source: branches/source_numpy_conversion/anuga/utilities/cg_solve.py @ 6982

Last change on this file since 6982 was 5902, checked in by rwilson, 16 years ago

NumPy? conversion.

File size: 2.5 KB
Line 
1
2import exceptions
3class VectorShapeError(exceptions.Exception): pass
4class ConvergenceError(exceptions.Exception): pass
5
6import numpy
7   
8import logging, logging.config
9logger = logging.getLogger('cg_solve')
10logger.setLevel(logging.WARNING)
11
12try:
13    logging.config.fileConfig('log.ini')
14except:
15    pass
16
17def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
18    """
19    Try to solve linear equation Ax = b using
20    conjugate gradient method
21
22    If b is an array, solve it as if it was a set of vectors, solving each
23    vector.
24    """
25   
26    if x0 is None:
27        x0 = numpy.zeros(b.shape, dtype=numpy.float)
28    else:
29        x0 = numpy.array(x0, dtype=numpy.float)
30
31    b  = numpy.array(b, dtype=numpy.float)
32    if len(b.shape) != 1 :
33       
34        for i in range(b.shape[1]):
35            x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i],
36                                          imax, tol, iprint)
37    else:
38        x0 = _conjugate_gradient(A, b, x0, imax, tol, iprint)
39
40    return x0
41   
42def _conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
43   """
44   Try to solve linear equation Ax = b using
45   conjugate gradient method
46
47   Input
48   A: matrix or function which applies a matrix, assumed symmetric
49      A can be either dense or sparse
50   b: right hand side
51   x0: inital guess (default the 0 vector)
52   imax: max number of iterations
53   tol: tolerance used for residual
54
55   Output
56   x: approximate solution
57   """
58
59
60   b  = numpy.array(b, dtype=numpy.float)
61   if len(b.shape) != 1 :
62      raise VectorShapeError, 'input vector should consist of only one column'
63
64   if x0 is None:
65      x0 = numpy.zeros(b.shape, dtype=numpy.float)
66   else:
67      x0 = numpy.array(x0, dtype=numpy.float)
68
69
70   #FIXME: Should test using None
71   if iprint == 0:
72      iprint = imax
73
74   i=1
75   x = x0
76   r = b - A*x
77   d = r
78   rTr = numpy.dot(r,r)
79   rTr0 = rTr
80
81   while (i<imax and rTr>tol**2*rTr0):
82       q = A*d
83       alpha = rTr/numpy.dot(d,q)
84       x = x + alpha*d
85       if i%50 :
86           r = b - A*x
87       else:
88           r = r - alpha*q
89       rTrOld = rTr
90       rTr = numpy.dot(r,r)
91       bt = rTr/rTrOld
92
93       d = r + bt*d
94       i = i+1
95       if i%iprint == 0 :
96          logger.info('i = %g rTr = %20.15e' %(i,rTr))
97
98       if i==imax:
99            logger.warning('max number of iterations attained')
100            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr
101            raise ConvergenceError, msg
102
103   return x
104
Note: See TracBrowser for help on using the repository browser.