import exceptions
class VectorShapeError(exceptions.Exception): pass

import logging, logging.config
logger = logging.getLogger('cg_solve')
logger.setLevel(logging.WARNING)

try:
    logging.config.fileConfig('log.ini')
except:
    pass

def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,iprint=0):
   """
   Try to solve linear equation Ax = b using
   conjugate gradient method

   Input
   A: matrix or function which applies a matrix, assumed symmetric
      A can be either dense or sparse
   b: right hand side
   x0: inital guess (default the 0 vector)
   imax: max number of iterations
   tol: tolerance used for residual

   Output
   x: approximate solution
   """

   from numpy import dot, array, Float, zeros

   b  = array(b, dtype=Float)
   if len(b.shape) != 1 :
      raise VectorShapeError, 'input vector should consist of only one column'

   if x0 is None:
      x0 = zeros(b.shape, dtype=Float)
   else:
      x0 = array(x0, dtype=Float)


   #FIXME: Should test using None
   if iprint == 0:
      iprint = imax

   i=1
   x = x0
   r = b - A*x
   d = r
   rTr = dot(r,r)
   rTr0 = rTr

   while (i<imax and rTr>tol**2*rTr0):
       q = A*d
       alpha = rTr/dot(d,q)
       x = x + alpha*d
       if i%50 :
           r = b - A*x
       else:
           r = r - alpha*q
       rTrOld = rTr
       rTr = dot(r,r)
       bt = rTr/rTrOld

       d = r + bt*d
       i = i+1
       if i%iprint == 0 :
          logger.info('i = %g rTr = %20.15e'% (i,rTr))

   #FIXME: Should this raise an exception?
   if i==imax:
       logger.warning('max number of iterations attained')

   return x
