source: trunk/anuga_core/source/anuga/utilities/cg_solve.py @ 7845

Last change on this file since 7845 was 7845, checked in by steve, 13 years ago
File size: 2.3 KB
Line 
1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3class ConvergenceError(exceptions.Exception): pass
4
5import numpy as num
6
7import anuga.utilities.log as log
8
9
10def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
11    """
12    Try to solve linear equation Ax = b using
13    conjugate gradient method
14
15    If b is an array, solve it as if it was a set of vectors, solving each
16    vector.
17    """
18   
19    if x0 is None:
20        x0 = num.zeros(b.shape, dtype=num.float)
21    else:
22        x0 = num.array(x0, dtype=num.float)
23
24    b  = num.array(b, dtype=num.float)
25    if len(b.shape) != 1 :
26       
27        for i in range(b.shape[1]):
28            x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i],
29                                          imax, tol, iprint)
30    else:
31        x0 = _conjugate_gradient(A, b, x0, imax, tol, iprint)
32
33    return x0
34   
35def _conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
36   """
37   Try to solve linear equation Ax = b using
38   conjugate gradient method
39
40   Input
41   A: matrix or function which applies a matrix, assumed symmetric
42      A can be either dense or sparse
43   b: right hand side
44   x0: inital guess (default the 0 vector)
45   imax: max number of iterations
46   tol: tolerance used for residual
47
48   Output
49   x: approximate solution
50   """
51
52
53   b  = num.array(b, dtype=num.float)
54   if len(b.shape) != 1 :
55      raise VectorShapeError, 'input vector should consist of only one column'
56
57   if x0 is None:
58      x0 = num.zeros(b.shape, dtype=num.float)
59   else:
60      x0 = num.array(x0, dtype=num.float)
61
62
63   #FIXME: Should test using None
64   if iprint == 0:
65      iprint = imax
66
67   i=1
68   x = x0
69   r = b - A*x
70   d = r
71   rTr = num.dot(r,r)
72   rTr0 = rTr
73
74   while (i<imax and rTr>tol**2*rTr0):
75       q = A*d
76       alpha = rTr/num.dot(d,q)
77       x = x + alpha*d
78       if i%50 :
79           r = b - A*x
80       else:
81           r = r - alpha*q
82       rTrOld = rTr
83       rTr = num.dot(r,r)
84       bt = rTr/rTrOld
85
86       d = r + bt*d
87       i = i+1
88       if i%iprint == 0 :
89          log.info('i = %g rTr = %20.15e' %(i,rTr))
90
91       if i==imax:
92            log.warning('max number of iterations attained')
93            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr
94            raise ConvergenceError, msg
95
96   return x
97
Note: See TracBrowser for help on using the repository browser.