import os
from math import sqrt, pi
from shallow_water_1d import *
from Numeric import allclose, array, zeros, ones, Float, take, sqrt
from config import g, epsilon
from analytic_dam import AnalyticDam


h0 = 5.0
h1 = 10.0

analytical_sol = AnalyticDam(h0, h1)

def newLinePlot(title='Simple Plot'):
    import Gnuplot
    g = Gnuplot.Gnuplot(persist=1)
    g.title(title)
    g('set data style linespoints') 
    g.xlabel('x')
    g.ylabel('y')
    return g

def linePlot(g,x1,y1,x2,y2):
    import Gnuplot
    plot1 = Gnuplot.PlotItems.Data(x1.flat,y1.flat,with="linespoints")
    plot2 = Gnuplot.PlotItems.Data(x2.flat,y2.flat, with="lines 3")
    g.plot(plot1,plot2)
    #g.plot(Gnuplot.PlotItems.Data(x1.flat,y1.flat),with="linespoints")
    #g.plot(Gnuplot.PlotItems.Data(x2.flat,y2.flat), with="lines")

def stage(x):
    y = zeros(len(x),Float)
    for i in range(len(x)):
        if x[i]<=1000.0:
            y[i] = h1
        else:
            y[i] = h0
    return y


import time
finaltime = 30.0
yieldstep = finaltime

L = 2000.0     # Length of channel (m)
#number_of_cells = [25,50,100,200,400,800,1600,3200,6400,12800,25600]
number_of_cells = [20]
h_error = zeros(len(number_of_cells),Float)
uh_error = zeros(len(number_of_cells),Float)
k = 0
for i in range(len(number_of_cells)):
    N = int(number_of_cells[i])
    print "Evaluating domain with",N,"cells"
    cell_len = L/N # Origin = 0.0
    points = zeros(N+1,Float)
    for j in range(N+1):
        points[j] = j*cell_len
    domain = Domain(points)
    domain.set_quantity('stage', stage)
    domain.set_boundary({'exterior': Reflective_boundary(domain)})
    domain.default_order = 2
    domain.default_time_order = 2
    print "time order", domain.default_time_order
    domain.cfl = 1.0
    domain.beta = 1.0
    domain.limiter = "steve_minmod"
    #domain.limiter = "superbee"
    init_integral = domain.quantities['stage'].get_integral()
    t0 = time.time()
    for t in domain.evolve(yieldstep = yieldstep, finaltime = finaltime):
        pass
    N = float(N)
    assert(allclose(domain.quantities['stage'].get_integral(),init_integral))
    StageC = domain.quantities['stage'].centroid_values
    XmomC = domain.quantities['xmomentum'].centroid_values
    C = domain.centroids
    h, uh = analytical_sol(C,domain.time)
    h_error[k] = 1.0/(N)*sum(abs(h-StageC))
    uh_error[k] = 1.0/(N)*sum(abs(uh-XmomC))
    print "h_error %.10f" %(h_error[k])
    print "uh_error %.10f"% (uh_error[k])
    k = k+1
    print 'That took %.2f seconds' %(time.time()-t0)
    X = domain.vertices
    StageQ = domain.quantities['stage'].vertex_values
    XmomQ = domain.quantities['xmomentum'].vertex_values
    h, uh = analytical_sol(X.flat,domain.time)
    from pylab import plot,title,xlabel,ylabel,legend,savefig,show,hold,subplot#,rc
    #rc('text', usetex=True)
    hold(False)
    plot1 = subplot(211)
    plot(X,h,X,StageQ)
    plot1.set_ylim([4,11])
    #title('Free Surface Elevation of a Dry Dam-Break')
    #ylabel('Stage (m)')
    #legend(('Analytical Solution', 'Numerical Solution'),
    #       'upper right', shadow=True)
    #plot2 = subplot(212)
    #plot(X,uh,X,XmomQ)
    #plot2.set_ylim([-1,25])
    #title('Xmomentum Profile of a Dry Dam-Break')
    #xlabel('x (m)')
    #ylabel(r'X-momentum ($m^2/s$)')
    #legend(('Analytical Solution', 'Numerical Solution'),
    #       'upper right', shadow=True)
    #filename = "subcritical_flow_s2_t2_"
    #filename += domain.limiter
    #filename += str(number_of_cells[i])
    #filename += ".eps"
    #savefig(filename)
    show()

print "Error in height", h_error
print "Error in xmom", uh_error

