source: trunk/anuga_core/source/anuga/utilities/cg_solve.py @ 8025

Last change on this file since 8025 was 8025, checked in by steve, 14 years ago

Changed name of intersection of segment routine

File size: 2.4 KB
Line 
1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3class ConvergenceError(exceptions.Exception): pass
4
5import numpy as num
6
7import anuga.utilities.log as log
8
9
10def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,atol=1.0e-14,iprint=None):
11    """
12    Try to solve linear equation Ax = b using
13    conjugate gradient method
14
15    If b is an array, solve it as if it was a set of vectors, solving each
16    vector.
17    """
18   
19    if x0 is None:
20        x0 = num.zeros(b.shape, dtype=num.float)
21    else:
22        x0 = num.array(x0, dtype=num.float)
23
24    b  = num.array(b, dtype=num.float)
25    if len(b.shape) != 1 :
26       
27        for i in range(b.shape[1]):
28            x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i],
29                                          imax, tol, atol, iprint)
30    else:
31        x0 = _conjugate_gradient(A, b, x0, imax, tol, atol, iprint)
32
33    return x0
34   
35def _conjugate_gradient(A, b, x0, 
36                        imax=10000, tol=1.0e-8, atol=1.0e-14, iprint=None):
37   """
38   Try to solve linear equation Ax = b using
39   conjugate gradient method
40
41   Input
42   A: matrix or function which applies a matrix, assumed symmetric
43      A can be either dense or sparse
44   b: right hand side
45   x0: inital guess (default the 0 vector)
46   imax: max number of iterations
47   tol: tolerance used for residual
48
49   Output
50   x: approximate solution
51   """
52
53
54   b  = num.array(b, dtype=num.float)
55   if len(b.shape) != 1 :
56      raise VectorShapeError, 'input vector should consist of only one column'
57
58   if x0 is None:
59      x0 = num.zeros(b.shape, dtype=num.float)
60   else:
61      x0 = num.array(x0, dtype=num.float)
62
63
64   #FIXME: Should test using None
65   if iprint == None  or iprint == 0:
66      iprint = imax
67
68   i=1
69   x = x0
70   r = b - A*x
71   d = r
72   rTr = num.dot(r,r)
73   rTr0 = rTr
74
75   #FIXME Let the iterations stop if starting with a small residual
76   while (i<imax and rTr>tol**2*rTr0 and rTr>atol**2 ):
77       q = A*d
78       alpha = rTr/num.dot(d,q)
79       x = x + alpha*d
80       if i%50 :
81           r = b - A*x
82       else:
83           r = r - alpha*q
84       rTrOld = rTr
85       rTr = num.dot(r,r)
86       bt = rTr/rTrOld
87
88       d = r + bt*d
89       i = i+1
90       if i%iprint == 0 :
91          log.info('i = %g rTr = %20.15e' %(i,rTr))
92
93       if i==imax:
94            log.warning('max number of iterations attained')
95            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr
96            raise ConvergenceError, msg
97
98   return x
99
Note: See TracBrowser for help on using the repository browser.