source: trunk/anuga_core/source/anuga/utilities/cg_solve.py @ 8272

Last change on this file since 8272 was 8164, checked in by steve, 14 years ago

Changes to allow kinematic viscosity

File size: 3.0 KB
Line 
1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3class ConvergenceError(exceptions.Exception): pass
4
5import numpy as num
6
7import anuga.utilities.log as log
8
9class Stats:
10
11    def __init__(self):
12
13        self.iter = None
14        self.rTr = None
15        self.dt = None
16
17def conjugate_gradient(A, b, x0=None, imax=10000, tol=1.0e-8, atol=1.0e-14,
18                        iprint=None, output_stats=False):
19    """
20    Try to solve linear equation Ax = b using
21    conjugate gradient method
22
23    If b is an array, solve it as if it was a set of vectors, solving each
24    vector.
25    """
26   
27    if x0 is None:
28        x0 = num.zeros(b.shape, dtype=num.float)
29    else:
30        x0 = num.array(x0, dtype=num.float)
31
32    b  = num.array(b, dtype=num.float)
33
34
35    if len(b.shape) != 1:
36       
37        for i in range(b.shape[1]):
38            x0[:, i], stats = _conjugate_gradient(A, b[:, i], x0[:, i],
39                                           imax, tol, atol, iprint)
40    else:
41        x0 , stats = _conjugate_gradient(A, b, x0, imax, tol, atol, iprint)
42
43    if output_stats:
44        return x0, stats
45    else:
46        return x0
47
48   
49def _conjugate_gradient(A, b, x0, 
50                        imax=10000, tol=1.0e-8, atol=1.0e-10, iprint=None):
51    """
52   Try to solve linear equation Ax = b using
53   conjugate gradient method
54
55   Input
56   A: matrix or function which applies a matrix, assumed symmetric
57      A can be either dense or sparse or a function
58      (__mul__ just needs to be defined)
59   b: right hand side
60   x0: inital guess (default the 0 vector)
61   imax: max number of iterations
62   tol: tolerance used for residual
63
64   Output
65   x: approximate solution
66   """
67
68    b  = num.array(b, dtype=num.float)
69    if len(b.shape) != 1:
70        raise VectorShapeError, 'input vector should consist of only one column'
71
72    if x0 is None:
73        x0 = num.zeros(b.shape, dtype=num.float)
74    else:
75        x0 = num.array(x0, dtype=num.float)
76
77
78    if iprint == None  or iprint == 0:
79        iprint = imax
80
81    dx = 0.0
82   
83    i = 1
84    x = x0
85    r = b - A * x
86    d = r
87    rTr = num.dot(r, r)
88    rTr0 = rTr
89   
90    #FIXME Let the iterations stop if starting with a small residual
91    while (i < imax and rTr > tol ** 2 * rTr0 and rTr > atol ** 2):
92        q = A * d
93        alpha = rTr / num.dot(d, q)
94        xold = x
95        x = x + alpha * d
96
97        dx = num.linalg.norm(x-xold)
98        #if dx < atol :
99        #    break
100           
101        if i % 50:
102            r = b - A * x
103        else:
104            r = r - alpha * q
105        rTrOld = rTr
106        rTr = num.dot(r, r)
107        bt = rTr / rTrOld
108
109        d = r + bt * d
110        i = i + 1
111        if i % iprint == 0:
112            log.info('i = %g rTr = %15.8e dx = %15.8e' % (i, rTr, dx))
113
114        if i == imax:
115            log.warning('max number of iterations attained')
116            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' % rTr
117            raise ConvergenceError, msg
118
119    #print x
120
121    stats = Stats()
122    stats.iter = i
123    stats.rTr = rTr
124    stats.dx = dx
125
126    return x, stats
127
Note: See TracBrowser for help on using the repository browser.