source: anuga_core/source/anuga/utilities/sparse.py @ 5571

Last change on this file since 5571 was 5223, checked in by ole, 16 years ago

Work done during Water Down Under 2008.
Input checks.

File size: 9.5 KB
Line 
1"""Proof of concept sparse matrix code
2"""
3
4
5class Sparse:
6
7    def __init__(self, *args):
8        """Create sparse matrix.
9        There are two construction forms
10        Usage:
11
12        Sparse(A)     #Creates sparse matrix from dense matrix A
13        Sparse(M, N)  #Creates empty MxN sparse matrix
14        """
15
16        self.Data = {}
17           
18        if len(args) == 1:
19            from Numeric import array
20            try:
21                A = array(args[0])
22            except:
23                raise 'Input must be convertable to a Numeric array'
24
25            assert len(A.shape) == 2, 'Input must be a 2d matrix'
26           
27            self.M, self.N = A.shape
28            for i in range(self.M):
29                for j in range(self.N):
30                    if A[i, j] != 0.0:
31                        self.Data[i, j] = A[i, j]
32               
33           
34        elif len(args) == 2:
35            self.M = args[0]
36            self.N = args[1]
37        else:
38            raise 'Invalid construction'
39           
40        self.shape = (self.M, self.N) 
41
42
43    def __repr__(self):
44        return '%d X %d sparse matrix:\n' %(self.M, self.N) + `self.Data`
45
46    def __len__(self):
47        """Return number of nonzeros of A
48        """
49        return len(self.Data)
50
51    def nonzeros(self):
52        """Return number of nonzeros of A
53        """       
54        return len(self)
55   
56    def __setitem__(self, key, x):
57
58        i,j = key
59        # removing these asserts will not speed things up
60        assert 0 <= i < self.M
61        assert 0 <= j < self.N       
62
63        if x != 0:
64            self.Data[key] = float(x)
65        else:
66            if self.Data.has_key( key ):           
67                del self.Data[key]
68
69    def __getitem__(self, key):
70       
71        i,j = key
72        # removing these asserts will not speed things up
73        assert 0 <= i < self.M
74        assert 0 <= j < self.N               
75
76        if self.Data.has_key( key ):
77            return self.Data[ key ]
78        else:
79            return 0.0
80
81    def copy(self):
82        #FIXME: Use the copy module instead
83        new = Sparse(self.M,self.N)
84
85        for key in self.Data.keys():
86            i, j = key
87
88            new[i,j] = self.Data[i,j]
89
90        return new
91
92
93    def todense(self):
94        from Numeric import zeros, Float
95
96        D = zeros( (self.M, self.N), Float)
97       
98        for i in range(self.M):
99            for j in range(self.N):
100                if self.Data.has_key( (i,j) ):               
101                    D[i, j] = self.Data[ (i,j) ]
102        return D
103
104
105   
106    def __mul__(self, other):
107        """Multiply this matrix onto 'other' which can either be
108        a Numeric vector, a Numeric matrix or another sparse matrix.
109        """
110
111        from Numeric import array, zeros, Float
112       
113        try:
114            B = array(other)
115        except:
116            msg = 'FIXME: Only Numeric types implemented so far'
117            raise msg
118           
119
120        # Assume numeric types from now on
121       
122        if len(B.shape) == 0:
123            # Scalar - use __rmul__ method
124            R = B*self
125           
126        elif len(B.shape) == 1:
127            # Vector
128            msg = 'Mismatching dimensions: You cannot multiply (%d x %d) matrix onto %d-vector'\
129                  %(self.M, self.N, B.shape[0])
130            assert B.shape[0] == self.N, msg
131
132            R = zeros(self.M, Float) #Result
133           
134            # Multiply nonzero elements
135            for key in self.Data.keys():
136                i, j = key
137
138                R[i] += self.Data[key]*B[j]
139        elif len(B.shape) == 2:
140       
141           
142            R = zeros((self.M, B.shape[1]), Float) #Result matrix
143
144            # Multiply nonzero elements
145            for col in range(R.shape[1]):
146                # For each column
147               
148                for key in self.Data.keys():
149                    i, j = key
150
151                    R[i, col] += self.Data[key]*B[j, col]
152           
153           
154        else:
155            raise ValueError, 'Dimension too high: d=%d' %len(B.shape)
156
157        return R
158   
159
160    def __add__(self, other):
161        """Add this matrix onto 'other'
162        """
163
164        from Numeric import array, zeros, Float
165       
166        new = other.copy()
167        for key in self.Data.keys():
168            i, j = key
169
170            new[i,j] += self.Data[key]
171
172        return new
173
174
175    def __rmul__(self, other):
176        """Right multiply this matrix with scalar
177        """
178
179        from Numeric import array, zeros, Float
180       
181        try:
182            other = float(other)
183        except:
184            msg = 'Sparse matrix can only "right-multiply" onto a scalar'
185            raise TypeError, msg
186        else:
187            new = self.copy()
188            #Multiply nonzero elements
189            for key in new.Data.keys():
190                i, j = key
191
192                new.Data[key] = other*new.Data[key]
193
194        return new
195
196
197    def trans_mult(self, other):
198        """Multiply the transpose of matrix with 'other' which can be
199        a Numeric vector.
200        """
201
202        from Numeric import array, zeros, Float
203       
204        try:
205            B = array(other)
206        except:
207            print 'FIXME: Only Numeric types implemented so far'
208
209
210        #Assume numeric types from now on
211        if len(B.shape) == 1:
212            #Vector
213
214            assert B.shape[0] == self.M, 'Mismatching dimensions'
215
216            R = zeros((self.N,), Float) #Result
217
218            #Multiply nonzero elements
219            for key in self.Data.keys():
220                i, j = key
221
222                R[j] += self.Data[key]*B[i]
223
224        else:
225            raise 'Can only multiply with 1d array'
226
227        return R
228
229class Sparse_CSR:
230
231    def __init__(self, A):
232        """Create sparse matrix in csr format.
233
234        Sparse_CSR(A) #creates csr sparse matrix from sparse matrix
235        Matrices are not built using this format, since it's painful to
236        add values to an existing sparse_CSR instance (hence there are no
237        objects to do this.)
238
239        Rather, build a matrix, and convert it to this format for a speed
240        increase.
241
242        data - a 1D array of the data
243        Colind - The ith item in this 1D array is the column index of the
244                 ith data in the data array
245        rowptr - 1D array, with the index representing the row of the matrix.
246                 The item in the row represents the index into colind of the
247                 first data value of this row.
248                 Regard it as a pointer into the colind array, for the ith row.
249
250                 
251        """
252
253        from Numeric import array, Float, Int
254
255        if isinstance(A,Sparse):
256
257            from Numeric import zeros
258            keys = A.Data.keys()
259            keys.sort()
260            nnz = len(keys)
261            data    = zeros ( (nnz,), Float)
262            colind  = zeros ( (nnz,), Int)
263            row_ptr = zeros ( (A.M+1,), Int)
264            current_row = -1
265            k = 0
266            for key in keys:
267                ikey0 = int(key[0])
268                ikey1 = int(key[1])
269                if ikey0 != current_row:
270                    current_row = ikey0
271                    row_ptr[ikey0] = k
272                data[k] = A.Data[key]
273                colind[k] = ikey1
274                k += 1
275            for row in range(current_row+1, A.M+1):
276                row_ptr[row] = nnz
277            #row_ptr[-1] = nnz
278       
279            self.data    = data
280            self.colind  = colind
281            self.row_ptr = row_ptr
282            self.M       = A.M
283            self.N       = A.N
284        else:
285            raise ValueError, "Sparse_CSR(A) expects A == Sparse Matrix"
286           
287    def __repr__(self):
288        return '%d X %d sparse matrix:\n' %(self.M, self.N) + `self.data`
289
290    def __len__(self):
291        """Return number of nonzeros of A
292        """
293        return self.row_ptr[-1]
294
295    def nonzeros(self):
296        """Return number of nonzeros of A
297        """       
298        return len(self)
299
300    def todense(self):
301        from Numeric import zeros, Float
302
303        D = zeros( (self.M, self.N), Float)
304       
305        for i in range(self.M):
306            for ckey in range(self.row_ptr[i],self.row_ptr[i+1]):
307                j = self.colind[ckey]
308                D[i, j] = self.data[ckey]
309        return D
310
311    def __mul__(self, other):
312        """Multiply this matrix onto 'other' which can either be
313        a Numeric vector, a Numeric matrix or another sparse matrix.
314        """
315
316        from Numeric import array, zeros, Float
317       
318        try:
319            B = array(other)
320        except:
321            print 'FIXME: Only Numeric types implemented so far'
322
323        return csr_mv(self,B) 
324
325
326# Setup for C extensions
327from anuga.utilities import compile
328if compile.can_use_C_extension('sparse_ext.c'):
329    # Access underlying c implementations
330    from sparse_ext import csr_mv
331
332
333if __name__ == '__main__':
334    # A little selftest
335   
336    from Numeric import allclose, array, Float
337   
338    A = Sparse(3,3)
339
340    A[1,1] = 4
341
342
343    print A
344    print A.todense()
345
346    A[1,1] = 0
347
348    print A
349    print A.todense()   
350
351    A[1,2] = 0
352
353
354    A[0,0] = 3
355    A[1,1] = 2
356    A[1,2] = 2
357    A[2,2] = 1
358
359    print A
360    print A.todense()
361
362
363    #Right hand side vector
364    v = [2,3,4]
365
366    u = A*v
367    print u
368    assert allclose(u, [6,14,4])
369
370    u = A.trans_mult(v)
371    print u
372    assert allclose(u, [6,6,10])
373
374    #Right hand side column
375    v = array([[2,4],[3,4],[4,4]])
376
377    u = A*v[:,0]
378    assert allclose(u, [6,14,4])
379
380    #u = A*v[:,1]
381    #print u
382    print A.shape
383
384    B = 3*A
385    print B.todense()
386
387    B[1,0] = 2
388
389    C = A+B
390
391    print C.todense()
392
393    C = Sparse_CSR(C)
394
395    y = C*[6,14,4]
396
397    print y
398
399    y2 = C*[[6,4],[4,28],[4,8]]
400
401    print y2
Note: See TracBrowser for help on using the repository browser.