source: inundation/utilities/cg_solve.py @ 3390

Last change on this file since 3390 was 2841, checked in by duncan, 19 years ago

cg_solve now works if b is an array

File size: 2.5 KB
RevLine 
[435]1import exceptions
[438]2class VectorShapeError(exceptions.Exception): pass
[2661]3class ConvergenceError(exceptions.Exception): pass
[435]4
[2841]5from Numeric import dot, array, Float, zeros
6   
[1150]7import logging, logging.config
8logger = logging.getLogger('cg_solve')
[1160]9logger.setLevel(logging.WARNING)
[435]10
[1156]11try:
12    logging.config.fileConfig('log.ini')
13except:
14    pass
[1150]15
[475]16def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
[2841]17    """
18    Try to solve linear equation Ax = b using
19    conjugate gradient method
20
21    If b is an array, solve it as if it was a set of vectors, solving each
22    vector.
23    """
24   
25    if x0 is None:
26        x0 = zeros(b.shape, typecode=Float)
27    else:
28        x0 = array(x0, typecode=Float)
29
30    b  = array(b, typecode=Float)
31    if len(b.shape) != 1 :
32       
33        for i in range(b.shape[1]):
34            x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i],
35                                          imax, tol, iprint)
36    else:
37        x0 = _conjugate_gradient(A, b, x0, imax, tol, iprint)
38
39    return x0
40   
41def _conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
[438]42   """
43   Try to solve linear equation Ax = b using
44   conjugate gradient method
[1150]45
[438]46   Input
47   A: matrix or function which applies a matrix, assumed symmetric
[475]48      A can be either dense or sparse
[438]49   b: right hand side
[475]50   x0: inital guess (default the 0 vector)
[438]51   imax: max number of iterations
52   tol: tolerance used for residual
[1150]53
[438]54   Output
55   x: approximate solution
56   """
[432]57
[1150]58
[475]59   b  = array(b, typecode=Float)
[477]60   if len(b.shape) != 1 :
61      raise VectorShapeError, 'input vector should consist of only one column'
62
[475]63   if x0 is None:
64      x0 = zeros(b.shape, typecode=Float)
[1150]65   else:
[475]66      x0 = array(x0, typecode=Float)
[432]67
[1150]68
[475]69   #FIXME: Should test using None
[435]70   if iprint == 0:
71      iprint = imax
[1150]72
[435]73   i=1
[432]74   x = x0
75   r = b - A*x
76   d = r
77   rTr = dot(r,r)
78   rTr0 = rTr
79
80   while (i<imax and rTr>tol**2*rTr0):
81       q = A*d
82       alpha = rTr/dot(d,q)
83       x = x + alpha*d
84       if i%50 :
85           r = b - A*x
86       else:
87           r = r - alpha*q
88       rTrOld = rTr
89       rTr = dot(r,r)
90       bt = rTr/rTrOld
91
92       d = r + bt*d
93       i = i+1
[435]94       if i%iprint == 0 :
[2670]95          logger.info('i = %g rTr = %20.15e' %(i,rTr))
[435]96
[2731]97       if i==imax:
98            logger.warning('max number of iterations attained')
99            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr
100            raise ConvergenceError, msg
[432]101
102   return x
[2841]103
Note: See TracBrowser for help on using the repository browser.