source: trunk/anuga_core/source/anuga/utilities/cg_solve.py @ 8154

Last change on this file since 8154 was 8154, checked in by steve, 12 years ago

Adding in kinematic viscosity. Added in some procedures to change boundary_values of
quantities

File size: 2.6 KB
Line 
1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3class ConvergenceError(exceptions.Exception): pass
4
5import numpy as num
6
7import anuga.utilities.log as log
8
9
10def conjugate_gradient(A, b, x0=None, imax=10000, tol=1.0e-8, atol=1.0e-14, iprint=None):
11    """
12    Try to solve linear equation Ax = b using
13    conjugate gradient method
14
15    If b is an array, solve it as if it was a set of vectors, solving each
16    vector.
17    """
18   
19    if x0 is None:
20        x0 = num.zeros(b.shape, dtype=num.float)
21    else:
22        x0 = num.array(x0, dtype=num.float)
23
24    b  = num.array(b, dtype=num.float)
25    if len(b.shape) != 1:
26       
27        for i in range(b.shape[1]):
28            x0[:, i] = _conjugate_gradient(A, b[:, i], x0[:, i],
29                                           imax, tol, atol, iprint)
30    else:
31        x0 = _conjugate_gradient(A, b, x0, imax, tol, atol, iprint)
32
33    return x0
34   
35def _conjugate_gradient(A, b, x0, 
36                        imax=10000, tol=1.0e-8, atol=1.0e-10, iprint=None):
37    """
38   Try to solve linear equation Ax = b using
39   conjugate gradient method
40
41   Input
42   A: matrix or function which applies a matrix, assumed symmetric
43      A can be either dense or sparse
44   b: right hand side
45   x0: inital guess (default the 0 vector)
46   imax: max number of iterations
47   tol: tolerance used for residual
48
49   Output
50   x: approximate solution
51   """
52
53    b  = num.array(b, dtype=num.float)
54    if len(b.shape) != 1:
55        raise VectorShapeError, 'input vector should consist of only one column'
56
57    if x0 is None:
58        x0 = num.zeros(b.shape, dtype=num.float)
59    else:
60        x0 = num.array(x0, dtype=num.float)
61
62
63    if iprint == None  or iprint == 0:
64        iprint = imax
65
66    i = 1
67    x = x0
68    r = b - A * x
69    d = r
70    rTr = num.dot(r, r)
71    rTr0 = rTr
72
73
74   
75    #FIXME Let the iterations stop if starting with a small residual
76    while (i < imax and rTr > tol ** 2 * rTr0 and rTr > atol ** 2):
77        q = A * d
78        alpha = rTr / num.dot(d, q)
79        xold = x
80        x = x + alpha * d
81
82        dx = num.linalg.norm(x-xold)
83        if dx < atol :
84            break
85           
86        if i % 50:
87            r = b - A * x
88        else:
89            r = r - alpha * q
90        rTrOld = rTr
91        rTr = num.dot(r, r)
92        bt = rTr / rTrOld
93
94        d = r + bt * d
95        i = i + 1
96        if i % iprint == 0:
97            log.info('i = %g rTr = %15.8e dx = %15.8e' % (i, rTr, dx))
98
99        if i == imax:
100            log.warning('max number of iterations attained')
101            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' % rTr
102            raise ConvergenceError, msg
103
104    #print x
105    return x
106
Note: See TracBrowser for help on using the repository browser.