source: trunk/anuga_core/source/anuga/utilities/cg_solve.py @ 8154

Last change on this file since 8154 was 8154, checked in by steve, 13 years ago

Adding in kinematic viscosity. Added in some procedures to change boundary_values of
quantities

File size: 2.6 KB
RevLine 
[6158]1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3class ConvergenceError(exceptions.Exception): pass
4
[7276]5import numpy as num
[6158]6
[7845]7import anuga.utilities.log as log
[6158]8
[7845]9
[8154]10def conjugate_gradient(A, b, x0=None, imax=10000, tol=1.0e-8, atol=1.0e-14, iprint=None):
[6158]11    """
12    Try to solve linear equation Ax = b using
13    conjugate gradient method
14
15    If b is an array, solve it as if it was a set of vectors, solving each
16    vector.
17    """
18   
19    if x0 is None:
[7276]20        x0 = num.zeros(b.shape, dtype=num.float)
[6158]21    else:
[7276]22        x0 = num.array(x0, dtype=num.float)
[6158]23
[7276]24    b  = num.array(b, dtype=num.float)
[8154]25    if len(b.shape) != 1:
[6158]26       
27        for i in range(b.shape[1]):
[8154]28            x0[:, i] = _conjugate_gradient(A, b[:, i], x0[:, i],
29                                           imax, tol, atol, iprint)
[6158]30    else:
[8025]31        x0 = _conjugate_gradient(A, b, x0, imax, tol, atol, iprint)
[6158]32
33    return x0
34   
[8025]35def _conjugate_gradient(A, b, x0, 
[8154]36                        imax=10000, tol=1.0e-8, atol=1.0e-10, iprint=None):
37    """
[6158]38   Try to solve linear equation Ax = b using
39   conjugate gradient method
40
41   Input
42   A: matrix or function which applies a matrix, assumed symmetric
43      A can be either dense or sparse
44   b: right hand side
45   x0: inital guess (default the 0 vector)
46   imax: max number of iterations
47   tol: tolerance used for residual
48
49   Output
50   x: approximate solution
51   """
52
[8154]53    b  = num.array(b, dtype=num.float)
54    if len(b.shape) != 1:
55        raise VectorShapeError, 'input vector should consist of only one column'
[6158]56
[8154]57    if x0 is None:
58        x0 = num.zeros(b.shape, dtype=num.float)
59    else:
60        x0 = num.array(x0, dtype=num.float)
[6158]61
62
[8154]63    if iprint == None  or iprint == 0:
64        iprint = imax
[6158]65
[8154]66    i = 1
67    x = x0
68    r = b - A * x
69    d = r
70    rTr = num.dot(r, r)
71    rTr0 = rTr
[6158]72
73
[8154]74   
75    #FIXME Let the iterations stop if starting with a small residual
76    while (i < imax and rTr > tol ** 2 * rTr0 and rTr > atol ** 2):
77        q = A * d
78        alpha = rTr / num.dot(d, q)
79        xold = x
80        x = x + alpha * d
[6158]81
[8154]82        dx = num.linalg.norm(x-xold)
83        if dx < atol :
84            break
85           
86        if i % 50:
87            r = b - A * x
88        else:
89            r = r - alpha * q
90        rTrOld = rTr
91        rTr = num.dot(r, r)
92        bt = rTr / rTrOld
[6158]93
[8154]94        d = r + bt * d
95        i = i + 1
96        if i % iprint == 0:
97            log.info('i = %g rTr = %15.8e dx = %15.8e' % (i, rTr, dx))
98
99        if i == imax:
[7845]100            log.warning('max number of iterations attained')
[8154]101            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' % rTr
[6158]102            raise ConvergenceError, msg
103
[8154]104    #print x
105    return x
[6158]106
Note: See TracBrowser for help on using the repository browser.