# source:branches/numpy/anuga/utilities/cg_solve.py@6415

Last change on this file since 6415 was 6304, checked in by rwilson, 15 years ago

Initial commit of numpy changes. Still a long way to go.

File size: 2.4 KB
Line
1import exceptions
2class VectorShapeError(exceptions.Exception): pass
3class ConvergenceError(exceptions.Exception): pass
4
5import numpy as num
6
7import logging, logging.config
8logger = logging.getLogger('cg_solve')
9logger.setLevel(logging.WARNING)
10
11try:
12    logging.config.fileConfig('log.ini')
13except:
14    pass
15
17    """
18    Try to solve linear equation Ax = b using
20
21    If b is an array, solve it as if it was a set of vectors, solving each
22    vector.
23    """
24
25    if x0 is None:
26        x0 = num.zeros(b.shape, dtype=num.float)
27    else:
28        x0 = num.array(x0, dtype=num.float)
29
30    b  = num.array(b, dtype=num.float)
31    if len(b.shape) != 1 :
32
33        for i in range(b.shape[1]):
34            x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i],
35                                          imax, tol, iprint)
36    else:
37        x0 = _conjugate_gradient(A, b, x0, imax, tol, iprint)
38
39    return x0
40
42   """
43   Try to solve linear equation Ax = b using
45
46   Input
47   A: matrix or function which applies a matrix, assumed symmetric
48      A can be either dense or sparse
49   b: right hand side
50   x0: inital guess (default the 0 vector)
51   imax: max number of iterations
52   tol: tolerance used for residual
53
54   Output
55   x: approximate solution
56   """
57
58
59   b  = num.array(b, dtype=num.float)
60   if len(b.shape) != 1 :
61      raise VectorShapeError, 'input vector should consist of only one column'
62
63   if x0 is None:
64      x0 = num.zeros(b.shape, dtype=num.float)
65   else:
66      x0 = num.array(x0, dtype=num.float)
67
68
69   #FIXME: Should test using None
70   if iprint == 0:
71      iprint = imax
72
73   i=1
74   x = x0
75   r = b - A*x
76   d = r
77   rTr = num.dot(r,r)
78   rTr0 = rTr
79
80   while (i<imax and rTr>tol**2*rTr0):
81       q = A*d
82       alpha = rTr/num.dot(d,q)
83       x = x + alpha*d
84       if i%50 :
85           r = b - A*x
86       else:
87           r = r - alpha*q
88       rTrOld = rTr
89       rTr = num.dot(r,r)
90       bt = rTr/rTrOld
91
92       d = r + bt*d
93       i = i+1
94       if i%iprint == 0 :
95          logger.info('i = %g rTr = %20.15e' %(i,rTr))
96
97       if i==imax:
98            logger.warning('max number of iterations attained')
99            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr
100            raise ConvergenceError, msg
101
102   return x
103
Note: See TracBrowser for help on using the repository browser.