source: anuga_core/source/anuga/utilities/cg_solve.py @ 4249

Last change on this file since 4249 was 2841, checked in by duncan, 19 years ago

cg_solve now works if b is an array

File size: 2.5 KB
Line 
1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3class ConvergenceError(exceptions.Exception): pass
4
5from Numeric import dot, array, Float, zeros
6   
7import logging, logging.config
8logger = logging.getLogger('cg_solve')
9logger.setLevel(logging.WARNING)
10
11try:
12    logging.config.fileConfig('log.ini')
13except:
14    pass
15
16def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
17    """
18    Try to solve linear equation Ax = b using
19    conjugate gradient method
20
21    If b is an array, solve it as if it was a set of vectors, solving each
22    vector.
23    """
24   
25    if x0 is None:
26        x0 = zeros(b.shape, typecode=Float)
27    else:
28        x0 = array(x0, typecode=Float)
29
30    b  = array(b, typecode=Float)
31    if len(b.shape) != 1 :
32       
33        for i in range(b.shape[1]):
34            x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i],
35                                          imax, tol, iprint)
36    else:
37        x0 = _conjugate_gradient(A, b, x0, imax, tol, iprint)
38
39    return x0
40   
41def _conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
42   """
43   Try to solve linear equation Ax = b using
44   conjugate gradient method
45
46   Input
47   A: matrix or function which applies a matrix, assumed symmetric
48      A can be either dense or sparse
49   b: right hand side
50   x0: inital guess (default the 0 vector)
51   imax: max number of iterations
52   tol: tolerance used for residual
53
54   Output
55   x: approximate solution
56   """
57
58
59   b  = array(b, typecode=Float)
60   if len(b.shape) != 1 :
61      raise VectorShapeError, 'input vector should consist of only one column'
62
63   if x0 is None:
64      x0 = zeros(b.shape, typecode=Float)
65   else:
66      x0 = array(x0, typecode=Float)
67
68
69   #FIXME: Should test using None
70   if iprint == 0:
71      iprint = imax
72
73   i=1
74   x = x0
75   r = b - A*x
76   d = r
77   rTr = dot(r,r)
78   rTr0 = rTr
79
80   while (i<imax and rTr>tol**2*rTr0):
81       q = A*d
82       alpha = rTr/dot(d,q)
83       x = x + alpha*d
84       if i%50 :
85           r = b - A*x
86       else:
87           r = r - alpha*q
88       rTrOld = rTr
89       rTr = dot(r,r)
90       bt = rTr/rTrOld
91
92       d = r + bt*d
93       i = i+1
94       if i%iprint == 0 :
95          logger.info('i = %g rTr = %20.15e' %(i,rTr))
96
97       if i==imax:
98            logger.warning('max number of iterations attained')
99            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr
100            raise ConvergenceError, msg
101
102   return x
103
Note: See TracBrowser for help on using the repository browser.