"""Stochastic study of the ANUGA implementation of the
shallow water wave equation.

This script runs the model for one realisation of bathymetry as
given in the file bathymetry.txt and outputs a full simulation is \
sww NetCDF format.

The left boundary condition is a timeseries defined in
NetCDF file: input_wave.tms

Note: This scripts needs create_mesh.py to have been run

Suresh Kumar and Ole Nielsen 2006
"""


#------------------------------------------------------------------------------
# Import necessary modules
#------------------------------------------------------------------------------

# Standard modules
import os
import time
import cPickle

# Related major packages
from anuga.pyvolution.shallow_water import Domain
from anuga.pyvolution.shallow_water import Reflective_boundary
from anuga.pyvolution.shallow_water import Transmissive_Momentum_Set_Stage_boundary
from anuga.pyvolution.pmesh2domain import pmesh_to_domain_instance
from anuga.pyvolution.data_manager import xya2pts
from anuga.pyvolution.util import file_function
from caching.caching import cache

# Application specific imports
import project                 # Definition of file names and polygons


#-----------------------------------------------------------------------------
# Read in processor information
#-----------------------------------------------------------------------------

try:
    import pypar
except:
    print 'Could not import pypar'
    myid = 0
    numprocs = 1
    processor_name = 'local host'
else:    
    myid = pypar.rank()
    numprocs = pypar.size()
    processor_name = pypar.Get_processor_name()

print 'I am process %d of %d running on %s' %(myid, numprocs, processor_name)


#-----------------------------------------------------------------------------
# Setup computational domain
#----------------------------------------------------------------------------- 
#print 'Creating domain from', project.mesh_filename

domain = Domain(project.working_dir + project.mesh_filename,
                use_cache=False,
                verbose=False)                
                

#print 'Number of triangles = ', len(domain)
#print domain.statistics()


domain.set_datadir(project.working_dir)
domain.set_quantities_to_be_stored(['stage', 'xmomentum', 'ymomentum'])


#------------------------------------------------------------------------------
# Setup boundary conditions
#------------------------------------------------------------------------------

function = file_function(project.boundary_filename, domain, verbose = False)
Bts = Transmissive_Momentum_Set_Stage_boundary(domain, function) #Input wave
Br = Reflective_boundary(domain) #Wall

# Bind boundary objects to tags
domain.set_boundary({'wave': Bts, 'wall': Br})


#------------------------------------------------------------------------------
# Setup initial conditions
#------------------------------------------------------------------------------
domain.set_quantity('friction', 0.0)
domain.set_quantity('stage', 0.0)

# Get prefitted realisations

finaltime = 22.5
timestep = 0.05


realisation = 0
for filename in os.listdir(project.working_dir):
    if filename.startswith(project.basename) and filename.endswith('.pck'):
        print 'P%d: Reading %s' %(myid, filename)
        fid = open(project.working_dir + filename)
        V = cPickle.load(fid)
        fid.close()

        #if myid == 0:
        #    print 'V', V[6:7,:]
            
        # For each column (each realisation)
        for i in range(V.shape[1]):

            # Distribute work in round-robin fashion
            if i%numprocs == myid:
                
                name = project.basename + '_P%d' %myid    
                domain.set_name(name)                     #Output name
                print 'V', V.shape
                domain.set_quantity('elevation', V[:,i])  #Assign bathymetry

                print 'P%d: Setting quantity %d: %s' %(myid, i, str(V[:4,i]))
                
                domain.set_time(0.0)                      #Reset time

                #---------------------------------------------------
                # Evolve system through time
                #---------------------------------------------------
                print 'P%d: Running realisation %d of %d in block %s'\
                      %(myid, realisation, V.shape[1], filename)

                t0 = time.time()
                for t in domain.evolve(yieldstep = timestep,
                                       finaltime = finaltime):
                    pass
                    domain.write_time()
                    
        
                print 'P%d: Simulation of realisation %d took %.2f seconds'\
                      %(myid, realisation, time.time()-t0)




                #---------------------------------------------------
                # Now extract the 3 timeseries (Ch 5-7-9) and store them
                # in three files for this realisation

                print 'P%d: Extracting time series for realisation %d from file %s'\
                      %(myid, realisation, project.working_dir +\
                        domain.get_name() + '.sww')
                f = file_function(project.working_dir +\
                                  domain.get_name() + '.sww',
                                  quantities='stage',
                                  interpolation_points=project.gauges,
                                  verbose=False)


                simulation_name = project.working_dir + \
                                  project.basename + '_realisation_%d' %realisation

                print 'P%d: Writing to file %s'\
                      %(myid, simulation_name + '_' + name + '.txt')

                for k, name in enumerate(project.gauge_names):
                    fid = open(simulation_name + '_' + name + '.txt', 'w')
                    for t in f.get_time():
                        #For all precomputed timesteps
                        val = f(t, point_id = k)[0]
                        fid.write('%f %f\n' %(t, val))

                    fid.close()

                    

            realisation += 1            


pypar.finalize()
