"""Class Domain -
1D interval domains for finite-volume computations of
the shallow water wave equation.

This module contains a specialisation of class Domain from module domain.py
consisting of methods specific to the Shallow Water Wave Equation


U_t + E_x = S

where

U = [w, uh]
E = [uh, u^2h + gh^2/2]
S represents source terms forcing the system
(e.g. gravity, friction, wind stress, ...)

and _t, _x, _y denote the derivative with respect to t, x and y respectiely.

The quantities are

symbol    variable name    explanation
x         x                horizontal distance from origin [m]
z         elevation        elevation of bed on which flow is modelled [m]
h         height           water height above z [m]
w         stage            absolute water level, w = z+h [m]
u                          speed in the x direction [m/s]
uh        xmomentum        momentum in the x direction [m^2/s]

eta                        mannings friction coefficient [to appear]
nu                         wind stress coefficient [to appear]

The conserved quantities are w, uh

For details see e.g.
Christopher Zoppou and Stephen Roberts,
Catastrophic Collapse of Water Supply Reservoirs in Urban Areas,
Journal of Hydraulic Engineering, vol. 127, No. 7 July 1999


John Jakeman, Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
Geoscience Australia, 2006
"""


from domain import *
Generic_Domain = Domain #Rename

#Shallow water domain
class Domain(Generic_Domain):

    def __init__(self, coordinates, boundary = None, tagged_elements = None):

        conserved_quantities = ['area', 'discharge']
        evolved_quantities = ['area', 'discharge', 'elevation', 'height', 'stage']
        other_quantities = ['friction']
        Generic_Domain.__init__(self,
                                coordinates = coordinates,
                                boundary = boundary,
                                conserved_quantities = conserved_quantities,
                                evolved_quantities = evolved_quantities,
                                other_quantities = other_quantities,
                                tagged_elements = tagged_elements)
        
        from config import minimum_allowed_height, g, h0
        self.minimum_allowed_height = minimum_allowed_height
        self.g = g
        self.h0 = h0

        #forcing terms not included in 1d domain ?WHy?
        self.forcing_terms.append(gravity)
        #self.forcing_terms.append(manning_friction)
        #print "\nI have Removed forcing terms line 64 1dsw"

        
        #Stored output
        self.store = True
        self.format = 'sww'
        self.smooth = True

        
        #Reduction operation for get_vertex_values
        from util import mean
        self.reduction = mean
        #self.reduction = min  #Looks better near steep slopes

        self.set_quantities_to_be_stored(['area','discharge'])

	self.__doc__ = 'shallow_water_domain'

       # self.check_integrity()


    def check_integrity(self):

        #Check that we are solving the shallow water wave equation

        msg = 'First conserved quantity must be "stage"'
        assert self.conserved_quantities[0] == 'stage', msg
        msg = 'Second conserved quantity must be "xmomentum"'
        assert self.conserved_quantities[1] == 'xmomentum', msg

        msg = 'First evolved quantity must be "stage"'
        assert self.evolved_quantities[0] == 'stage', msg
        msg = 'Second evolved quantity must be "xmomentum"'
        assert self.evolved_quantities[1] == 'xmomentum', msg
        msg = 'Third evolved quantity must be "elevation"'
        assert self.evolved_quantities[2] == 'elevation', msg
        msg = 'Fourth evolved quantity must be "height"'
        assert self.evolved_quantities[3] == 'height', msg
        msg = 'Fifth evolved quantity must be "velocity"'
        assert self.evolved_quantities[4] == 'velocity', msg

        Generic_Domain.check_integrity(self)

    def compute_fluxes(self):
        #Call correct module function
        #(either from this module or C-extension)
        compute_fluxes_channel(self)

    def distribute_to_vertices_and_edges(self):
        #Call correct module function
        #(either from this module or C-extension)
        distribute_to_vertices_and_edges_limit_a_d(self)
        

#=============== End of Shallow Water Domain ===============================
#-----------------------------------
# Compute fluxes interface
#-----------------------------------
def compute_fluxes(domain):
    """
    Python version of compute fluxes (local_compute_fluxes)
    is available in test_shallow_water_vel_domain.py
    """
 

    from Numeric import zeros, Float
    import sys

	
    timestep = float(sys.maxint)

    stage = domain.quantities['stage']
    xmom = domain.quantities['xmomentum']
    bed = domain.quantities['elevation']


    #from comp_flux_vel_ext import compute_fluxes_ext

    #domain.flux_timestep = compute_fluxes_ext(timestep,domain,stage,xmom,bed)
    domain.flux_timestep = .1

#-----------------------------------
# Compute flux definition with vel
#-----------------------------------
def compute_fluxes_vel(domain):
    from Numeric import zeros, Float
    import sys

	
    timestep = float(sys.maxint)

    stage    = domain.quantities['stage']
    xmom     = domain.quantities['xmomentum']
    bed      = domain.quantities['elevation']
    height   = domain.quantities['height']
    velocity = domain.quantities['velocity']


    from comp_flux_vel_ext import compute_fluxes_vel_ext

    domain.flux_timestep = compute_fluxes_vel_ext(timestep,domain,stage,xmom,bed,height,velocity)

#----------------------------------
#  Compute fluxes channel
#----------------------------------
def compute_fluxes_channel(domain):
    from Numeric import zeros, Float
    import sys
    
    timestep = float(sys.maxint)

    area      = domain.quantities['area']
    discharge = domain.quantities['discharge']
    bed       = domain.quantities['elevation']
    height     = domain.quantities['height']
    stage     = domain.quantities['stage']

    #from channel_domain_ext import compute_fluxes_channel_ext

   # domain.flux_timestep = compute_fluxes_channel_ext(timestep,domain,area,discharge,bed,height,stage)
   # domain.quantities['area'].explicit_update=ones(410,Float)
    domain.flux_timestep = .1
    #print area.vertex_values[0],area.vertex_values[1],area.vertex_values[2]
    for i in range(len(domain.coordinates)-1):
        fluxesl= channel_flux_func(domain,i-1)
        fluxesr= channel_flux_func(domain,i)
        #print domain.areas[i]
        #print -1*(fluxesl[1]-fluxesr[1])
        #print height.centroid_values[200],height.vertex_values[200]
        #print fluxesl[0]
        print -1*(fluxesl[0]-fluxesr[0])*domain.areas[i]
        area.explicit_update[i]=-1*(fluxesl[0]-fluxesr[0])*domain.areas[i]
        discharge.explicit_update[i]=-(fluxesl[1]-fluxesr[1])*domain.areas[i]
        
def channel_flux_func(domain, i):

    area      = domain.quantities['area']
    discharge = domain.quantities['discharge']
    bed       = domain.quantities['elevation']
    height     = domain.quantities['height']
    stage     = domain.quantities['stage']

    flux_left0=0
    flux_left1=0
    flux_right0=0
    flux_right1=0
    flux0=0
    flux1=0
  # Crude numerical flux calculation
    from math import sqrt

    if i==0 or i==(len(domain.coordinates)-1):
        flux_left0=0
        flux_left1=0
        flux_right0=0
        flux_right1=0
    else:
        g=9.8
        a_left = area.centroid_values[i-1]
        d_left = discharge.centroid_values[i-1]
        z_left = bed.centroid_values[i-1]
        h_left = height.centroid_values[i-1]
        w_left = stage.centroid_values[i-1]

        a_right = area.centroid_values[i]
        d_right = discharge.centroid_values[i]
        z_right = bed.centroid_values[i]
        h_right = height.centroid_values[i]
        w_right = stage.centroid_values[i]

        z=(z_left+z_right)/2
        #hbarr=0.5*(height.vertex_values[i][1]+height.vertex_values[i][0])
        #hbarl=0.5*(height.vertex_values[i-1][1]+height.vertex_values[i-1][0])
        hbarr=0
        hbarl=0
        #print hbarl,h_left
        if a_left>1.0e-12:
            u_left=d_left/a_left
        else:
            u_left=0
        if a_right>1.0e-12:
            u_right=d_right/a_right
        else:
            u_right=0

        ## soundspeed_left = sqrt(g*h_left);
##         soundspeed_right = sqrt(g*h_right);

##         s_max = max(u_left+soundspeed_left,u_right+soundspeed_right)
##         if s_max<0.0:
##             s_max=0
##         s_min = min(u_left-soundspeed_left,u_right-soundspeed_right)
##         if s_min>0.0:
##             s_min=0
            
        flux_left0 = d_left
        flux_right1= d_right
        if a_left<1.0e-12:
            flux_left11=0
        else:
            flux_left11=d_left*d_left/a_left
        flux_left1=flux_left11+g*(calculateI(h_left,0)-calculateI(hbarl,0))
        if a_right<1.0e-12:
            flux_right11=0
        else:
            flux_right11=d_right*d_right/a_right
        flux_right1=flux_right11+g*(calculateI(h_right,0)-calculateI(hbarr,0))
        #print g*(calculateI(h_right,0)-calculateI(hbarr,0)), g*(calculateI(h_left,0)-calculateI(hbarl,0)) 

    flux0 = 0.5*(flux_left0+flux_right0)
    flux1 = 0.5*(flux_left1+flux_right1)
    return (flux0,flux1)
            
def calculateI(H,z0):
    return H*H*0.5-z0*H+z0*z0*0.5

    
#--------------------------------------------------------------------------
def distribute_to_vertices_and_edges_limit_w_u(domain):
    """Distribution from centroids to vertices specific to the
    shallow water wave
    equation.

    It will ensure that h (w-z) is always non-negative even in the
    presence of steep bed-slopes by taking a weighted average between shallow
    and deep cases.

    In addition, all conserved quantities get distributed as per either a
    constant (order==1) or a piecewise linear function (order==2).

    FIXME: more explanation about removal of artificial variability etc

    Precondition:
      All quantities defined at centroids and bed elevation defined at
      vertices.

    Postcondition
      Conserved quantities defined at vertices

    """

    #from config import optimised_gradient_limiter

    #Remove very thin layers of water
    #protect_against_infinitesimal_and_negative_heights(domain)  

    import sys
    from Numeric import zeros, Float
    from config import epsilon, h0
	
    N = domain.number_of_elements

    #Shortcuts
    Stage = domain.quantities['stage']
    Xmom = domain.quantities['xmomentum']
    Bed = domain.quantities['elevation']
    Height = domain.quantities['height']
    Velocity = domain.quantities['velocity']

    #Arrays   
    w_C   = Stage.centroid_values    
    uh_C  = Xmom.centroid_values    
    z_C   = Bed.centroid_values
    h_C   = Height.centroid_values
    u_C   = Velocity.centroid_values
	
    #print id(h_C)
    for i in range(N):
	h_C[i] = w_C[i] - z_C[i]						
        if h_C[i] <= 1.0e-12:
	    #print 'h_C[%d]= %15.5e\n' % (i,h_C[i])
	    h_C[i] = 0.0
	    w_C[i] = z_C[i]
	    #uh_C[i] = 0.0
            
 #           u_C[i] = 0.0
 #       else:
 #           u_C[i] = uh_C[i]/h_C[i]
		
    h0 = 1.0e-12    
    for i in range(len(h_C)):
	if h_C[i] < 1.0e-12:
	    u_C[i] = 0.0  #Could have been negative
	    h_C[i] = 0.0
	else:
            u_C[i]  = uh_C[i]/(h_C[i] + h0/h_C[i])
            #u_C[i]  = uh_C[i]/h_C[i]
	
    for name in [ 'velocity', 'stage' ]:
        Q = domain.quantities[name]
        if domain.order == 1:
            Q.extrapolate_first_order()
        elif domain.order == 2:
            Q.extrapolate_second_order()
        else:
            raise 'Unknown order'

    w_V  = domain.quantities['stage'].vertex_values			
    z_V  = domain.quantities['elevation'].vertex_values	
    h_V  = domain.quantities['height'].vertex_values
    u_V  = domain.quantities['velocity'].vertex_values		
    uh_V = domain.quantities['xmomentum'].vertex_values	

    h_V[:]    = w_V - z_V
    for i in range(len(h_C)):
        for j in range(2):
            if h_V[i,j] < 0.0 :
                #print 'h_V[%d,%d] = %f \n' % (i,j,h_V[i,j])                 
                dh = h_V[i,j]
                h_V[i,j] = 0.0
                w_V[i,j] = z_V[i,j]
                h_V[i,(j+1)%2] = h_V[i,(j+1)%2] + dh
                w_V[i,(j+1)%2] = w_V[i,(j+1)%2] + dh
                
    uh_V[:] = u_V * h_V

    
    return

#---------------------------------------------------------------------------
def distribute_to_vertices_and_edges_limit_w_uh(domain):
    """Distribution from centroids to vertices specific to the
    shallow water wave equation.

    In addition, all conserved quantities get distributed as per either a
    constant (order==1) or a piecewise linear function (order==2).

    Precondition:
      All quantities defined at centroids and bed elevation defined at
      vertices.

    Postcondition
      Conserved quantities defined at vertices

    """

    import sys
    from Numeric import zeros, Float
    from config import epsilon, h0
	
    N = domain.number_of_elements

    #Shortcuts
    Stage = domain.quantities['stage']
    Xmom = domain.quantities['xmomentum']
    Bed = domain.quantities['elevation']
    Height = domain.quantities['height']
    Velocity = domain.quantities['velocity']

    #Arrays   
    w_C   = Stage.centroid_values    
    uh_C  = Xmom.centroid_values    
    z_C   = Bed.centroid_values
    h_C   = Height.centroid_values
    u_C   = Velocity.centroid_values
	

    for i in range(N):
	h_C[i] = w_C[i] - z_C[i]						
        if h_C[i] <= 1.0e-6:
	    #print 'h_C[%d]= %15.5e\n' % (i,h_C[i])
	    h_C[i] = 0.0
	    w_C[i] = z_C[i]
	    uh_C[i] = 0.0
            
    for name in [ 'stage', 'xmomentum']:
        Q = domain.quantities[name]
        if domain.order == 1:
            Q.extrapolate_first_order()
        elif domain.order == 2:
            Q.extrapolate_second_order()
        else:
            raise 'Unknown order'

    w_V  = domain.quantities['stage'].vertex_values			
    z_V  = domain.quantities['elevation'].vertex_values	
    h_V  = domain.quantities['height'].vertex_values
    u_V  = domain.quantities['velocity'].vertex_values		
    uh_V = domain.quantities['xmomentum'].vertex_values	

    h_V[:]    = w_V - z_V

    for i in range(len(h_C)):
        for j in range(2):
            if h_V[i,j] < 0.0 :
                #print 'h_V[%d,%d] = %f \n' % (i,j,h_V[i,j])                 
                dh = h_V[i,j]
                h_V[i,j] = 0.0
                w_V[i,j] = z_V[i,j]
                h_V[i,(j+1)%2] = h_V[i,(j+1)%2] + dh
                w_V[i,(j+1)%2] = w_V[i,(j+1)%2] + dh
                u_V[i,j] = 0.0
            if h_V[i,j] < h0:
                u_V[i,j]
            else:
                u_V[i,j] = uh_V[i,j]/(h_V[i,j] + h0/h_V[i,j])

#---------------------------------------------------------------------------
def distribute_to_vertices_and_edges_limit_a_d(domain):
    """Distribution from centroids to vertices specific to the
    shallow water wave equation.

    In addition, all conserved quantities get distributed as per either a
    constant (order==1) or a piecewise linear function (order==2).

    Precondition:
      All quantities defined at centroids and bed elevation defined at
      vertices.

    Postcondition
      Conserved quantities defined at vertices

    """

    import sys
    from Numeric import zeros, Float
    from config import epsilon, h0
	
    N = domain.number_of_elements

    #Shortcuts
    Area = domain.quantities['area']
    Discharge = domain.quantities['discharge']
    Bed = domain.quantities['elevation']
    Height = domain.quantities['height']
    Stage = domain.quantities['stage']

    #Arrays   
    a_C   = Area.centroid_values    
    d_C   = Discharge.centroid_values    
    z_C   = Bed.centroid_values
    h_C   = Height.centroid_values
    w_C   = Stage.centroid_values
	
#work out stage
    for i in range(N):
	h_C[i] = w_C[i] - z_C[i]
#make sure depth isn't zero        
        if h_C[i] <= 1.0e-6:
	    #print 'h_C[%d]= %15.5e\n' % (i,h_C[i])
	    h_C[i] = 0.0
	    w_C[i] = z_C[i]
	    d_C[i] = 0.0
#distribute stage,discharge

    for name in [ 'area', 'discharge']:
        Q = domain.quantities[name]
        if domain.order == 1:
            Q.extrapolate_first_order()
        elif domain.order == 2:
            Q.extrapolate_second_order()
        else:
            raise 'Unknown order'

    a_V  = domain.quantities['area'].vertex_values			
    d_V  = domain.quantities['discharge'].vertex_values	
    h_V  = domain.quantities['height'].vertex_values
    z_V  = domain.quantities['elevation'].vertex_values		
    w_V = domain.quantities['stage'].vertex_values	
#height at verticies
    h_V[:]    = w_V - z_V

##     for i in range(len(h_C)):
##         for j in range(2):
##             if h_V[i,j] < 0.0 :
##                 #print 'h_V[%d,%d] = %f \n' % (i,j,h_V[i,j])                 
##                 dh = h_V[i,j]
##                 h_V[i,j] = 0.0
##                 w_V[i,j] = z_V[i,j]
##                 h_V[i,(j+1)%2] = h_V[i,(j+1)%2] + dh
##                 w_V[i,(j+1)%2] = w_V[i,(j+1)%2] + dh
##                 u_V[i,j] = 0.0
##             if h_V[i,j] < h0:
##                 u_V[i,j]
##             else:
##                 u_V[i,j] = uh_V[i,j]/(h_V[i,j] + h0/h_V[i,j]
##                                        )
                
#--------------------------------------------------------
#Boundaries - specific to the shallow_water_vel_domain
#--------------------------------------------------------
class Reflective_boundary(Boundary):
    """Reflective boundary returns same conserved quantities as
    those present in its neighbour volume but reflected.

    This class is specific to the shallow water equation as it
    works with the momentum quantities assumed to be the second
    and third conserved quantities.
    """

    def __init__(self, domain = None):
        Boundary.__init__(self)

        if domain is None:
            msg = 'Domain must be specified for reflective boundary'
            raise msg

        #Handy shorthands
        self.normals  = domain.normals
        self.stage    = domain.quantities['stage'].vertex_values
        self.xmom     = domain.quantities['xmomentum'].vertex_values
        self.bed      = domain.quantities['elevation'].vertex_values
        self.height   = domain.quantities['height'].vertex_values
        self.velocity = domain.quantities['velocity'].vertex_values

        from Numeric import zeros, Float
        #self.conserved_quantities = zeros(3, Float)
        self.evolved_quantities = zeros(5, Float)

    def __repr__(self):
        return 'Reflective_boundary'


    def evaluate(self, vol_id, edge_id):
        """Reflective boundaries reverses the outward momentum
        of the volume they serve.
        """
#Commenting out some quantities not currently keeping track of
##         q = self.evolved_quantities
##         q[0] = self.stage[vol_id, edge_id]
##         q[1] = -self.xmom[vol_id, edge_id]
##         q[2] = self.bed[vol_id, edge_id]
##         q[3] = self.height[vol_id, edge_id]
##         q[4] = -self.stage[stage_id, stage_id]

        #print "In Reflective q ",q


        return q

class Dirichlet_boundary(Boundary):
    """Dirichlet boundary returns constant values for the
    conserved quantities
    """


    def __init__(self, evolved_quantities=None):
        Boundary.__init__(self)

        if evolved_quantities is None:
            msg = 'Must specify one value for each evolved quantity'
            raise msg

        from Numeric import array, Float
        self.evolved_quantities=array(evolved_quantities).astype(Float)

    def __repr__(self):
        return 'Dirichlet boundary (%s)' %self.evolved_quantities

    def evaluate(self, vol_id=None, edge_id=None):
        return self.evolved_quantities

#--------------------------------------------------------
#Boundaries for channel - specific to the channel domain
#--------------------------------------------------------
class Reflective_boundary(Boundary):
    """Reflective boundary returns same conserved quantities as
    those present in its neighbour volume but reflected.

    This class is specific to the shallow water equation as it
    works with the momentum quantities assumed to be the second
    and third conserved quantities.
    """

    def __init__(self, domain = None):
        Boundary.__init__(self)

        if domain is None:
            msg = 'Domain must be specified for reflective boundary'
            raise msg

        #Handy shorthands
        self.normals  = domain.normals
        self.area    = domain.quantities['area'].vertex_values
        self.discharge     = domain.quantities['discharge'].vertex_values
        self.bed      = domain.quantities['elevation'].vertex_values
        self.height   = domain.quantities['height'].vertex_values
        self.stage = domain.quantities['stage'].vertex_values

        from Numeric import zeros, Float
        #self.conserved_quantities = zeros(3, Float)
        self.evolved_quantities = zeros(5, Float)

    def __repr__(self):
        return 'Reflective_boundary'


    def evaluate(self, vol_id, edge_id):
        """Reflective boundaries reverses the outward momentum
        of the volume they serve.
        """

        q = self.evolved_quantities
        q[0] = self.area[vol_id, edge_id]
        q[1] = -self.discharge[vol_id, edge_id]
        q[2] = self.bed[vol_id, edge_id]
        q[3] = self.height[vol_id, edge_id]
        q[4] = self.stage[vol_id, edge_id]

        #print "In Reflective q ",q


        return q

class Dirichlet_boundary(Boundary):
    """Dirichlet boundary returns constant values for the
    conserved quantities
    """


    def __init__(self, evolved_quantities=None):
        Boundary.__init__(self)

        if evolved_quantities is None:
            msg = 'Must specify one value for each evolved quantity'
            raise msg

        from Numeric import array, Float
        self.evolved_quantities=array(evolved_quantities).astype(Float)

    def __repr__(self):
        return 'Dirichlet boundary (%s)' %self.evolved_quantities

    def evaluate(self, vol_id=None, edge_id=None):
        return self.evolved_quantities
    
#----------------------------
#Standard forcing terms:
#---------------------------
def gravity(domain):
    """Apply gravitational pull in the presence of bed slope
    """

    from util import gradient
    from Numeric import zeros, Float, array, sum



    Area     = domain.quantities['area']
    Discharge      = domain.quantities['discharge']
    Elevation = domain.quantities['elevation']
    Height    = domain.quantities['height']
    Stage     = domain.quantities['stage']

    discharge_ud  = Discharge.explicit_update
    #stage_ud = Stage.explicit_update


    #h = Stage.vertex_values - Elevation.vertex_values
    h = Height.vertex_values
    b = Elevation.vertex_values
    w = Stage.vertex_values

    x = domain.get_vertex_coordinates()
    g = domain.g

    for k in range(domain.number_of_elements):
        avg_h = 0.5*(h[k,0] + h[k,1])

        #Compute bed slope
        x0, x1 = x[k,:]
        b0, b1 = b[k,:]
        bx = gradient(x0, x1, b0, b1)
        
        #Update momentum (explicit update is reset to source values)
        discharge_ud[k] += -g*bx*avg_h
        #stage_ud[k] = 0.0
 
 
def manning_friction(domain):
    """Apply (Manning) friction to water momentum
    """

    from math import sqrt

    w = domain.quantities['stage'].centroid_values
    z = domain.quantities['elevation'].centroid_values
    h = w-z

    uh = domain.quantities['xmomentum'].centroid_values
    #vh = domain.quantities['ymomentum'].centroid_values
    eta = domain.quantities['friction'].centroid_values

    xmom_update = domain.quantities['xmomentum'].semi_implicit_update
    #ymom_update = domain.quantities['ymomentum'].semi_implicit_update

    N = domain.number_of_elements
    eps = domain.minimum_allowed_height
    g = domain.g

    for k in range(N):
        if eta[k] >= eps:
            if h[k] >= eps:
            	#S = -g * eta[k]**2 * sqrt((uh[k]**2 + vh[k]**2))
                S = -g * eta[k]**2 * uh[k]
            	S /= h[k]**(7.0/3)

            	#Update momentum
            	xmom_update[k] += S*uh[k]
            	#ymom_update[k] += S*vh[k]

def linear_friction(domain):
    """Apply linear friction to water momentum

    Assumes quantity: 'linear_friction' to be present
    """

    from math import sqrt

    w = domain.quantities['stage'].centroid_values
    z = domain.quantities['elevation'].centroid_values
    h = w-z

    uh = domain.quantities['xmomentum'].centroid_values
    tau = domain.quantities['linear_friction'].centroid_values

    xmom_update = domain.quantities['xmomentum'].semi_implicit_update

    N = domain.number_of_elements
    eps = domain.minimum_allowed_height

    for k in range(N):
        if tau[k] >= eps:
            if h[k] >= eps:
            	S = -tau[k]/h[k]

            	#Update momentum
            	xmom_update[k] += S*uh[k]



def check_forcefield(f):
    """Check that f is either
    1: a callable object f(t,x,y), where x and y are vectors
       and that it returns an array or a list of same length
       as x and y
    2: a scalar
    """

    from Numeric import ones, Float, array


    if callable(f):
        #N = 3
        N = 2
        #x = ones(3, Float)
        #y = ones(3, Float)
        x = ones(2, Float)
        #y = ones(2, Float)
        
        try:
            #q = f(1.0, x=x, y=y)
            q = f(1.0, x=x)
        except Exception, e:
            msg = 'Function %s could not be executed:\n%s' %(f, e)
	    #FIXME: Reconsider this semantics
            raise msg

        try:
            q = array(q).astype(Float)
        except:
            msg = 'Return value from vector function %s could ' %f
            msg += 'not be converted into a Numeric array of floats.\n'
            msg += 'Specified function should return either list or array.'
            raise msg

        #Is this really what we want?
        msg = 'Return vector from function %s ' %f
        msg += 'must have same lenght as input vectors'
        assert len(q) == N, msg

    else:
        try:
            f = float(f)
        except:
            msg = 'Force field %s must be either a scalar' %f
            msg += ' or a vector function'
            raise msg
    return f

