source: inundation/ga/storm_surge/pyvolution/cg_solve.py @ 438

Last change on this file since 438 was 438, checked in by steve, 20 years ago

CG tests

File size: 1.5 KB
Line 
1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3
4
5def conjugate_gradient(A,b,x0,imax=2000,tol=1.0e-8,iprint=0):
6   """
7   Try to solve linear equation Ax = b using
8   conjugate gradient method
9   
10   Input
11   A: matrix or function which applies a matrix, assumed symmetric
12   b: right hand side
13   x0: inital guess
14   imax: max number of iterations
15   tol: tolerance used for residual
16   
17   Output
18   x: approximate solution
19   """
20
21   #if nargin<3
22   #   x0 = zeros(b);
23   #end
24
25   from Numeric import dot, asarray, Float
26   #from operator import mod
27
28   b  = asarray(b, typecode=Float)
29   x0 = asarray(x0, typecode=Float)
30   #A = asarray(A, typecode=Float)
31
32   #print "A shape",A.shape
33   #print "b shape",len(b.shape)
34   #print "x0 shape",x0.shape[0]
35
36   if len(b.shape) != 1 :
37      raise VectorShapeError, 'input vector should consist of only one column'
38
39   if iprint == 0:
40      iprint = imax
41     
42   i=1
43   x = x0
44   r = b - A*x
45   d = r
46   rTr = dot(r,r)
47   rTr0 = rTr
48
49   while (i<imax and rTr>tol**2*rTr0):
50       q = A*d
51       alpha = rTr/dot(d,q)
52       x = x + alpha*d
53       if i%50 :
54           r = b - A*x
55       else:
56           r = r - alpha*q
57       rTrOld = rTr
58       rTr = dot(r,r)
59       bt = rTr/rTrOld
60
61       d = r + bt*d
62       i = i+1
63       if i%iprint == 0 :
64          print 'i = %g rTr = %20.15e'% (i,rTr)
65
66         
67   if i==imax:
68     print 'max number of iterations attained'
69
70   return x
71
72
Note: See TracBrowser for help on using the repository browser.