source: inundation/ga/storm_surge/pyvolution/cg_solve.py @ 527

Last change on this file since 527 was 477, checked in by ole, 20 years ago

Minor cosmetics and comments

File size: 1.5 KB
RevLine 
[435]1import exceptions
[438]2class VectorShapeError(exceptions.Exception): pass
[435]3
4
[475]5def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
[438]6   """
7   Try to solve linear equation Ax = b using
8   conjugate gradient method
9   
10   Input
11   A: matrix or function which applies a matrix, assumed symmetric
[475]12      A can be either dense or sparse
[438]13   b: right hand side
[475]14   x0: inital guess (default the 0 vector)
[438]15   imax: max number of iterations
16   tol: tolerance used for residual
17   
18   Output
19   x: approximate solution
20   """
[432]21
[475]22   from Numeric import dot, array, Float, zeros
23   
24   b  = array(b, typecode=Float)
[477]25   if len(b.shape) != 1 :
26      raise VectorShapeError, 'input vector should consist of only one column'
27
[475]28   if x0 is None:
29      x0 = zeros(b.shape, typecode=Float)
30   else:   
31      x0 = array(x0, typecode=Float)
32     
[432]33
[475]34   #FIXME: Should test using None
[435]35   if iprint == 0:
36      iprint = imax
37     
38   i=1
[432]39   x = x0
40   r = b - A*x
41   d = r
42   rTr = dot(r,r)
43   rTr0 = rTr
44
45   while (i<imax and rTr>tol**2*rTr0):
46       q = A*d
47       alpha = rTr/dot(d,q)
48       x = x + alpha*d
49       if i%50 :
50           r = b - A*x
51       else:
52           r = r - alpha*q
53       rTrOld = rTr
54       rTr = dot(r,r)
55       bt = rTr/rTrOld
56
57       d = r + bt*d
58       i = i+1
[435]59       if i%iprint == 0 :
60          print 'i = %g rTr = %20.15e'% (i,rTr)
61
[477]62   #FIXME: Should this raise an exception?       
[432]63   if i==imax:
64     print 'max number of iterations attained'
65
66   return x
67
68
Note: See TracBrowser for help on using the repository browser.