Changeset 8154 for trunk/anuga_core/source/anuga/utilities/cg_solve.py
- Timestamp:
- Mar 17, 2011, 5:51:42 PM (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/anuga_core/source/anuga/utilities/cg_solve.py
r8025 r8154 8 8 9 9 10 def conjugate_gradient(A, b,x0=None,imax=10000,tol=1.0e-8,atol=1.0e-14,iprint=None):10 def conjugate_gradient(A, b, x0=None, imax=10000, tol=1.0e-8, atol=1.0e-14, iprint=None): 11 11 """ 12 12 Try to solve linear equation Ax = b using … … 23 23 24 24 b = num.array(b, dtype=num.float) 25 if len(b.shape) != 1 25 if len(b.shape) != 1: 26 26 27 27 for i in range(b.shape[1]): 28 x0[:, i] = _conjugate_gradient(A, b[:,i], x0[:,i],29 imax, tol, atol, iprint)28 x0[:, i] = _conjugate_gradient(A, b[:, i], x0[:, i], 29 imax, tol, atol, iprint) 30 30 else: 31 31 x0 = _conjugate_gradient(A, b, x0, imax, tol, atol, iprint) … … 34 34 35 35 def _conjugate_gradient(A, b, x0, 36 imax=10000, tol=1.0e-8, atol=1.0e-1 4, iprint=None):37 """36 imax=10000, tol=1.0e-8, atol=1.0e-10, iprint=None): 37 """ 38 38 Try to solve linear equation Ax = b using 39 39 conjugate gradient method … … 51 51 """ 52 52 53 b = num.array(b, dtype=num.float) 54 if len(b.shape) != 1: 55 raise VectorShapeError, 'input vector should consist of only one column' 53 56 54 b = num.array(b, dtype=num.float) 55 if len(b.shape) != 1 : 56 raise VectorShapeError, 'input vector should consist of only one column' 57 58 if x0 is None: 59 x0 = num.zeros(b.shape, dtype=num.float) 60 else: 61 x0 = num.array(x0, dtype=num.float) 57 if x0 is None: 58 x0 = num.zeros(b.shape, dtype=num.float) 59 else: 60 x0 = num.array(x0, dtype=num.float) 62 61 63 62 64 #FIXME: Should test using None 65 if iprint == None or iprint == 0: 66 iprint = imax 63 if iprint == None or iprint == 0: 64 iprint = imax 67 65 68 i=169 x = x070 r = b - A*x71 d = r72 rTr = num.dot(r,r)73 rTr0 = rTr66 i = 1 67 x = x0 68 r = b - A * x 69 d = r 70 rTr = num.dot(r, r) 71 rTr0 = rTr 74 72 75 #FIXME Let the iterations stop if starting with a small residual76 while (i<imax and rTr>tol**2*rTr0 and rTr>atol**2 ):77 q = A*d78 alpha = rTr/num.dot(d,q)79 x = x + alpha*d80 if i%50 :81 r = b - A*x82 else:83 r = r - alpha*q84 rTrOld = rTr85 rTr = num.dot(r,r)86 bt = rTr/rTrOld87 73 88 d = r + bt*d 89 i = i+1 90 if i%iprint == 0 : 91 log.info('i = %g rTr = %20.15e' %(i,rTr)) 74 75 #FIXME Let the iterations stop if starting with a small residual 76 while (i < imax and rTr > tol ** 2 * rTr0 and rTr > atol ** 2): 77 q = A * d 78 alpha = rTr / num.dot(d, q) 79 xold = x 80 x = x + alpha * d 92 81 93 if i==imax: 82 dx = num.linalg.norm(x-xold) 83 if dx < atol : 84 break 85 86 if i % 50: 87 r = b - A * x 88 else: 89 r = r - alpha * q 90 rTrOld = rTr 91 rTr = num.dot(r, r) 92 bt = rTr / rTrOld 93 94 d = r + bt * d 95 i = i + 1 96 if i % iprint == 0: 97 log.info('i = %g rTr = %15.8e dx = %15.8e' % (i, rTr, dx)) 98 99 if i == imax: 94 100 log.warning('max number of iterations attained') 95 msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' % rTr101 msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' % rTr 96 102 raise ConvergenceError, msg 97 103 98 return x 104 #print x 105 return x 99 106
Note: See TracChangeset
for help on using the changeset viewer.