[6158] | 1 | import exceptions |
---|
| 2 | class VectorShapeError(exceptions.Exception): pass |
---|
| 3 | class ConvergenceError(exceptions.Exception): pass |
---|
| 4 | |
---|
[7276] | 5 | import numpy as num |
---|
[6158] | 6 | |
---|
[7845] | 7 | import anuga.utilities.log as log |
---|
[6158] | 8 | |
---|
[7845] | 9 | |
---|
[8025] | 10 | def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,atol=1.0e-14,iprint=None): |
---|
[6158] | 11 | """ |
---|
| 12 | Try to solve linear equation Ax = b using |
---|
| 13 | conjugate gradient method |
---|
| 14 | |
---|
| 15 | If b is an array, solve it as if it was a set of vectors, solving each |
---|
| 16 | vector. |
---|
| 17 | """ |
---|
| 18 | |
---|
| 19 | if x0 is None: |
---|
[7276] | 20 | x0 = num.zeros(b.shape, dtype=num.float) |
---|
[6158] | 21 | else: |
---|
[7276] | 22 | x0 = num.array(x0, dtype=num.float) |
---|
[6158] | 23 | |
---|
[7276] | 24 | b = num.array(b, dtype=num.float) |
---|
[6158] | 25 | if len(b.shape) != 1 : |
---|
| 26 | |
---|
| 27 | for i in range(b.shape[1]): |
---|
| 28 | x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i], |
---|
[8025] | 29 | imax, tol, atol, iprint) |
---|
[6158] | 30 | else: |
---|
[8025] | 31 | x0 = _conjugate_gradient(A, b, x0, imax, tol, atol, iprint) |
---|
[6158] | 32 | |
---|
| 33 | return x0 |
---|
| 34 | |
---|
[8025] | 35 | def _conjugate_gradient(A, b, x0, |
---|
| 36 | imax=10000, tol=1.0e-8, atol=1.0e-14, iprint=None): |
---|
[6158] | 37 | """ |
---|
| 38 | Try to solve linear equation Ax = b using |
---|
| 39 | conjugate gradient method |
---|
| 40 | |
---|
| 41 | Input |
---|
| 42 | A: matrix or function which applies a matrix, assumed symmetric |
---|
| 43 | A can be either dense or sparse |
---|
| 44 | b: right hand side |
---|
| 45 | x0: inital guess (default the 0 vector) |
---|
| 46 | imax: max number of iterations |
---|
| 47 | tol: tolerance used for residual |
---|
| 48 | |
---|
| 49 | Output |
---|
| 50 | x: approximate solution |
---|
| 51 | """ |
---|
| 52 | |
---|
| 53 | |
---|
[7276] | 54 | b = num.array(b, dtype=num.float) |
---|
[6158] | 55 | if len(b.shape) != 1 : |
---|
| 56 | raise VectorShapeError, 'input vector should consist of only one column' |
---|
| 57 | |
---|
| 58 | if x0 is None: |
---|
[7276] | 59 | x0 = num.zeros(b.shape, dtype=num.float) |
---|
[6158] | 60 | else: |
---|
[7276] | 61 | x0 = num.array(x0, dtype=num.float) |
---|
[6158] | 62 | |
---|
| 63 | |
---|
| 64 | #FIXME: Should test using None |
---|
[7848] | 65 | if iprint == None or iprint == 0: |
---|
[6158] | 66 | iprint = imax |
---|
| 67 | |
---|
| 68 | i=1 |
---|
| 69 | x = x0 |
---|
| 70 | r = b - A*x |
---|
| 71 | d = r |
---|
| 72 | rTr = num.dot(r,r) |
---|
| 73 | rTr0 = rTr |
---|
| 74 | |
---|
[7848] | 75 | #FIXME Let the iterations stop if starting with a small residual |
---|
[8025] | 76 | while (i<imax and rTr>tol**2*rTr0 and rTr>atol**2 ): |
---|
[6158] | 77 | q = A*d |
---|
| 78 | alpha = rTr/num.dot(d,q) |
---|
| 79 | x = x + alpha*d |
---|
| 80 | if i%50 : |
---|
| 81 | r = b - A*x |
---|
| 82 | else: |
---|
| 83 | r = r - alpha*q |
---|
| 84 | rTrOld = rTr |
---|
| 85 | rTr = num.dot(r,r) |
---|
| 86 | bt = rTr/rTrOld |
---|
| 87 | |
---|
| 88 | d = r + bt*d |
---|
| 89 | i = i+1 |
---|
| 90 | if i%iprint == 0 : |
---|
[7845] | 91 | log.info('i = %g rTr = %20.15e' %(i,rTr)) |
---|
[6158] | 92 | |
---|
| 93 | if i==imax: |
---|
[7845] | 94 | log.warning('max number of iterations attained') |
---|
[6158] | 95 | msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr |
---|
| 96 | raise ConvergenceError, msg |
---|
| 97 | |
---|
| 98 | return x |
---|
| 99 | |
---|