

class AnalyticDam:

    def __init__(self, h0 = 5.0, h1 = 10.0, L = 2000.0):
        from math import sqrt
        
        self.h0 = h0 # depth upstream (m)
        self.h1 = h1 # depth downstream (m)
        self.L  = L  # length of domain

        g  = 9.81    # gravity (m/s^2)
        
        c0 = sqrt(g*h0) #left celerity
        c1 = sqrt(g*h1) #right celerity
        
        zmin=-100.0
        zmax=101.0
        for i in range(100):
            z=(zmin+zmax)/2.0
            u2=z-c0*c0/4.0/z*(1.0+sqrt(1.0+8.0*z*z/c0/c0))
            c2=c0*sqrt(0.5*(sqrt(1.0+8.0*z*z/c0/c0)-1.0))
            func=2.0*c1/c0-u2/c0-2.0*c2/c0
            if (func > 0.0):
                zmin=z
            else:
                zmax=z
            #print 'func=',func
        if( abs(z) > 99.0):
            print 'no convergence'

        self.u2 = u2
        self.c0 = c0
        self.c1 = c1
        self.c2 = c2
        self.g = g
        self.z = z



    def __call__(self, C,t):
        
        from Numeric import zeros,Float
        from math import sqrt
    
        #t  = 0.0     # time (s)
        h0 = self.h0    
        h1 = self.h1    
        L = self.L 
        n = len(C)    # number of cells

        u2 = self.u2
        c0 = self.c0
        c1 = self.c1
        c2 = self.c2
        
        g = self.g
        z = self.z

        u = zeros(n,Float)
        h = zeros(n,Float)
        uh = zeros(n,Float)
        x = C-3*L/4.0
        #x = zeros(n,Float)
        #for i in range(n):
        #    x[i] = C[i]-1000.0

        # Upstream and downstream boundary conditions are set to the intial water
        # depth for all time.

        # Calculate Shock Speed
        #h2 = 7.2692044
        
        #S2 = 2*h2/(h2-h0)*(sqrt(g*h1)-sqrt(g*h2))
        #u2 = S2 - g*h0/(4*S2)*(1+sqrt(1+8*S2*S2/(g*h0)))

        h2=h0/(1.0-u2/z)
        x3=(u2-c2)*t
        x2=z*t
        x1=-c1*t
        x1_ = -1*L/2.0-x1
        x2_ = -1*L/2.0+2*x1
        #x3_ = -1*L/2.0-x3
        #t=50 
        #x = (-L/2:L/2) 
        for i in range(n):
            # Calculate Analytical Solution at time t > 0
            u3 = 2.0/3.0*(sqrt(g*h1)+x[i]/t) 
            h3 = 4.0/(9.0*g)*(sqrt(g*h1)-x[i]/(2.0*t))*(sqrt(g*h1)-x[i]/(2.0*t))
            u3_ = 2.0/3.0*((x[i]+L/2.0)/t-sqrt(g*h1))
            h3_ = 1.0/(9.0*g)*((x[i]+L/2.0)/t+2*sqrt(g*h1))*((x[i]+L/2.0)/t+2*sqrt(g*h1))
            #if t == 30:
            #    x[i] = 500
            #    print 'x2',x2
            #    print 'x3',x3
            #    print 'x1',x1
            if ( x[i] <= x2_ ):
                #print 'here x2_=', x2_
                u[i] = 0.0
                h[i] = 0.0
                uh [i] = u[i]*h[i]
            #elif ( x[i] <= x3_ ):
            #    print 'here x3_=', x3_
            #    u[i] = -1*u2 
            #    h[i] = h2
            #    uh[i] = u[i]*h[i]
            elif ( x[i] <= x1_ ):
                #print 'here x1_=', x1_
                u[i] = u3_ 
                h[i] = h3_
                uh[i] = u[i]*h[i]
            #else:
            #    u[i] = 0.0 
            #    h[i] = h0
            #    uh[i] = u[i]*h[i]

            #elif ( x[i] <= x1/2.0 ):
            #    u[i] = 0.0 
            #    h[i] = h1
            #    uh[i] = u[i]*h[i]
            elif ( x[i] <= x1 ):
                #print 'here x1=', x1
                u[i] = 0.0 
                h[i] = h1
                uh[i] = u[i]*h[i]
            elif ( x[i] <= x3 ):
                #print 'here x3=', x3
                u[i] = u3 
                h[i] = h3
                uh[i] = u[i]*h[i]
            elif ( x[i] < x2 ):
                #print 'here x2=', x2
                u[i] = u2 
                h[i] = h2
                uh[i] = u[i]*h[i]
            else:
                #print 'here the last section'
                u[i] = 0.0 
                h[i] = h0
                uh[i] = u[i]*h[i]

        return h , uh
