source: anuga_core/source/anuga/utilities/sparse_ext.c @ 6671

Last change on this file since 6671 was 5897, checked in by ole, 16 years ago

Reverted numpy changes to the trunk that should have been made to the branch.
The command was svn merge -r 5895:5890 .

File size: 5.1 KB
Line 
1// Python - C extension for sparse module.
2//
3// To compile (Python2.3):
4//  gcc -c util_ext.c -I/usr/include/python2.3 -o util_ext.o -Wall -O
5//  gcc -shared util_ext.o  -o util_ext.so     
6//
7// See the module sparse.py
8//
9//
10// Steve Roberts, ANU 2004
11       
12#include "Python.h"
13#include "Numeric/arrayobject.h"
14#include "math.h"
15#include "stdio.h"
16
17//Matrix-vector routine
18int _csr_mv(int M,
19            double* data, 
20            long* colind,
21            long* row_ptr,
22            double* x,
23            double* y) {
24               
25  long i, j, ckey;
26
27  for (i=0; i<M; i++ ) 
28    for (ckey=row_ptr[i]; ckey<row_ptr[i+1]; ckey++) {
29      j = colind[ckey];
30      y[i] += data[ckey]*x[j];
31    }             
32 
33  return 0;
34}           
35
36//Matrix-matrix routine
37int _csr_mm(int M,
38            int columns, 
39            double* data, 
40            long* colind,
41            long* row_ptr,
42            double* x,
43            double* y) {
44               
45  long i, j, ckey, c, rowind_i, rowind_j;
46
47  for (i=0; i<M; i++ ) {
48    rowind_i = i*columns;
49   
50    for (ckey=row_ptr[i]; ckey<row_ptr[i+1]; ckey++) {
51      j = colind[ckey];
52      rowind_j = j*columns;
53         
54      for (c=0; c<columns; c++) {
55        y[rowind_i+c] += data[ckey]*x[rowind_j+c];
56      }             
57    } 
58  }
59 
60  return 0;
61}           
62
63                     
64 
65/////////////////////////////////////////////////
66// Gateways to Python
67PyObject *csr_mv(PyObject *self, PyObject *args) {
68 
69  PyObject *csr_sparse, // input sparse matrix
70    *xin, *R;           // output wrapped vector
71 
72  PyArrayObject
73    *data,            //Non Zeros Data array
74    *colind,          //Column indices array
75    *row_ptr,         //Row pointers array
76    *x,               //Input vector array
77    *y;               //Return vector array
78
79 
80  int dimensions[1], M, err, columns, rows;
81 
82  // Convert Python arguments to C 
83  if (!PyArg_ParseTuple(args, "OO", &csr_sparse, &xin)) {
84    PyErr_SetString(PyExc_RuntimeError, "Csr_mv could not parse input"); 
85    return NULL;
86  }
87
88  x = (PyArrayObject*) PyArray_ContiguousFromObject(xin,PyArray_DOUBLE,1,2);
89  if (!x) {
90    PyErr_SetString(PyExc_RuntimeError, 
91                    "Input array could not be read in csr_mv");   
92    return NULL;
93  }
94
95/*   printf("x.nd = %i\n",x->nd); */
96/*   printf("x.descr->type_num = %i %i\n",x->descr->type_num,PyArray_LONG); */
97/*   printf("x.dimensions[0] = %i\n",x->dimensions[0]); */
98/*   printf("x.data[0] = %g\n",((double*) x->data)[0]); */
99/*   printf("x.data[1] = %g\n",((double*) x->data)[1]); */
100/*   printf("x.data[2] = %g\n",((double*) x->data)[2]); */
101/*   printf("x.data[3] = %g\n",((double*) x->data)[3]); */
102/*   printf("x.data[4] = %g\n",((double*) x->data)[4]); */
103/*   printf("x.data[5] = %g\n",((double*) x->data)[5]); */
104
105 
106
107
108  data = (PyArrayObject*) 
109    PyObject_GetAttrString(csr_sparse, "data");     
110  if (!data) {
111    PyErr_SetString(PyExc_RuntimeError, 
112                    "Data array could not be allocated in csr_mv");     
113    return NULL;
114  } 
115
116  colind = (PyArrayObject*)
117    PyObject_GetAttrString(csr_sparse, "colind"); 
118  if (!colind) {
119    PyErr_SetString(PyExc_RuntimeError, 
120                    "Column index array could not be allocated in csr_mv");     
121    return NULL;
122  } 
123
124  row_ptr = (PyArrayObject*) 
125    PyObject_GetAttrString(csr_sparse, "row_ptr");   
126  if (!row_ptr) {
127    PyErr_SetString(PyExc_RuntimeError, 
128                    "Row pointer array could not be allocated in csr_mv"); 
129  }
130 
131  M = (row_ptr -> dimensions[0])-1;
132   
133  if (x -> nd == 1) {
134    // Multiplicant is a vector
135 
136    //Allocate space for return vectors y (don't DECREF)
137    dimensions[0] = M;
138    y = (PyArrayObject *) PyArray_FromDims(1, dimensions, PyArray_DOUBLE);
139 
140    err = _csr_mv(M,
141                  (double*) data -> data, 
142                  (long*)   colind -> data,
143                  (long*)   row_ptr -> data,
144                  (double*) x -> data,
145                  (double*) y -> data); 
146
147                           
148    if (err != 0) {
149      PyErr_SetString(PyExc_RuntimeError, "Matrix vector mult could not be calculated");
150      return NULL;
151    }
152  } else if(x -> nd == 2) {
153 
154
155    rows = x -> dimensions[0];     //Number of rows in x       
156    columns = x -> dimensions[1];  //Number of columns in x       
157   
158    //Allocate space for return matrix y (don't DECREF)
159    dimensions[0] = M;                   //Number of rows in sparse matrix 
160    dimensions[1] = columns;
161    y = (PyArrayObject *) PyArray_FromDims(2, dimensions, PyArray_DOUBLE);
162   
163    err = _csr_mm(M, columns,
164                  (double*) data -> data, 
165                  (long*)   colind -> data,
166                  (long*)   row_ptr -> data,
167                  (double*) x -> data,
168                  (double*) y -> data); 
169   
170  } else {
171    PyErr_SetString(PyExc_RuntimeError, 
172                    "Allowed dimensions in sparse_ext.c restricted to 1 or 2");
173    return NULL; 
174  }
175 
176                     
177  //Release                 
178  Py_DECREF(data);   
179  Py_DECREF(colind);   
180  Py_DECREF(row_ptr); 
181  Py_DECREF(x);           
182         
183  //Build result, release and return
184  R = Py_BuildValue("O", PyArray_Return(y)); 
185  Py_DECREF(y);       
186  return R;
187}
188
189
190
191
192// Method table for python module
193static struct PyMethodDef MethodTable[] = {
194  {"csr_mv", csr_mv, METH_VARARGS, "Print out"},   
195  {NULL, NULL, 0, NULL}   /* sentinel */
196};
197
198// Module initialisation   
199void initsparse_ext(void){
200  Py_InitModule("sparse_ext", MethodTable);
201 
202  import_array();     //Necessary for handling of NumPY structures 
203}
204
Note: See TracBrowser for help on using the repository browser.