"""Example of shallow water wave equation analytical solution
consists of a parabolic profile in a parabolic basin. Analytical
solutiuon to this problem was derived by Carrier and Greenspan
and used by Yoon and Chou.

   Copyright 2005
   Christopher Zoppou, Stephen Roberts, ANU, Geoscience Australia

"""

#---------------
# Module imports
#import sys
#from os import sep
#sys.path.append('..'+sep+'pyvolution')
from anuga.pyvolution.shallow_water import Domain, Transmissive_boundary, Reflective_boundary,\
     Dirichlet_boundary
from math import sqrt, cos, sin, pi
from anuga.pyvolution.mesh_factory import rectangular_cross
from anuga.utilities.polygon import inside_polygon
from Numeric import asarray
from anuga.pyvolution.least_squares import Interpolation

#-------------------------------
# Domain
n = 100
m = 100
delta_x = 80.0
delta_y = 80.0
lenx = delta_x*n
leny = delta_y*m
origin = (-4000.0, -4000.0)

points, elements, boundary = rectangular_cross(m, n, lenx, leny, origin)
domain = Domain(points, elements, boundary)

#----------------
# Order of scheme
domain.default_order = 1

domain.smooth = True

#-------------------------------------
# Provide file name for storing output
domain.store = False
domain.format = 'sww'
domain.set_name('yoon_mesh_second_order_cross')
print 'Number of triangles = ', len(domain)

#----------------------------------------------------------
# Decide which quantities are to be stored at each timestep
domain.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']

#------------------------------------------
# Reduction operation for get_vertex_values
from anuga.pyvolution.util import mean
domain.reduction = mean #domain.reduction = min  #Looks better near steep slopes

#------------------
# Initial condition
print 'Initial condition'
t = 0.0
D0 = 1.
L = 2500.
R0 = 2000.
g = 9.81

A = (L**4 - R0**4)/(L**4 + R0**4)
omega = 2./L*sqrt(2.*g*D0)
T = pi/omega

#------------------
# Set bed elevation
def x_slope(x,y):
    n = x.shape[0]
    z = 0*x
    for i in range(n):
        r = sqrt(x[i]*x[i] + y[i]*y[i])
        z[i] = -D0*(1.-r*r/L/L)
    return z
domain.set_quantity('elevation', x_slope)

#----------------------------
# Set the initial water level
def level(x,y):
    z = x_slope(x,y)
    n = x.shape[0]
    h = 0*x
    for i in range(n):
        r = sqrt(x[i]*x[i] + y[i]*y[i])
        h[i] = D0*((sqrt(1-A*A))/(1.-A*cos(omega*t))
                -1.-r*r/L/L*((1.-A*A)/((1.-A*cos(omega*t))**2)-1.))
    if h[i] < z[i]:
        h[i] = z[i]
    return h
domain.set_quantity('stage', level)

#---------
# Boundary
print 'Boundary conditions'
R = Reflective_boundary(domain)
T = Transmissive_boundary(domain)
D = Dirichlet_boundary([0.0, 0.0, 0.0])
domain.set_boundary({'left': D, 'right': D, 'top': D, 'bottom': D})

#---------------------------------------------
# Find triangle that contains the point points
# and print to file
points = [0.,0.]
for n in range(len(domain.triangles)):
    t = domain.triangles[n]
    tri = [domain.coordinates[t[0]],domain.coordinates[t[1]],domain.coordinates[t[2]]]

    if inside_polygon(points,tri):
        print 'Point is within triangle with vertices '+'%s'%tri
        n_point = n

print 'n_point = ',n_point
t = domain.triangles[n_point]
tri = [domain.coordinates[t[0]],domain.coordinates[t[1]],domain.coordinates[t[2]]]

filename=domain.get_name()
file = open(filename,'w')

#----------
# Evolution
import time
t0 = time.time()

time_array = []
stage_array = []
Stage     = domain.quantities['stage']
Xmomentum = domain.quantities['xmomentum']
Ymomentum = domain.quantities['ymomentum']

for t in domain.evolve(yieldstep = 20.0, finaltime = 17700.0 ):
    domain.write_time()

    tri_array = asarray(tri)
    t_array = asarray([[0,1,2]])
    interp = Interpolation(tri_array,t_array,[points])


    stage     = Stage.get_values(location='centroids',indices=[n_point])[0]
    xmomentum = Xmomentum.get_values(location='centroids',indices=[n_point])[0]
    ymomentum = Ymomentum.get_values(location='centroids',indices=[n_point])[0]
    file.write( '%10.6f   %10.6f  %10.6f   %10.6f\n'%(t,stage,xmomentum,ymomentum) )

    time_array.append(t)
    stage_array.append(stage)

file.close()
print 'That took %.2f seconds' %(time.time()-t0)


from pylab import *
ion()
hold(False)
plot(time_array, stage_array, 'r.-')
#title('Gauge %s' %name)
xlabel('time(s)')
ylabel('stage (m)')
legend(('Observed', 'Modelled'), shadow=True, loc='upper left')
#savefig(name, dpi = 300)

#raw_input('Next')
show()


