"""Class Quantity - Implements values at each 1d element

To create:

   Quantity(domain, vertex_values)

   domain: Associated domain structure. Required.

   vertex_values: N x 2 array of values at each vertex for each element.
                  Default None

   If vertex_values are None Create array of zeros compatible with domain.
   Otherwise check that it is compatible with dimenions of domain.
   Otherwise raise an exception
"""



class Quantity:

    
    def __init__(self, domain, vertex_values=None):
        #Initialise Quantity using optional vertex values.
        
        from domain import Domain
        from Numeric import array, zeros, Float

        msg = 'First argument in Quantity.__init__ '
        msg += 'must be of class Domain (or a subclass thereof)'
        assert isinstance(domain, Domain), msg

        if vertex_values is None:
            N = domain.number_of_elements
            self.vertex_values = zeros((N, 2), Float)
        else:
            self.vertex_values = array(vertex_values, Float)

            N, V = self.vertex_values.shape
            assert V == 2,\
                   'Two vertex values per element must be specified'


            msg = 'Number of vertex values (%d) must be consistent with'\
                  %N
            msg += 'number of elements in specified domain (%d).'\
                   %domain.number_of_elements

            assert N == domain.number_of_elements, msg

        self.domain = domain

        #Allocate space for other quantities
        self.centroid_values = zeros(N, Float)
        self.centroid_backup_values = zeros(N, Float)
        #self.edge_values = zeros((N, 2), Float)
        #edge values are values of the ends of each interval
      
        #Intialise centroid values
        self.interpolate()


        from Numeric import zeros, Float
        
        #Allocate space for boundary values
        #L = len(domain.boundary)
        self.boundary_values = zeros(2, Float) #assumes no parrellism

        #Allocate space for updates of conserved quantities by
        #flux calculations and forcing functions
        
        N = domain.number_of_elements
        self.explicit_update = zeros(N, Float )
        self.semi_implicit_update = zeros(N, Float )

        self.gradients = zeros(N, Float)
        self.qmax = zeros(self.centroid_values.shape, Float)
        self.qmin = zeros(self.centroid_values.shape, Float)

        self.beta = domain.beta        


    def __len__(self):
        """
        Returns number of intervals.
        """
        return self.centroid_values.shape[0]
    
    def interpolate(self):
        """
        Compute interpolated values at centroid
        Pre-condition: vertex_values have been set
        """

        N = self.vertex_values.shape[0]
        for i in range(N):
            v0 = self.vertex_values[i, 0]
            v1 = self.vertex_values[i, 1]

            self.centroid_values[i] = (v0 + v1)/2.0

    def set_values(self, X, location='vertices'):
        """Set values for quantity

        X: Compatible list, Numeric array (see below), constant or function
        location: Where values are to be stored.
                  Permissible options are: vertices, centroid
                  Default is "vertices"

        In case of location == 'centroid' the dimension values must
        be a list of a Numerical array of length N, N being the number
        of elements in the mesh. Otherwise it must be of dimension Nx2

        The values will be stored in elements following their
        internal ordering.

        If values are described a function, it will be evaluated at specified points

        If selected location is vertices, values for centroid and edges
        will be assigned interpolated values.
        In any other case, only values for the specified locations
        will be assigned and the others will be left undefined.
        """

        if location not in ['vertices', 'centroids']:
            msg = 'Invalid location: %s, (possible choices vertices, centroids)' %location
            raise msg

        if X is None:
            msg = 'Given values are None'
            raise msg

        import types

        if callable(X):
            #Use function specific method
            self.set_function_values(X, location)
          
        elif type(X) in [types.FloatType, types.IntType, types.LongType]:
            if location == 'centroids':
                self.centroid_values[:] = X
            else:
                self.vertex_values[:] = X

        else:
            #Use array specific method
            self.set_array_values(X, location)

        if location == 'vertices':
            #Intialise centroid and edge values
            self.interpolate()





    def set_function_values(self, f, location='vertices'):
        """Set values for quantity using specified function

        f: x -> z Function where x and z are arrays
        location: Where values are to be stored.
                  Permissible options are: vertices, centroid
                  Default is "vertices"
        """
        
        if location == 'centroids':
         
            P = self.domain.centroids
            self.set_values(f(P), location)
        else:
            #Vertices
            
            P = self.domain.get_vertices()
            
            for i in range(2):
               
                self.vertex_values[:,i] = f(P[:,i])
               
    def set_array_values(self, values, location='vertices'):
        """Set values for quantity

        values: Numeric array
        location: Where values are to be stored.
                  Permissible options are: vertices, centroid, edges
                  Default is "vertices"

        In case of location == 'centroid' the dimension values must
        be a list of a Numerical array of length N, N being the number
        of elements in the mesh. Otherwise it must be of dimension Nx2

        The values will be stored in elements following their
        internal ordering.

        If selected location is vertices, values for centroid
        will be assigned interpolated values.
        In any other case, only values for the specified locations
        will be assigned and the others will be left undefined.
        """

        from Numeric import array, Float

        values = array(values).astype(Float)

        N = self.centroid_values.shape[0]

        msg = 'Number of values must match number of elements'
        assert values.shape[0] == N, msg

        if location == 'centroids':
            assert len(values.shape) == 1, 'Values array must be 1d'
            self.centroid_values = values
        #elif location == 'edges':
        #    assert len(values.shape) == 2, 'Values array must be 2d'
        #    msg = 'Array must be N x 2'
        #    self.edge_values = values
        else:
            assert len(values.shape) == 2, 'Values array must be 2d'
            msg = 'Array must be N x 2'
            assert values.shape[1] == 2, msg

            self.vertex_values = values


    def get_values(self, location='vertices', indices = None):
        """get values for quantity

        return X, Compatible list, Numeric array (see below)
        location: Where values are to be stored.
                  Permissible options are: vertices, centroid
                  and unique vertices. Default is 'vertices'

        In case of location == 'centroids' the dimension values must
        be a list of a Numerical array of length N, N being the number
        of elements. Otherwise it must be of dimension Nx3

        The returned values with be a list the length of indices
        (N if indices = None).  Each value will be a list of the three
        vertex values for this quantity.

        Indices is the set of element ids that the operation applies to.

        """
        from Numeric import take

        if location not in ['vertices', 'centroids', 'unique vertices']:
            msg = 'Invalid location: %s' %location
            raise msg

        import types, Numeric
        assert type(indices) in [types.ListType, types.NoneType,
                                 Numeric.ArrayType],\
                                 'Indices must be a list or None'

        if location == 'centroids':
            if (indices ==  None):
                indices = range(len(self))
            return take(self.centroid_values,indices)
        elif location == 'unique vertices':
            if (indices ==  None):
                indices=range(self.domain.coordinates.shape[0])
            vert_values = []
            #Go through list of unique vertices
            for unique_vert_id in indices:
                cells = self.domain.vertexlist[unique_vert_id]

                #In case there are unused points
                if cells is None:
                    msg = 'Unique vertex not associated with cells'
                    raise msg

                # Go through all cells, vertex pairs
                # Average the values
                sum = 0
                for cell_id, vertex_id in cells:
                    sum += self.vertex_values[cell_id, vertex_id]
                vert_values.append(sum/len(cells))
            return Numeric.array(vert_values)
        else:
            if (indices ==  None):
                indices = range(len(self))
            return take(self.vertex_values,indices)


    def get_vertex_values(self,
                          x=True,
                          smooth = None,
                          precision = None,
                          reduction = None):
        """Return vertex values like an OBJ format

        The vertex values are returned as one sequence in the 1D float array A.
        If requested the coordinates will be returned in 1D arrays X.

        The connectivity is represented as an integer array, V, of dimension
        M x 2, where M is the number of volumes. Each row has two indices
        into the X, A arrays defining the element.

        if smooth is True, vertex values corresponding to one common
        coordinate set will be smoothed according to the given
        reduction operator. In this case vertex coordinates will be
        de-duplicated.

        If no smoothings is required, vertex coordinates and values will
        be aggregated as a concatenation of values at
        vertices 0, vertices 1


        Calling convention
        if x is True:
           X,A,V = get_vertex_values
        else:
           A,V = get_vertex_values

        """

        from Numeric import concatenate, zeros, Float, Int, array, reshape


        if smooth is None:
            smooth = self.domain.smooth

        if precision is None:
            precision = Float

        if reduction is None:
            reduction = self.domain.reduction

        #Create connectivity

        if smooth == True:

            V = self.domain.get_vertices()
            N = len(self.domain.vertexlist)
            #N = len(self.domain.vertices)
            A = zeros(N, precision)

            #Smoothing loop
            for k in range(N):
                L = self.domain.vertexlist[k]
                #L = self.domain.vertices[k]

                #Go through all triangle, vertex pairs
                #contributing to vertex k and register vertex value

                if L is None: continue #In case there are unused points

                contributions = []
                for volume_id, vertex_id in L:
                    v = self.vertex_values[volume_id, vertex_id]
                    contributions.append(v)

                A[k] = reduction(contributions)

            if x is True:
                 #X = self.domain.coordinates[:,0].astype(precision)
                 X = self.domain.coordinates[:].astype(precision)
                 #Y = self.domain.coordinates[:,1].astype(precision)

                 #return X, Y, A, V
                 return X, A, V
            
            #else:
            return A, V
        else:
            #Don't smooth
            #obj machinery moved to general_mesh

            # Create a V like [[0 1 2], [3 4 5]....[3*m-2 3*m-1 3*m]]
            # These vert_id's will relate to the verts created below
            #m = len(self.domain)  #Number of volumes
            #M = 3*m        #Total number of unique vertices
            #V = reshape(array(range(M)).astype(Int), (m,3))

            #V = self.domain.get_triangles(obj=True)
            V = self.domain.get_vertices
            #FIXME use get_vertices, when ready

            A = self.vertex_values.flat

            #Do vertex coordinates
            if x is True:
                X = self.domain.get_vertex_coordinates()

                #X = C[:,0:6:2].copy()
                #Y = C[:,1:6:2].copy()

                return X.flat, A, V
            else:
                return A, V

    def get_integral(self):
        """Compute the integral of quantity across entire domain
        """
        integral = 0
        for k in range(self.domain.number_of_elements):
            area = self.domain.areas[k]
            qc = self.centroid_values[k]
            integral += qc*area

        return integral


    def update(self, timestep):
        """Update centroid values based on values stored in
        explicit_update and semi_implicit_update as well as given timestep
        """

        from Numeric import sum, equal, ones, Float

        N = self.centroid_values.shape[0]

        #Explicit updates
        self.centroid_values += timestep*self.explicit_update
        
        #Semi implicit updates
        denominator = ones(N, Float)-timestep*self.semi_implicit_update

        if sum(equal(denominator, 0.0)) > 0.0:
            msg = 'Zero division in semi implicit update. Call Stephen :-)'
            raise msg
        else:
            #Update conserved_quantities from semi implicit updates
            self.centroid_values /= denominator


    def compute_gradients(self):
        """Compute gradients of piecewise linear function defined by centroids of
        neighbouring volumes.
        """

        #print 'compute_gradient'

        from Numeric import array, zeros, Float

        N = self.centroid_values.shape[0]


        G = self.gradients
        Q = self.centroid_values
        X = self.domain.centroids

        for k in range(N):

    	    # first and last elements have boundaries

            if k == 0:

                #Get data
                k0 = k
                k1 = k+1
                k2 = k+2

                q0 = Q[k0]
                q1 = Q[k1]
                q2 = Q[k2]

                x0 = X[k0] #V0 centroid
                x1 = X[k1] #V1 centroid
                x2 = X[k2]

                #Gradient
                #G[k] = (q1 - q0)/(x1 - x0)
                
                G[k] = (q1 - q0)*(x2 - x0)*(x2 - x0) - (q2 - q0)*(x1 - x0)*(x1 - x0)
                G[k] /= (x1 - x0)*(x2 - x0)*(x2 - x1)

            elif k == N-1:

                #Get data
                k0 = k
                k1 = k-1
                k2 = k-2

                q0 = Q[k0]
                q1 = Q[k1]
                q2 = Q[k2]

                x0 = X[k0] #V0 centroid
                x1 = X[k1] #V1 centroid
                x2 = X[k2]

                #Gradient
                #G[k] = (q1 - q0)/(x1 - x0)
                
                G[k] = (q1 - q0)*(x2 - x0)*(x2 - x0) - (q2 - q0)*(x1 - x0)*(x1 - x0)
                G[k] /= (x1 - x0)*(x2 - x0)*(x2 - x1)

##                q0 = Q[k0]
##                q1 = Q[k1]
##
##                x0 = X[k0] #V0 centroid
##                x1 = X[k1] #V1 centroid
##
##                #Gradient
##                G[k] = (q1 - q0)/(x1 - x0)

            else:
                #Interior Volume (2 neighbours)

                #Get data
                k0 = k-1
                k2 = k+1

                q0 = Q[k0]
                q1 = Q[k]
                q2 = Q[k2]

                x0 = X[k0] #V0 centroid
                x1 = X[k]  #V1 centroid (Self)
                x2 = X[k2] #V2 centroid

                #Gradient
                #G[k] = (q2-q0)/(x2-x0)
                G[k] = ((q0-q1)/(x0-x1)*(x2-x1) - (q2-q1)/(x2-x1)*(x0-x1))/(x2-x0)


    def compute_minmod_gradients(self):
        """Compute gradients of piecewise linear function defined by centroids of
        neighbouring volumes.
        """

        #print 'compute_minmod_gradients'
        
        from Numeric import array, zeros, Float,sign
        
        def xmin(a,b):
            return 0.5*(sign(a)+sign(b))*min(abs(a),abs(b))

        def xmic(t,a,b):
            return xmin(t*xmin(a,b), 0.50*(a+b) )



        N = self.centroid_values.shape[0]


        G = self.gradients
        Q = self.centroid_values
        X = self.domain.centroids

        for k in range(N):

    	    # first and last elements have boundaries

            if k == 0:

                #Get data
                k0 = k
                k1 = k+1
                k2 = k+2

                q0 = Q[k0]
                q1 = Q[k1]
                q2 = Q[k2]

                x0 = X[k0] #V0 centroid
                x1 = X[k1] #V1 centroid
                x2 = X[k2]

                #Gradient
                #G[k] = (q1 - q0)/(x1 - x0)
                
                G[k] = (q1 - q0)*(x2 - x0)*(x2 - x0) - (q2 - q0)*(x1 - x0)*(x1 - x0)
                G[k] /= (x1 - x0)*(x2 - x0)*(x2 - x1)

            elif k == N-1:

                #Get data
                k0 = k
                k1 = k-1
                k2 = k-2

                q0 = Q[k0]
                q1 = Q[k1]
                q2 = Q[k2]

                x0 = X[k0] #V0 centroid
                x1 = X[k1] #V1 centroid
                x2 = X[k2]

                #Gradient
                #G[k] = (q1 - q0)/(x1 - x0)
                
                G[k] = (q1 - q0)*(x2 - x0)*(x2 - x0) - (q2 - q0)*(x1 - x0)*(x1 - x0)
                G[k] /= (x1 - x0)*(x2 - x0)*(x2 - x1)

##                #Get data
##                k0 = k
##                k1 = k-1
##
##                q0 = Q[k0]
##                q1 = Q[k1]
##
##                x0 = X[k0] #V0 centroid
##                x1 = X[k1] #V1 centroid
##
##                #Gradient
##                G[k] = (q1 - q0)/(x1 - x0)

            elif (self.domain.wet_nodes[k,0] == 2) & (self.domain.wet_nodes[k,1] == 2):
                G[k] = 0.0

            else:
                #Interior Volume (2 neighbours)

                #Get data
                k0 = k-1
                k2 = k+1

                q0 = Q[k0]
                q1 = Q[k]
                q2 = Q[k2]

                x0 = X[k0] #V0 centroid
                x1 = X[k]  #V1 centroid (Self)
                x2 = X[k2] #V2 centroid

                # assuming uniform grid
                d1 = (q1 - q0)/(x1-x0)
                d2 = (q2 - q1)/(x2-x1)

                #Gradient
                #G[k] = (d1+d2)*0.5
                #G[k] = (d1*(x2-x1) - d2*(x0-x1))/(x2-x0)                
                G[k] = xmic( self.domain.beta, d1, d2 )
        

    def extrapolate_first_order(self):
        """Extrapolate conserved quantities from centroid to
        vertices for each volume using
        first order scheme.
        """

        qc = self.centroid_values
        qv = self.vertex_values

        for i in range(2):
            qv[:,i] = qc


    def extrapolate_second_order(self):
        """Extrapolate conserved quantities from centroid to
        vertices for each volume using
        second order scheme.
        """
        if self.domain.limiter == "pyvolution":
            #Z = self.gradients
            #print "gradients 1",Z
            self.compute_gradients()
            #print "gradients 2",Z

            #Z = self.gradients
            #print "gradients 1",Z
            #self.compute_minmod_gradients()
            #print "gradients 2", Z

            G = self.gradients
            V = self.domain.vertices
            qc = self.centroid_values
            qv = self.vertex_values        

            #Check each triangle
            for k in range(self.domain.number_of_elements):
                #Centroid coordinates
                x = self.domain.centroids[k]

                #vertex coordinates
                x0, x1 = V[k,:]

                #Extrapolate
                qv[k,0] = qc[k] + G[k]*(x0-x)
                qv[k,1] = qc[k] + G[k]*(x1-x)
            self.limit_pyvolution()
        elif self.domain.limiter == "minmod_steve":
            self.limit_minmod()
        else:
            self.limit_range()
        
        

    def limit_minmod(self):
        #Z = self.gradients
        #print "gradients 1",Z
        self.compute_minmod_gradients()
        #print "gradients 2", Z

        G = self.gradients
        V = self.domain.vertices
        qc = self.centroid_values
        qv = self.vertex_values        
        
        #Check each triangle
        for k in range(self.domain.number_of_elements):
            #Centroid coordinates
            x = self.domain.centroids[k]

            #vertex coordinates
            x0, x1 = V[k,:]

            #Extrapolate
            qv[k,0] = qc[k] + G[k]*(x0-x)
            qv[k,1] = qc[k] + G[k]*(x1-x)

 
    def limit_pyvolution(self):
        """
        Limit slopes for each volume to eliminate artificial variance
        introduced by e.g. second order extrapolator

        This is an unsophisticated limiter as it does not take into
        account dependencies among quantities.

        precondition:
        vertex values are estimated from gradient
        postcondition:
        vertex values are updated
        """
        from Numeric import zeros, Float

        N = self.domain.number_of_elements
        beta = self.domain.beta
        #beta = 0.8

        qc = self.centroid_values
        qv = self.vertex_values

        #Find min and max of this and neighbour's centroid values
        qmax = self.qmax
        qmin = self.qmin

        for k in range(N):
            qmax[k] = qmin[k] = qc[k]
            for i in range(2):
                n = self.domain.neighbours[k,i]
                if n >= 0:
                    qn = qc[n] #Neighbour's centroid value

                    qmin[k] = min(qmin[k], qn)
                    qmax[k] = max(qmax[k], qn)


        #Diffences between centroids and maxima/minima
        dqmax = qmax - qc
        dqmin = qmin - qc

        #Deltas between vertex and centroid values
        dq = zeros(qv.shape, Float)
        for i in range(2):
            dq[:,i] = qv[:,i] - qc

        #Phi limiter
        for k in range(N):

            #Find the gradient limiter (phi) across vertices
            phi = 1.0
            for i in range(2):
                r = 1.0
                if (dq[k,i] > 0): r = dqmax[k]/dq[k,i]
                if (dq[k,i] < 0): r = dqmin[k]/dq[k,i]

                phi = min( min(r*beta, 1), phi )

            #Then update using phi limiter
            for i in range(2):
                qv[k,i] = qc[k] + phi*dq[k,i]

    def limit_range(self):
        import sys
        from Numeric import zeros, Float
        from util import minmod, minmod_kurganov, maxmod, vanleer, vanalbada
        limiter = self.domain.limiter
        #print limiter
        
        #print 'limit_range'
        N = self.domain.number_of_elements
        qc = self.centroid_values
        qv = self.vertex_values
        C = self.domain.centroids
        X = self.domain.vertices
        beta_p = zeros(N,Float)
        beta_m = zeros(N,Float)
        beta_x = zeros(N,Float)
        
        for k in range(N):
        
            n0 = self.domain.neighbours[k,0]
            n1 = self.domain.neighbours[k,1]
            
            if ( n0 >= 0) & (n1 >= 0):
                #SLOPE DERIVATIVE LIMIT
                beta_p[k] = (qc[k]-qc[k-1])/(C[k]-C[k-1])
                beta_m[k] = (qc[k+1]-qc[k])/(C[k+1]-C[k])
                beta_x[k] = (qc[k+1]-qc[k-1])/(C[k+1]-C[k-1])
                
        dq = zeros(qv.shape, Float)
        for i in range(2):
            dq[:,i] =self.domain.vertices[:,i]-self.domain.centroids
            
        #Phi limiter
        for k in range(N):
            n0 = self.domain.neighbours[k,0]
            n1 = self.domain.neighbours[k,1]
            if n0 < 0:
                phi = (qc[k+1] - qc[k])/(C[k+1] - C[k])
            elif n1 < 0:
                phi = (qc[k] - qc[k-1])/(C[k] - C[k-1])
            #elif (self.domain.wet_nodes[k,0] == 2) & (self.domain.wet_nodes[k,1] == 2):
            #    phi = 0.0
            else:
                if limiter == "minmod":
                    phi = minmod(beta_p[k],beta_m[k])

                elif limiter == "minmod_kurganov":#Change this
                    # Also known as monotonized central difference limiter
                    # if theta = 2.0
                    theta = 2.0 
                    phi = minmod_kurganov(theta*beta_p[k],theta*beta_m[k],beta_x[k])
                elif limiter == "superbee":
                    slope1 = minmod(beta_m[k],2.0*beta_p[k])
                    slope2 = minmod(2.0*beta_m[k],beta_p[k])
                    phi = maxmod(slope1,slope2)

                elif limiter == "vanleer":
                    phi = vanleer(beta_p[k],beta_m[k])

                elif limiter == "vanalbada":
                    phi = vanalbada(beta_m[k],beta_p[k])
            
            for i in range(2):
                qv[k,i] = qc[k] + phi*dq[k,i]

    def limit_steve_slope(self):    

        import sys
        from Numeric import zeros, Float
        from util import minmod, minmod_kurganov, maxmod, vanleer

        N = self.domain.number_of_elements
        limiter = self.domain.limiter
        limiter_type = self.domain.limiter_type
            
        qc = self.centroid_values
        qv = self.vertex_values

        #Find min and max of this and neighbour's centroid values
        beta_p = zeros(N,Float)
        beta_m = zeros(N,Float)
        beta_x = zeros(N,Float)
        C = self.domain.centroids
        X = self.domain.vertices

        for k in range(N):
        
            n0 = self.domain.neighbours[k,0]
            n1 = self.domain.neighbours[k,1]
            
            if (n0 >= 0) & (n1 >= 0):
                # Check denominator not zero
                if (qc[k+1]-qc[k]) == 0.0:
                    beta_p[k] = float(sys.maxint)
                    beta_m[k] = float(sys.maxint)
                else:
                    #STEVE LIMIT
                    beta_p[k] = (qc[k]-qc[k-1])/(qc[k+1]-qc[k])
                    beta_m[k] = (qc[k+2]-qc[k+1])/(qc[k+1]-qc[k])

        #Deltas between vertex and centroid values
        dq = zeros(qv.shape, Float)
        for i in range(2):
            dq[:,i] =self.domain.vertices[:,i]-self.domain.centroids
            
        #Phi limiter
        for k in range(N):
                
            phi = 0.0
            if limiter == "flux_minmod":
                #FLUX MINMOD
                phi = minmod_kurganov(1.0,beta_m[k],beta_p[k])
            elif limiter == "flux_superbee":
                #FLUX SUPERBEE
                phi = max(0.0,min(1.0,2.0*beta_m[k]),min(2.0,beta_m[k]))+max(0.0,min(1.0,2.0*beta_p[k]),min(2.0,beta_p[k]))-1.0
            elif limiter == "flux_muscl":
                #FLUX MUSCL
                phi = max(0.0,min(2.0,2.0*beta_m[k],2.0*beta_p[k],0.5*(beta_m[k]+beta_p[k])))
            elif limiter == "flux_vanleer":
                #FLUX VAN LEER
                phi = (beta_m[k]+abs(beta_m[k]))/(1.0+abs(beta_m[k]))+(beta_p[k]+abs(beta_p[k]))/(1.0+abs(beta_p[k]))-1.0
                
                #Then update using phi limiter
                n = self.domain.neighbours[k,1]
                if n>=0:
                    #qv[k,0] = qc[k] - 0.5*phi*(qc[k+1]-qc[k])
                    #qv[k,1] = qc[k] + 0.5*phi*(qc[k+1]-qc[k])
                    qv[k,0] = qc[k] + 0.5*phi*(qv[k,0]-qc[k])
                    qv[k,1] = qc[k] + 0.5*phi*(qv[k,1]-qc[k])
                else:
                    qv[k,i] = qc[k]

    def backup_centroid_values(self):
        # Call correct module function
        # (either from this module or C-extension)
        #backup_centroid_values(self)

        self.centroid_backup_values[:] = (self.centroid_values).astype('f')

    def saxpy_centroid_values(self,a,b):
        # Call correct module function
        # (either from this module or C-extension)
        self.centroid_values[:] = (a*self.centroid_values + b*self.centroid_backup_values).astype('f')
        
class Conserved_quantity(Quantity):
    """Class conserved quantity adds to Quantity:

    storage and method for updating, and
    methods for extrapolation from centropid to vertices inluding
    gradients and limiters
    """

    def __init__(self, domain, vertex_values=None):
        Quantity.__init__(self, domain, vertex_values)

        print "Use Quantity instead of Conserved_quantity"


##
##def newLinePlot(title='Simple Plot'):
##    import Gnuplot
##    g = Gnuplot.Gnuplot()
##    g.title(title)
##    g('set data style linespoints') 
##    g.xlabel('x')
##    g.ylabel('y')
##    return g
##
##def linePlot(g,x,y):
##    import Gnuplot
##    g.plot(Gnuplot.PlotItems.Data(x.flat,y.flat))

def newLinePlot(title='Simple Plot'):
    import pylab as g
    g.ion()
    g.hold(False)
    g.title(title)
    g.xlabel('x')
    g.ylabel('y')
    

def linePlot(x,y):
    import pylab as g
    g.plot(x.flat,y.flat)


def closePlots():
    import pylab as g
    g.close('all')
    
if __name__ == "__main__":
    #from domain import Domain
    from shallow_water_domain import Domain     
    from Numeric import arange
    
    points1 = [0.0, 1.0, 2.0, 3.0]
    vertex_values = [[1.0,2.0],[4.0,5.0],[-1.0,2.0]]

    D1 = Domain(points1)

    Q1 = Quantity(D1, vertex_values)

    print Q1.vertex_values
    print Q1.centroid_values

    new_vertex_values = [[2.0,1.0],[3.0,4.0],[-2.0,4.0]]

    Q1.set_values(new_vertex_values)

    print Q1.vertex_values
    print Q1.centroid_values

    new_centroid_values = [20,30,40]
    Q1.set_values(new_centroid_values,'centroids')

    print Q1.vertex_values
    print Q1.centroid_values

    class FunClass:
        def __init__(self,value):
            self.value = value

        def __call__(self,x):
            return self.value*(x**2)


    fun = FunClass(1.0)
    Q1.set_values(fun,'vertices')

    print Q1.vertex_values
    print Q1.centroid_values

    Xc = Q1.domain.vertices
    Qc = Q1.vertex_values
    print Xc
    print Qc

    Qc[1,0] = 3

    Q1.extrapolate_second_order()
    #Q1.limit_minmod()

    newLinePlot('plots')
    linePlot(Xc,Qc)
    raw_input('press return')

    points2 = arange(10)
    D2 = Domain(points2)

    Q2 = Quantity(D2)
    Q2.set_values(fun,'vertices')
    Xc = Q2.domain.vertices
    Qc = Q2.vertex_values
    linePlot(Xc,Qc)
    raw_input('press return')

    
    Q2.extrapolate_second_order()
    #Q2.limit_minmod()
    Xc = Q2.domain.vertices
    Qc = Q2.vertex_values
    print Q2.centroid_values
    print Qc
    linePlot(Xc,Qc)
    raw_input('press return')


    for i in range(10):
        import pylab as g
        g.hold(True)
        fun = FunClass(i/10.0)
        Q2.set_values(fun,'centroids')
        Q2.extrapolate_second_order()
        #Q2.limit_minmod()
        Qc = Q2.vertex_values
        linePlot(Xc,Qc)
        raw_input('press return')

    raw_input('press return to quit')
closePlots()
