source: inundation/ga/storm_surge/pyvolution/cg_solve.py @ 1454

Last change on this file since 1454 was 1160, checked in by steve, 20 years ago
File size: 1.7 KB
RevLine 
[435]1import exceptions
[438]2class VectorShapeError(exceptions.Exception): pass
[435]3
[1150]4import logging, logging.config
5logger = logging.getLogger('cg_solve')
[1160]6logger.setLevel(logging.WARNING)
[435]7
[1156]8try:
9    logging.config.fileConfig('log.ini')
10except:
11    pass
[1150]12
[475]13def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
[438]14   """
15   Try to solve linear equation Ax = b using
16   conjugate gradient method
[1150]17
[438]18   Input
19   A: matrix or function which applies a matrix, assumed symmetric
[475]20      A can be either dense or sparse
[438]21   b: right hand side
[475]22   x0: inital guess (default the 0 vector)
[438]23   imax: max number of iterations
24   tol: tolerance used for residual
[1150]25
[438]26   Output
27   x: approximate solution
28   """
[432]29
[475]30   from Numeric import dot, array, Float, zeros
[1150]31
[475]32   b  = array(b, typecode=Float)
[477]33   if len(b.shape) != 1 :
34      raise VectorShapeError, 'input vector should consist of only one column'
35
[475]36   if x0 is None:
37      x0 = zeros(b.shape, typecode=Float)
[1150]38   else:
[475]39      x0 = array(x0, typecode=Float)
[432]40
[1150]41
[475]42   #FIXME: Should test using None
[435]43   if iprint == 0:
44      iprint = imax
[1150]45
[435]46   i=1
[432]47   x = x0
48   r = b - A*x
49   d = r
50   rTr = dot(r,r)
51   rTr0 = rTr
52
53   while (i<imax and rTr>tol**2*rTr0):
54       q = A*d
55       alpha = rTr/dot(d,q)
56       x = x + alpha*d
57       if i%50 :
58           r = b - A*x
59       else:
60           r = r - alpha*q
61       rTrOld = rTr
62       rTr = dot(r,r)
63       bt = rTr/rTrOld
64
65       d = r + bt*d
66       i = i+1
[435]67       if i%iprint == 0 :
[1150]68          logger.info('i = %g rTr = %20.15e'% (i,rTr))
[435]69
[1150]70   #FIXME: Should this raise an exception?
[432]71   if i==imax:
[1160]72       logger.warning('max number of iterations attained')
[432]73
74   return x
Note: See TracBrowser for help on using the repository browser.