"""Read in tms file, interpolate at specified locations (ch 5,7,9) and compare
"""

import sys
from os import sep
from caching import cache
import project

sys.path.append('..'+sep+'..'+sep)


gauges = [[0.000, 1.696]]  #Boundary
gauges += [[4.521, 1.196],  [4.521, 1.696],  [4.521, 2.196]] #Ch 5-7-9

gauge_names = ['Boundary', 'ch5', 'ch7', 'ch9']


finaltime = 22.5
timestep = 0.05

#Read reference data

#Input wave
filename = 'input_wave.tms'
print 'Reading', filename
from Scientific.IO.NetCDF import NetCDFFile
from Numeric import allclose
fid = NetCDFFile(filename, 'r')#

input_time = fid.variables['time'][:]
input_stage = fid.variables['stage'][:]


#gauges
reference_time = []
ch5 = []
ch7 = []
ch9 = []
filename = 'output_ch5-7-9.txt'
fid = open(filename)
lines = fid.readlines()
fid.close()
for i, line in enumerate(lines[1:]):
    if i == len(input_time): break

    fields = line.split()

    reference_time.append(float(fields[0]))
    ch5.append(float(fields[1])/100)   #cm2m
    ch7.append(float(fields[2])/100)   #cm2m
    ch9.append(float(fields[3])/100)   #cm2m


from anuga.pyvolution.util import file_function
from anuga.utilities.numerical_tools import ensure_numeric
gauge_values = [ensure_numeric(input_stage),
                ensure_numeric(ch5),
                ensure_numeric(ch7),
                ensure_numeric(ch9)] #Reference values



#Read model output
#filename = project.basename + '_original.sww'
filename = project.basename + '.sww'

f = cache(file_function, filename,
          {'quantities': 'stage',
           'interpolation_points': gauges,
           'verbose': True},
          #evaluate = True,
          dependencies = [filename],
          verbose = True)




#Checks
#print reference_time
#print input_time
assert reference_time[0] == 0.0
assert reference_time[-1] == finaltime
assert allclose(reference_time, input_time)



#Validation


for k, name in enumerate(gauge_names):
    sqsum = 0
    denom = 0
    model = []
    print 'Validating ' + name
    for i, t in enumerate(reference_time):
        ref = gauge_values[k][i]
        val = f(t, point_id = k)[0]
        model.append(val)

        sqsum += (ref - val)**2
        denom += ref**2

    print sqsum
    print sqsum/denom

    from pylab import *
    ion()
    hold(False)
    plot(reference_time, gauge_values[k], 'r.-',
         reference_time, model, 'k-')
    title('Gauge %s' %name)
    xlabel('time(s)')
    ylabel('stage (m)')    
    legend(('Observed', 'Modelled'), shadow=True, loc='upper left')
    savefig(name, dpi = 300)

    raw_input('Next')
show()



#from pylab import *
#plot(time, stage)
#show()
