source: anuga_core/source/anuga/utilities/sparse.py @ 6650

Last change on this file since 6650 was 6162, checked in by rwilson, 16 years ago

Changed num.array() calls that should have a defined type.

File size: 9.1 KB
Line 
1"""Proof of concept sparse matrix code
2"""
3
4import Numeric as num
5
6
7class Sparse:
8
9    def __init__(self, *args):
10        """Create sparse matrix.
11        There are two construction forms
12        Usage:
13
14        Sparse(A)     #Creates sparse matrix from dense matrix A
15        Sparse(M, N)  #Creates empty MxN sparse matrix
16        """
17
18        self.Data = {}
19           
20        if len(args) == 1:
21            try:
22                A = num.array(args[0])
23            except:
24                raise 'Input must be convertable to a Numeric array'
25
26            assert len(A.shape) == 2, 'Input must be a 2d matrix'
27           
28            self.M, self.N = A.shape
29            for i in range(self.M):
30                for j in range(self.N):
31                    if A[i, j] != 0.0:
32                        self.Data[i, j] = A[i, j]
33               
34           
35        elif len(args) == 2:
36            self.M = args[0]
37            self.N = args[1]
38        else:
39            raise 'Invalid construction'
40           
41        self.shape = (self.M, self.N) 
42
43
44    def __repr__(self):
45        return '%d X %d sparse matrix:\n' %(self.M, self.N) + `self.Data`
46
47    def __len__(self):
48        """Return number of nonzeros of A
49        """
50        return len(self.Data)
51
52    def nonzeros(self):
53        """Return number of nonzeros of A
54        """       
55        return len(self)
56   
57    def __setitem__(self, key, x):
58
59        i,j = key
60        # removing these asserts will not speed things up
61        assert 0 <= i < self.M
62        assert 0 <= j < self.N       
63
64        if x != 0:
65            self.Data[key] = float(x)
66        else:
67            if self.Data.has_key( key ):           
68                del self.Data[key]
69
70    def __getitem__(self, key):
71       
72        i,j = key
73        # removing these asserts will not speed things up
74        assert 0 <= i < self.M
75        assert 0 <= j < self.N               
76
77        if self.Data.has_key( key ):
78            return self.Data[ key ]
79        else:
80            return 0.0
81
82    def copy(self):
83        #FIXME: Use the copy module instead
84        new = Sparse(self.M,self.N)
85
86        for key in self.Data.keys():
87            i, j = key
88
89            new[i,j] = self.Data[i,j]
90
91        return new
92
93
94    def todense(self):
95        D = num.zeros( (self.M, self.N), num.Float)
96       
97        for i in range(self.M):
98            for j in range(self.N):
99                if self.Data.has_key( (i,j) ):               
100                    D[i, j] = self.Data[ (i,j) ]
101        return D
102
103
104   
105    def __mul__(self, other):
106        """Multiply this matrix onto 'other' which can either be
107        a Numeric vector, a Numeric matrix or another sparse matrix.
108        """
109
110        try:
111            B = num.array(other)
112        except:
113            msg = 'FIXME: Only Numeric types implemented so far'
114            raise msg
115           
116
117        # Assume numeric types from now on
118       
119        if len(B.shape) == 0:
120            # Scalar - use __rmul__ method
121            R = B*self
122           
123        elif len(B.shape) == 1:
124            # Vector
125            msg = 'Mismatching dimensions: You cannot multiply (%d x %d) matrix onto %d-vector'\
126                  %(self.M, self.N, B.shape[0])
127            assert B.shape[0] == self.N, msg
128
129            R = num.zeros(self.M, num.Float) #Result
130           
131            # Multiply nonzero elements
132            for key in self.Data.keys():
133                i, j = key
134
135                R[i] += self.Data[key]*B[j]
136        elif len(B.shape) == 2:
137       
138           
139            R = num.zeros((self.M, B.shape[1]), num.Float) #Result matrix
140
141            # Multiply nonzero elements
142            for col in range(R.shape[1]):
143                # For each column
144               
145                for key in self.Data.keys():
146                    i, j = key
147
148                    R[i, col] += self.Data[key]*B[j, col]
149           
150           
151        else:
152            raise ValueError, 'Dimension too high: d=%d' %len(B.shape)
153
154        return R
155   
156
157    def __add__(self, other):
158        """Add this matrix onto 'other'
159        """
160
161        new = other.copy()
162        for key in self.Data.keys():
163            i, j = key
164
165            new[i,j] += self.Data[key]
166
167        return new
168
169
170    def __rmul__(self, other):
171        """Right multiply this matrix with scalar
172        """
173
174        try:
175            other = float(other)
176        except:
177            msg = 'Sparse matrix can only "right-multiply" onto a scalar'
178            raise TypeError, msg
179        else:
180            new = self.copy()
181            #Multiply nonzero elements
182            for key in new.Data.keys():
183                i, j = key
184
185                new.Data[key] = other*new.Data[key]
186
187        return new
188
189
190    def trans_mult(self, other):
191        """Multiply the transpose of matrix with 'other' which can be
192        a Numeric vector.
193        """
194
195        try:
196            B = num.array(other)
197        except:
198            print 'FIXME: Only Numeric types implemented so far'
199
200
201        #Assume numeric types from now on
202        if len(B.shape) == 1:
203            #Vector
204
205            assert B.shape[0] == self.M, 'Mismatching dimensions'
206
207            R = num.zeros((self.N,), num.Float) #Result
208
209            #Multiply nonzero elements
210            for key in self.Data.keys():
211                i, j = key
212
213                R[j] += self.Data[key]*B[i]
214
215        else:
216            raise 'Can only multiply with 1d array'
217
218        return R
219
220class Sparse_CSR:
221
222    def __init__(self, A):
223        """Create sparse matrix in csr format.
224
225        Sparse_CSR(A) #creates csr sparse matrix from sparse matrix
226        Matrices are not built using this format, since it's painful to
227        add values to an existing sparse_CSR instance (hence there are no
228        objects to do this.)
229
230        Rather, build a matrix, and convert it to this format for a speed
231        increase.
232
233        data - a 1D array of the data
234        Colind - The ith item in this 1D array is the column index of the
235                 ith data in the data array
236        rowptr - 1D array, with the index representing the row of the matrix.
237                 The item in the row represents the index into colind of the
238                 first data value of this row.
239                 Regard it as a pointer into the colind array, for the ith row.
240
241                 
242        """
243
244        if isinstance(A,Sparse):
245
246            keys = A.Data.keys()
247            keys.sort()
248            nnz = len(keys)
249            data    = num.zeros ( (nnz,), num.Float)
250            colind  = num.zeros ( (nnz,), num.Int)
251            row_ptr = num.zeros ( (A.M+1,), num.Int)
252            current_row = -1
253            k = 0
254            for key in keys:
255                ikey0 = int(key[0])
256                ikey1 = int(key[1])
257                if ikey0 != current_row:
258                    current_row = ikey0
259                    row_ptr[ikey0] = k
260                data[k] = A.Data[key]
261                colind[k] = ikey1
262                k += 1
263            for row in range(current_row+1, A.M+1):
264                row_ptr[row] = nnz
265            #row_ptr[-1] = nnz
266       
267            self.data    = data
268            self.colind  = colind
269            self.row_ptr = row_ptr
270            self.M       = A.M
271            self.N       = A.N
272        else:
273            raise ValueError, "Sparse_CSR(A) expects A == Sparse Matrix"
274           
275    def __repr__(self):
276        return '%d X %d sparse matrix:\n' %(self.M, self.N) + `self.data`
277
278    def __len__(self):
279        """Return number of nonzeros of A
280        """
281        return self.row_ptr[-1]
282
283    def nonzeros(self):
284        """Return number of nonzeros of A
285        """       
286        return len(self)
287
288    def todense(self):
289        D = num.zeros( (self.M, self.N), num.Float)
290       
291        for i in range(self.M):
292            for ckey in range(self.row_ptr[i],self.row_ptr[i+1]):
293                j = self.colind[ckey]
294                D[i, j] = self.data[ckey]
295        return D
296
297    def __mul__(self, other):
298        """Multiply this matrix onto 'other' which can either be
299        a Numeric vector, a Numeric matrix or another sparse matrix.
300        """
301
302        try:
303            B = num.array(other)
304        except:
305            print 'FIXME: Only Numeric types implemented so far'
306
307        return csr_mv(self,B) 
308
309
310# Setup for C extensions
311from anuga.utilities import compile
312if compile.can_use_C_extension('sparse_ext.c'):
313    # Access underlying c implementations
314    from sparse_ext import csr_mv
315
316
317if __name__ == '__main__':
318    # A little selftest
319   
320    A = Sparse(3,3)
321
322    A[1,1] = 4
323
324
325    print A
326    print A.todense()
327
328    A[1,1] = 0
329
330    print A
331    print A.todense()   
332
333    A[1,2] = 0
334
335
336    A[0,0] = 3
337    A[1,1] = 2
338    A[1,2] = 2
339    A[2,2] = 1
340
341    print A
342    print A.todense()
343
344
345    #Right hand side vector
346    v = [2,3,4]
347
348    u = A*v
349    print u
350    assert num.allclose(u, [6,14,4])
351
352    u = A.trans_mult(v)
353    print u
354    assert num.allclose(u, [6,6,10])
355
356    #Right hand side column
357    v = num.array([[2,4],[3,4],[4,4]], num.Int)      #array default#
358
359    u = A*v[:,0]
360    assert num.allclose(u, [6,14,4])
361
362    #u = A*v[:,1]
363    #print u
364    print A.shape
365
366    B = 3*A
367    print B.todense()
368
369    B[1,0] = 2
370
371    C = A+B
372
373    print C.todense()
374
375    C = Sparse_CSR(C)
376
377    y = C*[6,14,4]
378
379    print y
380
381    y2 = C*[[6,4],[4,28],[4,8]]
382
383    print y2
Note: See TracBrowser for help on using the repository browser.