Ignore:
Timestamp:
Mar 17, 2011, 5:51:42 PM (12 years ago)
Author:
steve
Message:

Adding in kinematic viscosity. Added in some procedures to change boundary_values of
quantities

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/anuga_core/source/anuga/utilities/cg_solve.py

    r8025 r8154  
    88
    99
    10 def conjugate_gradient(A,b,x0=None,imax=10000,tol=1.0e-8,atol=1.0e-14,iprint=None):
     10def conjugate_gradient(A, b, x0=None, imax=10000, tol=1.0e-8, atol=1.0e-14, iprint=None):
    1111    """
    1212    Try to solve linear equation Ax = b using
     
    2323
    2424    b  = num.array(b, dtype=num.float)
    25     if len(b.shape) != 1 :
     25    if len(b.shape) != 1:
    2626       
    2727        for i in range(b.shape[1]):
    28             x0[:,i] = _conjugate_gradient(A, b[:,i], x0[:,i],
    29                                           imax, tol, atol, iprint)
     28            x0[:, i] = _conjugate_gradient(A, b[:, i], x0[:, i],
     29                                           imax, tol, atol, iprint)
    3030    else:
    3131        x0 = _conjugate_gradient(A, b, x0, imax, tol, atol, iprint)
     
    3434   
    3535def _conjugate_gradient(A, b, x0,
    36                         imax=10000, tol=1.0e-8, atol=1.0e-14, iprint=None):
    37    """
     36                        imax=10000, tol=1.0e-8, atol=1.0e-10, iprint=None):
     37    """
    3838   Try to solve linear equation Ax = b using
    3939   conjugate gradient method
     
    5151   """
    5252
     53    b  = num.array(b, dtype=num.float)
     54    if len(b.shape) != 1:
     55        raise VectorShapeError, 'input vector should consist of only one column'
    5356
    54    b  = num.array(b, dtype=num.float)
    55    if len(b.shape) != 1 :
    56       raise VectorShapeError, 'input vector should consist of only one column'
    57 
    58    if x0 is None:
    59       x0 = num.zeros(b.shape, dtype=num.float)
    60    else:
    61       x0 = num.array(x0, dtype=num.float)
     57    if x0 is None:
     58        x0 = num.zeros(b.shape, dtype=num.float)
     59    else:
     60        x0 = num.array(x0, dtype=num.float)
    6261
    6362
    64    #FIXME: Should test using None
    65    if iprint == None  or iprint == 0:
    66       iprint = imax
     63    if iprint == None  or iprint == 0:
     64        iprint = imax
    6765
    68    i=1
    69    x = x0
    70    r = b - A*x
    71    d = r
    72    rTr = num.dot(r,r)
    73    rTr0 = rTr
     66    i = 1
     67    x = x0
     68    r = b - A * x
     69    d = r
     70    rTr = num.dot(r, r)
     71    rTr0 = rTr
    7472
    75    #FIXME Let the iterations stop if starting with a small residual
    76    while (i<imax and rTr>tol**2*rTr0 and rTr>atol**2 ):
    77        q = A*d
    78        alpha = rTr/num.dot(d,q)
    79        x = x + alpha*d
    80        if i%50 :
    81            r = b - A*x
    82        else:
    83            r = r - alpha*q
    84        rTrOld = rTr
    85        rTr = num.dot(r,r)
    86        bt = rTr/rTrOld
    8773
    88        d = r + bt*d
    89        i = i+1
    90        if i%iprint == 0 :
    91           log.info('i = %g rTr = %20.15e' %(i,rTr))
     74   
     75    #FIXME Let the iterations stop if starting with a small residual
     76    while (i < imax and rTr > tol ** 2 * rTr0 and rTr > atol ** 2):
     77        q = A * d
     78        alpha = rTr / num.dot(d, q)
     79        xold = x
     80        x = x + alpha * d
    9281
    93        if i==imax:
     82        dx = num.linalg.norm(x-xold)
     83        if dx < atol :
     84            break
     85           
     86        if i % 50:
     87            r = b - A * x
     88        else:
     89            r = r - alpha * q
     90        rTrOld = rTr
     91        rTr = num.dot(r, r)
     92        bt = rTr / rTrOld
     93
     94        d = r + bt * d
     95        i = i + 1
     96        if i % iprint == 0:
     97            log.info('i = %g rTr = %15.8e dx = %15.8e' % (i, rTr, dx))
     98
     99        if i == imax:
    94100            log.warning('max number of iterations attained')
    95             msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' %rTr
     101            msg = 'Conjugate gradient solver did not converge: rTr==%20.15e' % rTr
    96102            raise ConvergenceError, msg
    97103
    98    return x
     104    #print x
     105    return x
    99106
Note: See TracChangeset for help on using the changeset viewer.