source: trunk/anuga_core/source/anuga/utilities/cg_solve.py @ 8025

Last change on this file since 8025 was 8025, checked in by steve, 12 years ago

Changed name of intersection of segment routine

File size: 2.4 KB
RevLine 
[6158]1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3class ConvergenceError(exceptions.Exception): pass
4
[7276]5import numpy as num
[6158]6
[7845]7import anuga.utilities.log as log
[6158]8
[7845]9
[8025]10def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,atol=1.0e-14,iprint=None):
[6158]11    """
12    Try to solve linear equation Ax = b using
13    conjugate gradient method
14
15    If b is an array, solve it as if it was a set of vectors, solving each
16    vector.
17    """
18   
19    if x0 is None:
[7276]20        x0 = num.zeros(b.shape, dtype=num.float)
[6158]21    else:
[7276]22        x0 = num.array(x0, dtype=num.float)
[6158]23
[7276]24    b  = num.array(b, dtype=num.float)
[6158]25    if len(b.shape) != 1 :
26       
27        for i in range(b.shape[1]):
28            x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i],
[8025]29                                          imax, tol, atol, iprint)
[6158]30    else:
[8025]31        x0 = _conjugate_gradient(A, b, x0, imax, tol, atol, iprint)
[6158]32
33    return x0
34   
[8025]35def _conjugate_gradient(A, b, x0, 
36                        imax=10000, tol=1.0e-8, atol=1.0e-14, iprint=None):
[6158]37   """
38   Try to solve linear equation Ax = b using
39   conjugate gradient method
40
41   Input
42   A: matrix or function which applies a matrix, assumed symmetric
43      A can be either dense or sparse
44   b: right hand side
45   x0: inital guess (default the 0 vector)
46   imax: max number of iterations
47   tol: tolerance used for residual
48
49   Output
50   x: approximate solution
51   """
52
53
[7276]54   b  = num.array(b, dtype=num.float)
[6158]55   if len(b.shape) != 1 :
56      raise VectorShapeError, 'input vector should consist of only one column'
57
58   if x0 is None:
[7276]59      x0 = num.zeros(b.shape, dtype=num.float)
[6158]60   else:
[7276]61      x0 = num.array(x0, dtype=num.float)
[6158]62
63
64   #FIXME: Should test using None
[7848]65   if iprint == None  or iprint == 0:
[6158]66      iprint = imax
67
68   i=1
69   x = x0
70   r = b - A*x
71   d = r
72   rTr = num.dot(r,r)
73   rTr0 = rTr
74
[7848]75   #FIXME Let the iterations stop if starting with a small residual
[8025]76   while (i<imax and rTr>tol**2*rTr0 and rTr>atol**2 ):
[6158]77       q = A*d
78       alpha = rTr/num.dot(d,q)
79       x = x + alpha*d
80       if i%50 :
81           r = b - A*x
82       else:
83           r = r - alpha*q
84       rTrOld = rTr
85       rTr = num.dot(r,r)
86       bt = rTr/rTrOld
87
88       d = r + bt*d
89       i = i+1
90       if i%iprint == 0 :
[7845]91          log.info('i = %g rTr = %20.15e' %(i,rTr))
[6158]92
93       if i==imax:
[7845]94            log.warning('max number of iterations attained')
[6158]95            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr
96            raise ConvergenceError, msg
97
98   return x
99
Note: See TracBrowser for help on using the repository browser.