
import os
from math import sqrt, pi
from shallow_water_1d import *
from Numeric import allclose, array, zeros, ones, Float, take, sqrt
from config import g, epsilon
from analytic_dam import AnalyticDam

h0 = 0.1
h1 = 10.0

analytical_sol = AnalyticDam(h0, h1)

## def analytical_sol(C,t):
    
##     #t  = 0.0     # time (s)
##     g  = 9.81    # gravity (m/s^2)
##     h1 = 10.0    # depth upstream (m)
##     h0 = 5.0     # depth downstream (m)
##     L = 2000.0   # length of stream/domain (m)
##     n = len(C)    # number of cells

##     u = zeros(n,Float)
##     h = zeros(n,Float)
##     x = C-L/2
##     #x = zeros(n,Float)
##     #for i in range(n):
##     #    x[i] = C[i]-1000.0

##     # Upstream and downstream boundary conditions are set to the intial water
##     # depth for all time.

##     # Calculate Shock Speed
##     h2 = 7.2692044
##     S2 = 2*h2/(h2-h0)*(sqrt(g*h1)-sqrt(g*h2))
##     u2 = S2 - g*h0/(4*S2)*(1+sqrt(1+8*S2*S2/(g*h0)))
    
##     #t=50 
##     #x = (-L/2:L/2) 
##     for i in range(n):
##         # Calculate Analytical Solution at time t > 0
##         u3 = 2/3*(sqrt(g*h1)+x[i]/t) 
##         h3 = 4/(9*g)*(sqrt(g*h1)-x[i]/(2*t))*(sqrt(g*h1)-x[i]/(2*t)) 

##         if ( x[i] <= -t*sqrt(g*h1) ):
##             u[i] = 0.0 
##             h[i] = h1 
##         elif ( x[i] <= t*(u2-sqrt(g*h2)) ):
##             u[i] = u3 
##             h[i] = h3 
##         elif ( x[i] < t*S2 ):
##             u[i] = u2 
##             h[i] = h2 
##         else:
##             u[i] = 0.0 
##             h[i] = h0 

##     return h , u*h


def newLinePlot(title='Simple Plot'):
    import Gnuplot
    g = Gnuplot.Gnuplot()
    g.title(title)
    g('set style data lines') 
    g.xlabel('x')
    g.ylabel('y')
    return g

def linePlot(g,x1,y1,x2,y2):
    import Gnuplot
    plot1 = Gnuplot.PlotItems.Data(x1.flat,y1.flat,with="linespoints")
    plot2 = Gnuplot.PlotItems.Data(x2.flat,y2.flat, with="lines 3")
    g.plot(plot1,plot2)
    #g.plot(Gnuplot.PlotItems.Data(x1.flat,y1.flat),with="lines")
    #g.plot(Gnuplot.PlotItems.Data(x2.flat,y2.flat), with="lines")

debug = False

print "TEST 1D-SOLUTION I"

L = 2000.0     # Length of channel (m)
N = 100        # Number of compuational cells
cell_len = L/N # Origin = 0.0

points = zeros(N+1,Float)
for i in range(N+1):
    points[i] = i*cell_len

domain = Domain(points)

def stage(x):
    y = zeros(len(x),Float)
    for i in range(len(x)):
        if x[i]<=1000.0:
            y[i] = h1
        else:
            y[i] = h0
    return y

domain.set_quantity('stage', stage)


domain.default_order = 2
domain.cfl = 1.0
domain.beta = 0.85
print "domain.order", domain.default_order

if (debug == True):
    area = domain.areas
    for i in range(len(area)):
        if area != 20:
            print "Cell Areas are Wrong"
            
    L = domain.quantities['stage'].vertex_values
    print "Initial Stage"
    print L

domain.set_boundary({'exterior': Reflective_boundary(domain)})

X = domain.vertices
C = domain.centroids
plot1 = newLinePlot("Stage")
plot2 = newLinePlot("Momentum")
plot3 = newLinePlot("Velocity")

import time
t0 = time.time()
yieldstep = 1.0
finaltime = 50.0

for t in domain.evolve(yieldstep = yieldstep, finaltime = finaltime):
    domain.write_time()
    if t > 0.0:
        StageQ = domain.quantities['stage'].vertex_values
        MomentumQ = domain.quantities['xmomentum'].vertex_values
        Velocity = MomentumQ/StageQ
        y , my = analytical_sol(X.flat,domain.time)
        linePlot(plot1,X,StageQ,X,y)
        linePlot(plot2,X,MomentumQ,X,my)
        linePlot(plot3,X,Velocity,X,my/y)
    #raw_input('press_return')
    #pass

print 'That took %.2f seconds' %(time.time()-t0)

C = domain.quantities['stage'].centroid_values

if (debug == True):
    L = domain.quantities['stage'].vertex_values
    print "Final Stage Vertex Values"
    print L
    print "Final Stage Centroid Values"
    print C

raw_input('press return')
#f = file('test_solution_I.out', 'w')
#for i in range(N):
#    f.write(str(C[i]))
#    f.write("\n")
#f.close

#del plot1, plot2,plot3
