Changeset 452
- Timestamp:
- Oct 26, 2004, 11:55:00 PM (20 years ago)
- Location:
- inundation/ga/storm_surge/pyvolution
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
inundation/ga/storm_surge/pyvolution/least_squares.py
r441 r452 27 27 from Numeric import zeros, array, Float, Int, dot, transpose 28 28 from LinearAlgebra import solve_linear_equations 29 from scipy import sparse 29 #from scipy import sparse 30 from sparse import Sparse 30 31 from cg_solve import conjugate_gradient, VectorShapeError 31 32 … … 210 211 211 212 #self.A = zeros((n,m), Float) 212 self.A = sparse.dok_matrix() 213 self.AtA = sparse.dok_matrix() 213 self.A = Sparse(n,m) 214 215 #print 'n by m ',self.A 216 self.AtA = Sparse(m,m) 214 217 215 218 #Compute matrix elements … … 265 268 266 269 267 self.A = (self.A).tocsc()268 self.AtA = (self.AtA).tocsc()269 self.At = self.A.transp()270 ## self.A = (self.A).tocsc() 271 ## self.AtA = (self.AtA).tocsc() 272 ## self.At = self.A.transp() 270 273 271 274 def get_A(self): … … 314 317 315 318 #self.D = zeros((m,m), Float) 316 self.D = sparse.dok_matrix()319 self.D = Sparse(m,m) 317 320 318 321 #For each triangle compute contributions to D = D1+D2 … … 360 363 self.D[v0,v2] += e20 361 364 362 self.D = (self.D).tocsc()365 #self.D = (self.D).tocsc() 363 366 364 367 def fit(self, z): … … 383 386 384 387 #Compute right hand side based on data 385 Atz = self.At * z 386 388 Atz = self.A.trans_mult(z) 389 390 391 #print 'fit: Atz',Atz 392 387 393 #Check sanity 388 394 n, m = self.A.shape … … 485 491 """ 486 492 import os, sys 487 usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh alpha" % 493 usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh alpha" % os.path.basename(sys.argv[0]) 488 494 489 495 if len(sys.argv) < 4: -
inundation/ga/storm_surge/pyvolution/sparse.py
r447 r452 1 1 """Proof of concept sparse matrix code 2 2 """ 3 4 3 from scipy_base import * 4 from cg_solve import conjugate_gradient, VectorShapeError 5 5 6 6 class Sparse: … … 12 12 self.M = M 13 13 self.N = N 14 self.shape = (M,N) 14 15 self.A = {} 15 16 … … 40 41 return 0.0 41 42 43 def copy(self): 44 45 new = Sparse(self.M,self.N) 46 47 for key in self.A.keys(): 48 i, j = key 49 50 new[i,j] = self.A[i,j] 51 52 return new 53 42 54 43 55 def todense(self): … … 49 61 for j in range(self.N): 50 62 if self.A.has_key( (i,j) ): 51 D[i, j] = A[ (i,j) ]63 D[i, j] = self.A[ (i,j) ] 52 64 return D 53 65 … … 67 79 68 80 #Assume numeric types from now on 69 R = zeros( B.shape, Float) #Result81 R = zeros((self.M,), Float) #Result 70 82 71 83 if len(B.shape) == 1: 72 84 #Vector 73 85 86 ## print 'B.shape ',B.shape 87 ## print 'self.shape ',self.shape 88 74 89 assert B.shape[0] == self.N, 'Mismatching dimensions' 75 90 … … 78 93 i, j = key 79 94 80 R[i] += A[key]*B[j]81 82 else: 83 raise 'Numeric matrix not yet implemented'95 R[i] += self.A[key]*B[j] 96 97 else: 98 raise ValueError, 'Numeric matrix not yet implemented' 84 99 85 100 return R 86 87 101 102 def __add__(self, other): 103 """Add this matrix onto 'other' 104 """ 105 106 from Numeric import array, zeros, Float 107 108 new = other.copy() 109 110 # print 'self.shape',self.shape 111 # print 'other.shape',other.shape 112 113 for key in self.A.keys(): 114 i, j = key 115 116 new[i,j] = new[i,j] + self.A[key] 117 118 return new 119 120 121 def __rmul__(self, other): 122 """Right multiply this matrix with scalar 123 """ 124 125 from Numeric import array, zeros, Float 126 127 if isscalar(other): 128 new = self.copy() 129 #Multiply nonzero elements 130 for key in new.A.keys(): 131 i, j = key 132 133 new.A[key] = other*new.A[key] 134 else: 135 raise 'only right multiple with scalar implemented' 136 137 # print 'new.shape',new.shape 138 139 return new 140 141 142 def trans_mult(self, other): 143 """Multiply the transpose of matrix with 'other' which can be 144 a Numeric vector. 145 """ 146 147 from Numeric import array, zeros, Float 148 149 try: 150 B = array(other) 151 except: 152 print 'FIXME: Only Numeric types implemented so far' 153 154 155 #Assume numeric types from now on 156 157 158 if len(B.shape) == 1: 159 #Vector 160 161 assert B.shape[0] == self.M, 'Mismatching dimensions' 162 163 R = zeros((self.N,), Float) #Result 164 165 #Multiply nonzero elements 166 for key in self.A.keys(): 167 i, j = key 168 169 R[j] += self.A[key]*B[i] 170 171 else: 172 raise 'Can only multiply with 1d array' 173 174 return R 175 176 177 178 88 179 if __name__ == '__main__': 89 180 … … 122 213 assert allclose(u, [6,14,4]) 123 214 215 u = A.trans_mult(v) 216 print u 217 assert allclose(u, [6,6,10]) 218 124 219 #Right hand side column 125 220 v = array([[2,4],[3,4],[4,4]]) … … 130 225 #u = A*v[:,1] 131 226 #print u 227 print A.shape 228 229 B = 3*A 230 print B.todense() 231 232 B[1,0] = 2 233 234 C = A+B 235 236 print C.todense() -
inundation/ga/storm_surge/pyvolution/test_least_squares.py
r441 r452 3 3 import unittest 4 4 from math import sqrt 5 from scipy import mat 5 6 6 7 from least_squares import * … … 121 122 122 123 #Data points 123 data_points = [ [-3., 1.9], [-2, 1], [0.0, 1], [0, 3], [2, 3], [-1.0/3,-4./3], [-1.0,-1.5 ], [1.0,-1.0]] 124 data_points = [ [-3., 1.9], [-2, 1], [0.0, 1], [0, 3], [2, 3], 125 [-1.0/3,-4./3], [-1.0,-1.5 ], [1.0,-1.0]] 124 126 interp = Interpolation(points, triangles, data_points) 125 127 128 #FIXME Which matrix does this refer to? AtA? 126 129 answer = [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0], #Affects point d 127 130 [0.5, 0.0, 0.0, 0.5, 0.0, 0.0], #Affects points a and d … … 131 134 [1./3, 0.0, 0.0, 0.0, 1./3, 1./3]] #Affects points a,e and f 132 135 133 assert allclose(interp.get_A(), answer) 136 137 A = mat(interp.get_A()) 138 At = transpose(A) 139 AtA = At * A 140 AtA = AtA.asarray() 141 #print AtA 142 143 #FIXME These two matrices are not correct, recalculate 144 assert allclose(AtA, answer) 134 145 135 146 … … 202 213 [0.5, -1.9], 203 214 [3.0,1.0]] 204 215 205 216 z = linear_function(point_coords) 206 217 interp = Interpolation(vertices, triangles, point_coords, alpha=0.0) 218 219 #print 'z',z 207 220 f = interp.fit(z) 208 221 answer = linear_function(vertices)
Note: See TracChangeset
for help on using the changeset viewer.