

class AnalyticDam:

    def __init__(self, h0 = 5.0, h1 = 10.0, L = 2000.0):
        from math import sqrt
        
        self.h0 = h0 # depth upstream (m)
        self.h1 = h1 # depth downstream (m)
        self.L  = L  # length of domain

        g  = 9.81    # gravity (m/s^2)
        
        c0 = sqrt(g*h0) #left celerity
        c1 = sqrt(g*h1) #right celerity
        
        zmin=-100.0
        zmax=101.0
        for i in range(100):
            z=(zmin+zmax)/2.0
            u2=z-c0*c0/4.0/z*(1.0+sqrt(1.0+8.0*z*z/c0/c0))
            c2=c0*sqrt(0.5*(sqrt(1.0+8.0*z*z/c0/c0)-1.0))
            func=2.0*c1/c0-u2/c0-2.0*c2/c0
            if (func > 0.0):
                zmin=z
            else:
                zmax=z

        if( abs(z) > 99.0):
            print 'no convergence'

        self.u2 = u2
        self.c0 = c0
        self.c1 = c1
        self.c2 = c2
        self.g = g
        self.z = z



    def __call__(self, C,t):
        
        from Numeric import zeros,Float
        from math import sqrt
    
        #t  = 0.0     # time (s)
        h0 = self.h0    
        h1 = self.h1    
        L = self.L 
        n = len(C)    # number of cells

        u2 = self.u2
        c0 = self.c0
        c1 = self.c1
        c2 = self.c2
        
        g = self.g
        z = self.z

        u = zeros(n,Float)
        h = zeros(n,Float)
        uh = zeros(n,Float)
        x = C-L/2.0
        #x = zeros(n,Float)
        #for i in range(n):
        #    x[i] = C[i]-1000.0

        # Upstream and downstream boundary conditions are set to the intial water
        # depth for all time.

        # Calculate Shock Speed
        #h2 = 7.2692044
        
        #S2 = 2*h2/(h2-h0)*(sqrt(g*h1)-sqrt(g*h2))
        #u2 = S2 - g*h0/(4*S2)*(1+sqrt(1+8*S2*S2/(g*h0)))

        h2=h0/(1.0-u2/z)
        x3=(u2-c2)*t
        x2=z*t
        x1=-c1*t
    
        #t=50 
        #x = (-L/2:L/2) 
        for i in range(n):
            # Calculate Analytical Solution at time t > 0
            u3 = 2.0/3.0*(sqrt(g*h1)+x[i]/t) 
            h3 = 4.0/(9.0*g)*(sqrt(g*h1)-x[i]/(2.0*t))*(sqrt(g*h1)-x[i]/(2.0*t)) 

            if ( x[i] <= x1 ):
                u[i] = 0.0 
                h[i] = h1
                uh[i] = u[i]*h[i]
            elif ( x[i] <= x3 ):
                u[i] = u3 
                h[i] = h3
                uh[i] = u[i]*h[i]
            elif ( x[i] < x2 ):
                u[i] = u2 
                h[i] = h2
                uh[i] = u[i]*h[i]
            else:
                u[i] = 0.0 
                h[i] = h0
                uh[i] = u[i]*h[i]

        return h , uh
