source: inundation/ga/storm_surge/pyvolution/sparse_ext.c @ 1507

Last change on this file since 1507 was 605, checked in by ole, 20 years ago

Implemented matrix-matrix mult in c-extension using CSR format - all tests work again.

File size: 4.6 KB
Line 
1// Python - C extension for sparse module.
2//
3// To compile (Python2.3):
4//  gcc -c util_ext.c -I/usr/include/python2.3 -o util_ext.o -Wall -O
5//  gcc -shared util_ext.o  -o util_ext.so     
6//
7// See the module sparse.py
8//
9//
10// Steve Roberts, ANU 2004
11       
12#include "Python.h"
13#include "Numeric/arrayobject.h"
14#include "math.h"
15#include "stdio.h"
16
17//Matrix-vector routine
18int _csr_mv(int M,
19            double* data, 
20            long* colind,
21            long* row_ptr,
22            double* x,
23            double* y) {
24               
25  long i, j, ckey;
26
27  for (i=0; i<M; i++ ) 
28    for (ckey=row_ptr[i]; ckey<row_ptr[i+1]; ckey++) {
29      j = colind[ckey];
30      y[i] += data[ckey]*x[j];
31    }             
32 
33  return 0;
34}           
35
36//Matrix-matrix routine
37int _csr_mm(int M,
38            int columns, 
39            double* data, 
40            long* colind,
41            long* row_ptr,
42            double* x,
43            double* y) {
44               
45  long i, j, ckey, c, rowind_i, rowind_j;
46
47  for (i=0; i<M; i++ ) {
48    rowind_i = i*columns;
49   
50    for (ckey=row_ptr[i]; ckey<row_ptr[i+1]; ckey++) {
51      j = colind[ckey];
52      rowind_j = j*columns;
53         
54      for (c=0; c<columns; c++) {
55        y[rowind_i+c] += data[ckey]*x[rowind_j+c];
56      }             
57    } 
58  }
59 
60  return 0;
61}           
62
63                     
64 
65/////////////////////////////////////////////////
66// Gateways to Python
67PyObject *csr_mv(PyObject *self, PyObject *args) {
68 
69  PyObject *csr_sparse, // input sparse matrix
70    *xin, *R;           // output wrapped vector
71 
72  PyArrayObject
73    *data,            //Non Zeros Data array
74    *colind,          //Column indices array
75    *row_ptr,         //Row pointers array
76    *x,               //Input vector array
77    *y;               //Return vector array
78
79 
80  int dimensions[1], M, err, columns, rows;
81 
82  // Convert Python arguments to C 
83  if (!PyArg_ParseTuple(args, "OO", &csr_sparse, &xin)) 
84    return NULL;
85
86  x = (PyArrayObject*) PyArray_ContiguousFromObject(xin,PyArray_DOUBLE,1,2);
87  if (!x)
88    return NULL;
89
90/*   printf("x.nd = %i\n",x->nd); */
91/*   printf("x.descr->type_num = %i %i\n",x->descr->type_num,PyArray_LONG); */
92/*   printf("x.dimensions[0] = %i\n",x->dimensions[0]); */
93/*   printf("x.data[0] = %g\n",((double*) x->data)[0]); */
94/*   printf("x.data[1] = %g\n",((double*) x->data)[1]); */
95/*   printf("x.data[2] = %g\n",((double*) x->data)[2]); */
96/*   printf("x.data[3] = %g\n",((double*) x->data)[3]); */
97/*   printf("x.data[4] = %g\n",((double*) x->data)[4]); */
98/*   printf("x.data[5] = %g\n",((double*) x->data)[5]); */
99
100 
101
102
103  data =  (PyArrayObject*)
104    PyObject_GetAttrString(csr_sparse, "data");     
105  if (!data) 
106    return NULL; 
107
108  colind = (PyArrayObject*)
109    PyObject_GetAttrString(csr_sparse, "colind"); 
110  if (!colind) return NULL;   
111
112  row_ptr = (PyArrayObject*) 
113    PyObject_GetAttrString(csr_sparse, "row_ptr");   
114  if (!row_ptr) return NULL;       
115 
116  M = (row_ptr -> dimensions[0])-1;
117   
118  if (x -> nd == 1) {
119    // Multiplicant is a vector
120 
121    //Allocate space for return vectors y (don't DECREF)
122    dimensions[0] = M;
123    y = (PyArrayObject *) PyArray_FromDims(1, dimensions, PyArray_DOUBLE);
124 
125    err = _csr_mv(M,
126                  (double*) data -> data, 
127                  (long*)   colind -> data,
128                  (long*)   row_ptr -> data,
129                  (double*) x -> data,
130                  (double*) y -> data); 
131
132                           
133    if (err != 0) {
134      PyErr_SetString(PyExc_RuntimeError, "matrix vector mult could not be calculated");
135      return NULL;
136    }
137  } else if(x -> nd == 2) {
138 
139
140    rows = x -> dimensions[0];     //Number of rows in x       
141    columns = x -> dimensions[1];  //Number of columns in x       
142   
143    //Allocate space for return matrix y (don't DECREF)
144    dimensions[0] = M;                   //Number of rows in sparse matrix 
145    dimensions[1] = columns;
146    y = (PyArrayObject *) PyArray_FromDims(2, dimensions, PyArray_DOUBLE);
147   
148    err = _csr_mm(M, columns,
149                  (double*) data -> data, 
150                  (long*)   colind -> data,
151                  (long*)   row_ptr -> data,
152                  (double*) x -> data,
153                  (double*) y -> data); 
154   
155  } else {
156    PyErr_SetString(PyExc_RuntimeError, 
157                    "Allowed dimensions in sparse_ext.c restricted to 1 or 2");
158    return NULL; 
159  }
160 
161                     
162  //Release                 
163  Py_DECREF(data);   
164  Py_DECREF(colind);   
165  Py_DECREF(row_ptr); 
166  Py_DECREF(x);           
167         
168  //Build result, release and return
169  R = Py_BuildValue("O", PyArray_Return(y)); 
170  Py_DECREF(y);       
171  return R;
172}
173
174
175
176
177// Method table for python module
178static struct PyMethodDef MethodTable[] = {
179  {"csr_mv", csr_mv, METH_VARARGS, "Print out"},   
180  {NULL, NULL, 0, NULL}   /* sentinel */
181};
182
183// Module initialisation   
184void initsparse_ext(void){
185  Py_InitModule("sparse_ext", MethodTable);
186 
187  import_array();     //Necessary for handling of NumPY structures 
188}
189
Note: See TracBrowser for help on using the repository browser.