source: anuga_core/source/anuga/utilities/sparse_ext.c @ 3745

Last change on this file since 3745 was 3730, checked in by ole, 18 years ago

Added proper error messages to return NULL in c extensions.
Fixed up some indentation issues as well

File size: 5.1 KB
Line 
1// Python - C extension for sparse module.
2//
3// To compile (Python2.3):
4//  gcc -c util_ext.c -I/usr/include/python2.3 -o util_ext.o -Wall -O
5//  gcc -shared util_ext.o  -o util_ext.so     
6//
7// See the module sparse.py
8//
9//
10// Steve Roberts, ANU 2004
11       
12#include "Python.h"
13#include "Numeric/arrayobject.h"
14#include "math.h"
15#include "stdio.h"
16
17//Matrix-vector routine
18int _csr_mv(int M,
19            double* data, 
20            long* colind,
21            long* row_ptr,
22            double* x,
23            double* y) {
24               
25  long i, j, ckey;
26
27  for (i=0; i<M; i++ ) 
28    for (ckey=row_ptr[i]; ckey<row_ptr[i+1]; ckey++) {
29      j = colind[ckey];
30      y[i] += data[ckey]*x[j];
31    }             
32 
33  return 0;
34}           
35
36//Matrix-matrix routine
37int _csr_mm(int M,
38            int columns, 
39            double* data, 
40            long* colind,
41            long* row_ptr,
42            double* x,
43            double* y) {
44               
45  long i, j, ckey, c, rowind_i, rowind_j;
46
47  for (i=0; i<M; i++ ) {
48    rowind_i = i*columns;
49   
50    for (ckey=row_ptr[i]; ckey<row_ptr[i+1]; ckey++) {
51      j = colind[ckey];
52      rowind_j = j*columns;
53         
54      for (c=0; c<columns; c++) {
55        y[rowind_i+c] += data[ckey]*x[rowind_j+c];
56      }             
57    } 
58  }
59 
60  return 0;
61}           
62
63                     
64 
65/////////////////////////////////////////////////
66// Gateways to Python
67PyObject *csr_mv(PyObject *self, PyObject *args) {
68 
69  PyObject *csr_sparse, // input sparse matrix
70    *xin, *R;           // output wrapped vector
71 
72  PyArrayObject
73    *data,            //Non Zeros Data array
74    *colind,          //Column indices array
75    *row_ptr,         //Row pointers array
76    *x,               //Input vector array
77    *y;               //Return vector array
78
79 
80  int dimensions[1], M, err, columns, rows;
81 
82  // Convert Python arguments to C 
83  if (!PyArg_ParseTuple(args, "OO", &csr_sparse, &xin)) {
84    PyErr_SetString(PyExc_RuntimeError, "Csr_mv could not parse input"); 
85    return NULL;
86  }
87
88  x = (PyArrayObject*) PyArray_ContiguousFromObject(xin,PyArray_DOUBLE,1,2);
89  if (!x) {
90    PyErr_SetString(PyExc_RuntimeError, 
91                    "Input array could not be read in csr_mv");   
92    return NULL;
93  }
94
95/*   printf("x.nd = %i\n",x->nd); */
96/*   printf("x.descr->type_num = %i %i\n",x->descr->type_num,PyArray_LONG); */
97/*   printf("x.dimensions[0] = %i\n",x->dimensions[0]); */
98/*   printf("x.data[0] = %g\n",((double*) x->data)[0]); */
99/*   printf("x.data[1] = %g\n",((double*) x->data)[1]); */
100/*   printf("x.data[2] = %g\n",((double*) x->data)[2]); */
101/*   printf("x.data[3] = %g\n",((double*) x->data)[3]); */
102/*   printf("x.data[4] = %g\n",((double*) x->data)[4]); */
103/*   printf("x.data[5] = %g\n",((double*) x->data)[5]); */
104
105 
106
107
108  data = (PyArrayObject*) 
109    PyObject_GetAttrString(csr_sparse, "data");     
110  if (!data) {
111    PyErr_SetString(PyExc_RuntimeError, 
112                    "Data array could not be allocated in csr_mv");     
113    return NULL;
114  } 
115
116  colind = (PyArrayObject*)
117    PyObject_GetAttrString(csr_sparse, "colind"); 
118  if (!colind) {
119    PyErr_SetString(PyExc_RuntimeError, 
120                    "Column index array could not be allocated in csr_mv");     
121    return NULL;
122  } 
123
124  row_ptr = (PyArrayObject*) 
125    PyObject_GetAttrString(csr_sparse, "row_ptr");   
126  if (!row_ptr) {
127    PyErr_SetString(PyExc_RuntimeError, 
128                    "Row pointer array could not be allocated in csr_mv"); 
129  }
130 
131  M = (row_ptr -> dimensions[0])-1;
132   
133  if (x -> nd == 1) {
134    // Multiplicant is a vector
135 
136    //Allocate space for return vectors y (don't DECREF)
137    dimensions[0] = M;
138    y = (PyArrayObject *) PyArray_FromDims(1, dimensions, PyArray_DOUBLE);
139 
140    err = _csr_mv(M,
141                  (double*) data -> data, 
142                  (long*)   colind -> data,
143                  (long*)   row_ptr -> data,
144                  (double*) x -> data,
145                  (double*) y -> data); 
146
147                           
148    if (err != 0) {
149      PyErr_SetString(PyExc_RuntimeError, "Matrix vector mult could not be calculated");
150      return NULL;
151    }
152  } else if(x -> nd == 2) {
153 
154
155    rows = x -> dimensions[0];     //Number of rows in x       
156    columns = x -> dimensions[1];  //Number of columns in x       
157   
158    //Allocate space for return matrix y (don't DECREF)
159    dimensions[0] = M;                   //Number of rows in sparse matrix 
160    dimensions[1] = columns;
161    y = (PyArrayObject *) PyArray_FromDims(2, dimensions, PyArray_DOUBLE);
162   
163    err = _csr_mm(M, columns,
164                  (double*) data -> data, 
165                  (long*)   colind -> data,
166                  (long*)   row_ptr -> data,
167                  (double*) x -> data,
168                  (double*) y -> data); 
169   
170  } else {
171    PyErr_SetString(PyExc_RuntimeError, 
172                    "Allowed dimensions in sparse_ext.c restricted to 1 or 2");
173    return NULL; 
174  }
175 
176                     
177  //Release                 
178  Py_DECREF(data);   
179  Py_DECREF(colind);   
180  Py_DECREF(row_ptr); 
181  Py_DECREF(x);           
182         
183  //Build result, release and return
184  R = Py_BuildValue("O", PyArray_Return(y)); 
185  Py_DECREF(y);       
186  return R;
187}
188
189
190
191
192// Method table for python module
193static struct PyMethodDef MethodTable[] = {
194  {"csr_mv", csr_mv, METH_VARARGS, "Print out"},   
195  {NULL, NULL, 0, NULL}   /* sentinel */
196};
197
198// Module initialisation   
199void initsparse_ext(void){
200  Py_InitModule("sparse_ext", MethodTable);
201 
202  import_array();     //Necessary for handling of NumPY structures 
203}
204
Note: See TracBrowser for help on using the repository browser.