import os
from math import sqrt, pi
from shallow_water_1d import *
from Numeric import allclose, array, zeros, ones, Float, take, sqrt
from config import g, epsilon

def analytical_sol(C,t):
    
    #t  = 0.0     # time (s)
    g  = 9.81    # gravity (m/s^2)
    h1 = 10.0    # depth upstream (m)
    h0 = 0.1     # depth downstream (m)
    L = 2000.0   # length of stream/domain (m)
    n = len(C)    # number of cells

    u = zeros(n,Float)
    h = zeros(n,Float)
    
    x = C-L/2
    #x = zeros(n,Float)
    #for i in range(n):
    #    x[i] = C[i]-1000.0

    # Upstream and downstream boundary conditions are set to the intial water
    # depth for all time.

    # Calculate Shock Speed
    h2 = 1.7117807
    S2 = 2*h2/(h2-h0)*(sqrt(g*h1)-sqrt(g*h2))
    u2 = S2 - g*h0/(4*S2)*(1+sqrt(1+8*S2*S2/(g*h0)))
    
    #t=50 
    #x = (-L/2:L/2) 
    for i in range(n):
        # Calculate Analytical Solution at time t > 0
        u3 = 2/3*(sqrt(g*h1)+x[i]/t) 
        h3 = 4/(9*g)*(sqrt(g*h1)-x[i]/(2*t))*(sqrt(g*h1)-x[i]/(2*t)) 

        if ( x[i] <= -t*sqrt(g*h1) ):
            u[i] = 0.0 
            h[i] = h1 
        elif ( x[i] <= t*(u2-sqrt(g*h2)) ):
            u[i] = u3 
            h[i] = h3 
        elif ( x[i] < t*S2 ):
            u[i] = u2 
            h[i] = h2 
        else:
            u[i] = 0.0 
            h[i] = h0 

    return h,u*h


def newLinePlot(title='Simple Plot'):
    import Gnuplot
    g = Gnuplot.Gnuplot(persist=1)
    g.title(title)
    g('set data style linespoints') 
    g.xlabel('x')
    g.ylabel('y')
    return g

def linePlot(g,x1,y1,x2,y2):
    import Gnuplot
    plot1 = Gnuplot.PlotItems.Data(x1.flat,y1.flat,with="linespoints")
    plot2 = Gnuplot.PlotItems.Data(x2.flat,y2.flat, with="lines 3")
    g.plot(plot1,plot2)
    #g.plot(Gnuplot.PlotItems.Data(x1.flat,y1.flat),with="linespoints")
    #g.plot(Gnuplot.PlotItems.Data(x2.flat,y2.flat), with="lines")

debug = False

print "TEST 1D-SOLUTION II -- 0.1m Deep Downstream"

L = 2000.0     # Length of channel (m)
N = 100        # Number of compuational cells
cell_len = L/N # Origin = 0.0

points = zeros(N+1,Float)
for i in range(N+1):
    points[i] = i*cell_len

domain = Domain(points)

def stage(x):
    y = zeros(len(x),Float)
    for i in range(len(x)):
        if x[i]<=1000.0:
            y[i] = 10.0
        else:
            y[i] = 0.1
    return y

domain.set_quantity('stage', stage)

domain.order = 2
domain.default_order = 2
domain.cfl = 0.5
print "domain.order", domain.order

if (debug == True):
    area = domain.areas
    for i in range(len(area)):
        if area != 20:
            print "Cell Areas are Wrong"
            
L = domain.quantities['stage'].vertex_values
print "Initial Stage"
print L
raw_input('press enter')

domain.set_boundary({'exterior': Reflective_boundary(domain)})

X = domain.vertices
C = domain.centroids
plot1 = newLinePlot("Stage")
plot2 = newLinePlot("Momentum")

import time
t0 = time.time()
yieldstep = 1.0
finaltime = 50.0
for t in domain.evolve(yieldstep = yieldstep, finaltime = finaltime):
    domain.write_time()
    if t > 0.0:
        StageQ = domain.quantities['stage'].vertex_values
        y,my = analytical_sol(X.flat,domain.time)
        linePlot(plot1,X,StageQ,X,y)
        MomentumQ = domain.quantities['xmomentum'].vertex_values
        linePlot(plot2,X,MomentumQ,X,my)
    #raw_input('press_return')
    #pass

print 'That took %.2f seconds' %(time.time()-t0)

C = domain.quantities['stage'].centroid_values

if (debug == True):
    L = domain.quantities['stage'].vertex_values
    print "Final Stage Vertex Values"
    print L
    print "Final Stage Centroid Values"
    print C

#f = file('test_solution_I.out', 'w')
#for i in range(N):
#    f.write(str(C[i]))
#    f.write("\n")
#f.close

