source: inundation/pyvolution/sparse.py @ 1847

Last change on this file since 1847 was 1847, checked in by duncan, 19 years ago

added comments

File size: 10.2 KB
RevLine 
[485]1"""Proof of concept sparse matrix code
2"""
3
4
5class Sparse:
6
7    def __init__(self, *args):
8        """Create sparse matrix.
9        There are two construction forms
10        Usage:
11
12        Sparse(A)     #Creates sparse matrix from dense matrix A
13        Sparse(M, N)  #Creates empty MxN sparse matrix
14
15
16       
17        """
18
[586]19        self.Data = {}
[485]20           
21        if len(args) == 1:
22            from Numeric import array
23            try:
24                A = array(args[0])
25            except:
26                raise 'Input must be convertable to a Numeric array'
27
28            assert len(A.shape) == 2, 'Input must be a 2d matrix'
29           
30            self.M, self.N = A.shape
31            for i in range(self.M):
32                for j in range(self.N):
33                    if A[i, j] != 0.0:
[586]34                        self.Data[i, j] = A[i, j]
[485]35               
36           
37        elif len(args) == 2:
38            self.M = args[0]
39            self.N = args[1]
40        else:
41            raise 'Invalid construction'
42           
43        self.shape = (self.M, self.N) 
44
45
46    def __repr__(self):
[586]47        return '%d X %d sparse matrix:\n' %(self.M, self.N) + `self.Data`
[485]48
49    def __len__(self):
50        """Return number of nonzeros of A
51        """
[586]52        return len(self.Data)
[485]53
54    def nonzeros(self):
55        """Return number of nonzeros of A
56        """       
57        return len(self)
58   
59    def __setitem__(self, key, x):
60
61        i,j = key
62        assert 0 <= i < self.M
63        assert 0 <= j < self.N       
64
65        if x != 0:
[586]66            self.Data[key] = float(x)
[485]67        else:
[586]68            if self.Data.has_key( key ):           
69                del self.Data[key]
[485]70
71    def __getitem__(self, key):
72       
73        i,j = key
74        assert 0 <= i < self.M
75        assert 0 <= j < self.N               
76
[586]77        if self.Data.has_key( key ):
78            return self.Data[ key ]
[485]79        else:
80            return 0.0
81
82    def copy(self):
83        #FIXME: Use the copy module instead
84        new = Sparse(self.M,self.N)
85
[586]86        for key in self.Data.keys():
[485]87            i, j = key
88
[586]89            new[i,j] = self.Data[i,j]
[485]90
91        return new
92
93
94    def todense(self):
95        from Numeric import zeros, Float
96
97        D = zeros( (self.M, self.N), Float)
98       
99        for i in range(self.M):
100            for j in range(self.N):
[586]101                if self.Data.has_key( (i,j) ):               
102                    D[i, j] = self.Data[ (i,j) ]
[485]103        return D
104
105
[586]106   
[485]107    def __mul__(self, other):
108        """Multiply this matrix onto 'other' which can either be
109        a Numeric vector, a Numeric matrix or another sparse matrix.
110        """
111
112        from Numeric import array, zeros, Float
113       
114        try:
115            B = array(other)
116        except:
[1632]117            msg = 'FIXME: Only Numeric types implemented so far'
118            raise msg
119           
[485]120
121        #Assume numeric types from now on
122       
123        if len(B.shape) == 0:
124            #Scalar - use __rmul__ method
125            R = B*self
126           
127        elif len(B.shape) == 1:
128            #Vector
129            assert B.shape[0] == self.N, 'Mismatching dimensions'
130
131            R = zeros(self.M, Float) #Result
132           
133            #Multiply nonzero elements
[586]134            for key in self.Data.keys():
[485]135                i, j = key
136
[586]137                R[i] += self.Data[key]*B[j]
[485]138        elif len(B.shape) == 2:
139       
140           
141            R = zeros((self.M, B.shape[1]), Float) #Result matrix
142
143            #Multiply nonzero elements
144            for col in range(R.shape[1]):
145                #For each column
146               
[586]147                for key in self.Data.keys():
[485]148                    i, j = key
149
[586]150                    R[i, col] += self.Data[key]*B[j, col]
[485]151           
152           
153        else:
154            raise ValueError, 'Dimension too high: d=%d' %len(B.shape)
155
156        return R
157   
158
159    def __add__(self, other):
160        """Add this matrix onto 'other'
161        """
162
163        from Numeric import array, zeros, Float
164       
165        new = other.copy()
[586]166        for key in self.Data.keys():
[485]167            i, j = key
168
[586]169            new[i,j] += self.Data[key]
[485]170
171        return new
172
173
174    def __rmul__(self, other):
175        """Right multiply this matrix with scalar
176        """
177
178        from Numeric import array, zeros, Float
179       
180        try:
181            other = float(other)
182        except:
183            msg = 'Sparse matrix can only "right-multiply" onto a scalar'
184            raise TypeError, msg
185        else:
186            new = self.copy()
187            #Multiply nonzero elements
[586]188            for key in new.Data.keys():
[485]189                i, j = key
190
[586]191                new.Data[key] = other*new.Data[key]
[485]192
193        return new
194
195
196    def trans_mult(self, other):
197        """Multiply the transpose of matrix with 'other' which can be
198        a Numeric vector.
199        """
200
201        from Numeric import array, zeros, Float
202       
203        try:
204            B = array(other)
205        except:
206            print 'FIXME: Only Numeric types implemented so far'
207
208
209        #Assume numeric types from now on
210        if len(B.shape) == 1:
211            #Vector
212
213            assert B.shape[0] == self.M, 'Mismatching dimensions'
214
215            R = zeros((self.N,), Float) #Result
216
217            #Multiply nonzero elements
[586]218            for key in self.Data.keys():
[485]219                i, j = key
220
[586]221                R[j] += self.Data[key]*B[i]
[485]222
223        else:
224            raise 'Can only multiply with 1d array'
225
226        return R
227
[586]228class Sparse_CSR:
[485]229
[586]230    def __init__(self, A):
231        """Create sparse matrix in csr format.
232
233        Sparse_CSR(A) #creates csr sparse matrix from sparse matrix
[1847]234
235        data - a 1D array of the data
236        Colind - The ith item in this 1D array is the column index of the
237                 ith data in the data array
[586]238        """
239
240        from Numeric import array, Float, Int
241
242        if isinstance(A,Sparse):
243
244            from Numeric import zeros
245            keys = A.Data.keys()
246            keys.sort()
247            nnz = len(keys)
248            data    = zeros ( (nnz,), Float)
249            colind  = zeros ( (nnz,), Int)
250            row_ptr = zeros ( (A.M+1,), Int)
251            current_row = -1
252            k = 0
253            for key in keys:
254                ikey0 = int(key[0])
255                ikey1 = int(key[1])
256                if ikey0 != current_row:
257                    current_row = ikey0
258                    row_ptr[ikey0] = k
259                data[k] = A.Data[key]
260                colind[k] = ikey1
261                k += 1
262            for row in range(current_row+1, A.M+1):
263                row_ptr[row] = nnz
264            #row_ptr[-1] = nnz
265       
266            self.data    = data
267            self.colind  = colind
268            self.row_ptr = row_ptr
269            self.M       = A.M
270            self.N       = A.N
271        else:
272            raise ValueError, "Sparse_CSR(A) expects A == Sparse Matrix"
273           
274    def __repr__(self):
275        return '%d X %d sparse matrix:\n' %(self.M, self.N) + `self.data`
276
277    def __len__(self):
278        """Return number of nonzeros of A
279        """
280        return self.row_ptr[-1]
281
[1847]282    def __setitem__(self, key, x):
283
284        i,j = key
285        assert 0 <= i < self.M
286        assert 0 <= j < self.N       
287
288        # From Sparse
289        #if x != 0:
290        #    self.Data[key] = float(x)
291        #else:
292        #    if self.Data.has_key( key ):           
293        #        del self.Data[key]
294
[586]295    def nonzeros(self):
296        """Return number of nonzeros of A
297        """       
298        return len(self)
299
300    def todense(self):
301        from Numeric import zeros, Float
302
303        D = zeros( (self.M, self.N), Float)
304       
305        for i in range(self.M):
306            for ckey in range(self.row_ptr[i],self.row_ptr[i+1]):
307                j = self.colind[ckey]
308                D[i, j] = self.data[ckey]
309        return D
310
311    def __mul__(self, other):
312        """Multiply this matrix onto 'other' which can either be
313        a Numeric vector, a Numeric matrix or another sparse matrix.
314        """
315
316        from Numeric import array, zeros, Float
317       
318        try:
319            B = array(other)
320        except:
321            print 'FIXME: Only Numeric types implemented so far'
322
[594]323        return csr_mv(self,B) 
[587]324
[586]325
326
[594]327def csr_mv(self, B):
[605]328    """Python version of sparse (CSR) matrix multiplication
329    """
[586]330
[594]331    from Numeric import zeros, Float
[586]332
[587]333
[594]334    #Assume numeric types from now on
335       
336    if len(B.shape) == 0:
337        #Scalar - use __rmul__ method
338        R = B*self
339       
340    elif len(B.shape) == 1:
[587]341        #Vector
342        assert B.shape[0] == self.N, 'Mismatching dimensions'
[594]343       
[587]344        R = zeros(self.M, Float) #Result
[594]345       
[587]346        #Multiply nonzero elements
347        for i in range(self.M):
348            for ckey in range(self.row_ptr[i],self.row_ptr[i+1]):
349                j = self.colind[ckey]
350                R[i] += self.data[ckey]*B[j]           
[594]351               
352    elif len(B.shape) == 2:
353       
354        R = zeros((self.M, B.shape[1]), Float) #Result matrix
355       
356        #Multiply nonzero elements
357       
358        for col in range(R.shape[1]):
359            #For each column
360            for i in range(self.M):
361                for ckey in range(self.row_ptr[i],self.row_ptr[i+1]):
362                    j = self.colind[ckey]
363                    R[i, col] += self.data[ckey]*B[j,col]           
364                   
[587]365    else:
366        raise ValueError, 'Dimension too high: d=%d' %len(B.shape)
[594]367   
[587]368    return R
369
370
[605]371
372#Setup for C extensions
[587]373import compile
374if compile.can_use_C_extension('sparse_ext.c'):
[605]375    #Replace python version with c implementation
376    from sparse_ext import csr_mv
[587]377
[485]378if __name__ == '__main__':
379
380    from Numeric import allclose, array, Float
381   
382    A = Sparse(3,3)
383
384    A[1,1] = 4
385
386
387    print A
388    print A.todense()
389
390    A[1,1] = 0
391
392    print A
393    print A.todense()   
394
395    A[1,2] = 0
396
397
398    A[0,0] = 3
399    A[1,1] = 2
400    A[1,2] = 2
401    A[2,2] = 1
402
403    print A
404    print A.todense()
405
406
407    #Right hand side vector
408    v = [2,3,4]
409
410    u = A*v
411    print u
412    assert allclose(u, [6,14,4])
413
414    u = A.trans_mult(v)
415    print u
416    assert allclose(u, [6,6,10])
417
418    #Right hand side column
419    v = array([[2,4],[3,4],[4,4]])
420
421    u = A*v[:,0]
422    assert allclose(u, [6,14,4])
423
424    #u = A*v[:,1]
425    #print u
426    print A.shape
427
428    B = 3*A
429    print B.todense()
430
431    B[1,0] = 2
432
433    C = A+B
434
435    print C.todense()
[594]436
437    C = Sparse_CSR(C)
438
439    y = C*[6,14,4]
440
441    print y
442
443    y2 = C*[[6,4],[4,28],[4,8]]
444
445    print y2
Note: See TracBrowser for help on using the repository browser.