source: trunk/anuga_core/source/anuga/utilities/sparse_ext.c @ 8710

Last change on this file since 8710 was 8436, checked in by steve, 12 years ago

Made achange to sparse_ext.c as there was a memory leak.

File size: 5.1 KB
Line 
1// Python - C extension for sparse module.
2//
3// To compile (Python2.3):
4//  gcc -c util_ext.c -I/usr/include/python2.3 -o util_ext.o -Wall -O
5//  gcc -shared util_ext.o  -o util_ext.so     
6//
7// See the module sparse.py
8//
9//
10// Steve Roberts, ANU 2004
11       
12#include "Python.h"
13#include "numpy/arrayobject.h"
14#include "math.h"
15#include "stdio.h"
16#include "numpy_shim.h"
17
18//Matrix-vector routine
19int _csr_mv(int M,
20            double* data, 
21            long* colind,
22            long* row_ptr,
23            double* x,
24            double* y) {
25               
26  long i, j, ckey;
27
28  for (i=0; i<M; i++ ) 
29    for (ckey=row_ptr[i]; ckey<row_ptr[i+1]; ckey++) {
30      j = colind[ckey];
31      y[i] += data[ckey]*x[j];
32    }             
33 
34  return 0;
35}           
36
37//Matrix-matrix routine
38int _csr_mm(int M,
39            int columns, 
40            double* data, 
41            long* colind,
42            long* row_ptr,
43            double* x,
44            double* y) {
45               
46  long i, j, ckey, c, rowind_i, rowind_j;
47
48  for (i=0; i<M; i++ ) {
49    rowind_i = i*columns;
50   
51    for (ckey=row_ptr[i]; ckey<row_ptr[i+1]; ckey++) {
52      j = colind[ckey];
53      rowind_j = j*columns;
54         
55      for (c=0; c<columns; c++) {
56        y[rowind_i+c] += data[ckey]*x[rowind_j+c];
57      }             
58    } 
59  }
60 
61  return 0;
62}           
63
64                     
65 
66/////////////////////////////////////////////////
67// Gateways to Python
68PyObject *csr_mv(PyObject *self, PyObject *args) {
69 
70  PyObject *csr_sparse, // input sparse matrix
71    *xin, *R;           // output wrapped vector
72 
73  PyArrayObject
74    *data,            //Non Zeros Data array
75    *colind,          //Column indices array
76    *row_ptr,         //Row pointers array
77    *x,               //Input vector array
78    *y;               //Return vector array
79
80 
81  int dimensions[2], M, err, columns, rows;
82 
83  // Convert Python arguments to C 
84  if (!PyArg_ParseTuple(args, "OO", &csr_sparse, &xin)) {
85    PyErr_SetString(PyExc_RuntimeError, "Csr_mv could not parse input"); 
86    return NULL;
87  }
88
89  x = (PyArrayObject*) PyArray_ContiguousFromObject(xin,PyArray_DOUBLE,1,2);
90  if (!x) {
91    PyErr_SetString(PyExc_RuntimeError, 
92                    "Input array could not be read in csr_mv");   
93    return NULL;
94  }
95
96/*   printf("x.nd = %i\n",x->nd); */
97/*   printf("x.descr->type_num = %i %i\n",x->descr->type_num,PyArray_LONG); */
98/*   printf("x.dimensions[0] = %i\n",x->dimensions[0]); */
99/*   printf("x.data[0] = %g\n",((double*) x->data)[0]); */
100/*   printf("x.data[1] = %g\n",((double*) x->data)[1]); */
101/*   printf("x.data[2] = %g\n",((double*) x->data)[2]); */
102/*   printf("x.data[3] = %g\n",((double*) x->data)[3]); */
103/*   printf("x.data[4] = %g\n",((double*) x->data)[4]); */
104/*   printf("x.data[5] = %g\n",((double*) x->data)[5]); */
105
106 
107
108
109  data = (PyArrayObject*) 
110    PyObject_GetAttrString(csr_sparse, "data");     
111  if (!data) {
112    PyErr_SetString(PyExc_RuntimeError, 
113                    "Data array could not be allocated in csr_mv");     
114    return NULL;
115  } 
116
117  colind = (PyArrayObject*)
118    PyObject_GetAttrString(csr_sparse, "colind"); 
119  if (!colind) {
120    PyErr_SetString(PyExc_RuntimeError, 
121                    "Column index array could not be allocated in csr_mv");     
122    return NULL;
123  } 
124
125  row_ptr = (PyArrayObject*) 
126    PyObject_GetAttrString(csr_sparse, "row_ptr");   
127  if (!row_ptr) {
128    PyErr_SetString(PyExc_RuntimeError, 
129                    "Row pointer array could not be allocated in csr_mv"); 
130  }
131 
132  M = (row_ptr -> dimensions[0])-1;
133   
134  if (x -> nd == 1) {
135    // Multiplicant is a vector
136 
137    //Allocate space for return vectors y (don't DECREF)
138    dimensions[0] = M;
139    y = (PyArrayObject *) anuga_FromDims(1, dimensions, PyArray_DOUBLE);
140 
141    err = _csr_mv(M,
142                  (double*) data -> data, 
143                  (long*)   colind -> data,
144                  (long*)   row_ptr -> data,
145                  (double*) x -> data,
146                  (double*) y -> data); 
147
148                           
149    if (err != 0) {
150      PyErr_SetString(PyExc_RuntimeError, "Matrix vector mult could not be calculated");
151      return NULL;
152    }
153  } else if(x -> nd == 2) {
154 
155
156    rows = x -> dimensions[0];     //Number of rows in x       
157    columns = x -> dimensions[1];  //Number of columns in x       
158   
159    //Allocate space for return matrix y (don't DECREF)
160    dimensions[0] = M;                   //Number of rows in sparse matrix 
161    dimensions[1] = columns;
162    y = (PyArrayObject *) anuga_FromDims(2, dimensions, PyArray_DOUBLE);
163   
164    err = _csr_mm(M, columns,
165                  (double*) data -> data, 
166                  (long*)   colind -> data,
167                  (long*)   row_ptr -> data,
168                  (double*) x -> data,
169                  (double*) y -> data); 
170   
171  } else {
172    PyErr_SetString(PyExc_RuntimeError, 
173                    "Allowed dimensions in sparse_ext.c restricted to 1 or 2");
174    return NULL; 
175  }
176 
177                     
178  //Release                 
179  Py_DECREF(data);   
180  Py_DECREF(colind);   
181  Py_DECREF(row_ptr); 
182  Py_DECREF(x);           
183         
184  //Build result, release and return
185  R = Py_BuildValue("O", PyArray_Return(y)); 
186  Py_DECREF(y);       
187  return R;
188}
189
190
191
192
193// Method table for python module
194static struct PyMethodDef MethodTable[] = {
195  {"csr_mv", csr_mv, METH_VARARGS, "Print out"},   
196  {NULL, NULL, 0, NULL}   /* sentinel */
197};
198
199// Module initialisation   
200void initsparse_ext(void){
201  Py_InitModule("sparse_ext", MethodTable);
202 
203  import_array();     //Necessary for handling of NumPY structures 
204}
205
Note: See TracBrowser for help on using the repository browser.