source: trunk/anuga_core/source/anuga/utilities/cg_solve.py @ 7848

Last change on this file since 7848 was 7848, checked in by steve, 12 years ago

Changed the logging levels in log.py so that the information about openning the
file ./anuga.log is now only an info log as opposed to critical log. I.e. by default
doesn't write to the console.

File size: 2.4 KB
Line 
1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3class ConvergenceError(exceptions.Exception): pass
4
5import numpy as num
6
7import anuga.utilities.log as log
8
9
10def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=None):
11    """
12    Try to solve linear equation Ax = b using
13    conjugate gradient method
14
15    If b is an array, solve it as if it was a set of vectors, solving each
16    vector.
17    """
18   
19    if x0 is None:
20        x0 = num.zeros(b.shape, dtype=num.float)
21    else:
22        x0 = num.array(x0, dtype=num.float)
23
24    b  = num.array(b, dtype=num.float)
25    if len(b.shape) != 1 :
26       
27        for i in range(b.shape[1]):
28            x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i],
29                                          imax, tol, iprint)
30    else:
31        x0 = _conjugate_gradient(A, b, x0, imax, tol, iprint)
32
33    return x0
34   
35def _conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=None):
36   """
37   Try to solve linear equation Ax = b using
38   conjugate gradient method
39
40   Input
41   A: matrix or function which applies a matrix, assumed symmetric
42      A can be either dense or sparse
43   b: right hand side
44   x0: inital guess (default the 0 vector)
45   imax: max number of iterations
46   tol: tolerance used for residual
47
48   Output
49   x: approximate solution
50   """
51
52
53   b  = num.array(b, dtype=num.float)
54   if len(b.shape) != 1 :
55      raise VectorShapeError, 'input vector should consist of only one column'
56
57   if x0 is None:
58      x0 = num.zeros(b.shape, dtype=num.float)
59   else:
60      x0 = num.array(x0, dtype=num.float)
61
62
63   #FIXME: Should test using None
64   if iprint == None  or iprint == 0:
65      iprint = imax
66
67   i=1
68   x = x0
69   r = b - A*x
70   d = r
71   rTr = num.dot(r,r)
72   rTr0 = rTr
73
74   #FIXME Let the iterations stop if starting with a small residual
75   while (i<imax and rTr>tol**2*rTr0):
76       q = A*d
77       alpha = rTr/num.dot(d,q)
78       x = x + alpha*d
79       if i%50 :
80           r = b - A*x
81       else:
82           r = r - alpha*q
83       rTrOld = rTr
84       rTr = num.dot(r,r)
85       bt = rTr/rTrOld
86
87       d = r + bt*d
88       i = i+1
89       if i%iprint == 0 :
90          log.info('i = %g rTr = %20.15e' %(i,rTr))
91
92       if i==imax:
93            log.warning('max number of iterations attained')
94            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr
95            raise ConvergenceError, msg
96
97   return x
98
Note: See TracBrowser for help on using the repository browser.