source: anuga_core/source/anuga/utilities/cg_solve.py @ 7276

Last change on this file since 7276 was 7276, checked in by ole, 13 years ago

Merged numpy branch back into the trunk.

In ~/sandpit/anuga/anuga_core/source
svn merge -r 6246:HEAD ../../branches/numpy .

In ~/sandpit/anuga/anuga_validation
svn merge -r 6417:HEAD ../branches/numpy_anuga_validation .

In ~/sandpit/anuga/misc
svn merge -r 6809:HEAD ../branches/numpy_misc .

For all merges, I used numpy version where conflicts existed

The suites test_all.py (in source/anuga) and validate_all.py passed using Python2.5 with numpy on my Ubuntu Linux box.

File size: 2.4 KB
Line 
1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3class ConvergenceError(exceptions.Exception): pass
4
5import numpy as num
6   
7import logging, logging.config
8logger = logging.getLogger('cg_solve')
9logger.setLevel(logging.WARNING)
10
11try:
12    logging.config.fileConfig('log.ini')
13except:
14    pass
15
16def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
17    """
18    Try to solve linear equation Ax = b using
19    conjugate gradient method
20
21    If b is an array, solve it as if it was a set of vectors, solving each
22    vector.
23    """
24   
25    if x0 is None:
26        x0 = num.zeros(b.shape, dtype=num.float)
27    else:
28        x0 = num.array(x0, dtype=num.float)
29
30    b  = num.array(b, dtype=num.float)
31    if len(b.shape) != 1 :
32       
33        for i in range(b.shape[1]):
34            x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i],
35                                          imax, tol, iprint)
36    else:
37        x0 = _conjugate_gradient(A, b, x0, imax, tol, iprint)
38
39    return x0
40   
41def _conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
42   """
43   Try to solve linear equation Ax = b using
44   conjugate gradient method
45
46   Input
47   A: matrix or function which applies a matrix, assumed symmetric
48      A can be either dense or sparse
49   b: right hand side
50   x0: inital guess (default the 0 vector)
51   imax: max number of iterations
52   tol: tolerance used for residual
53
54   Output
55   x: approximate solution
56   """
57
58
59   b  = num.array(b, dtype=num.float)
60   if len(b.shape) != 1 :
61      raise VectorShapeError, 'input vector should consist of only one column'
62
63   if x0 is None:
64      x0 = num.zeros(b.shape, dtype=num.float)
65   else:
66      x0 = num.array(x0, dtype=num.float)
67
68
69   #FIXME: Should test using None
70   if iprint == 0:
71      iprint = imax
72
73   i=1
74   x = x0
75   r = b - A*x
76   d = r
77   rTr = num.dot(r,r)
78   rTr0 = rTr
79
80   while (i<imax and rTr>tol**2*rTr0):
81       q = A*d
82       alpha = rTr/num.dot(d,q)
83       x = x + alpha*d
84       if i%50 :
85           r = b - A*x
86       else:
87           r = r - alpha*q
88       rTrOld = rTr
89       rTr = num.dot(r,r)
90       bt = rTr/rTrOld
91
92       d = r + bt*d
93       i = i+1
94       if i%iprint == 0 :
95          logger.info('i = %g rTr = %20.15e' %(i,rTr))
96
97       if i==imax:
98            logger.warning('max number of iterations attained')
99            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr
100            raise ConvergenceError, msg
101
102   return x
103
Note: See TracBrowser for help on using the repository browser.