#!/usr/bin/env python

'''
A program to compare two SWW  files for "equality".

This program makes lots of assumptions about the structure of the SWW files, 
so if that structure changes, this program must change.
'''

import sys
import os
import os.path
import getopt
from Scientific.IO.NetCDF import NetCDFFile
import numpy as num
from anuga.config import netcdf_mode_r


#####
# Various constants.
#####

# Global attributes that should exist and be same in both files
# Don't have to have all of these, and we don't care about others.
expect_global_attributes = ['smoothing', 'vertices_are_stored_uniquely',
                            'order', 'revision_number', 'starttime',
                            'xllcorner', 'yllcorner',
                            'zone', 'false_easting', 'false_northing',
                            'datum', 'projection', 'units']

# dimensions expected, with expected values (None means unknown)
expected_dimensions = {'number_of_volumes': None,
                       'number_of_vertices': 3,
                       'numbers_in_range': 2,
                       'number_of_points': None,
                       'number_of_timesteps': None}

# Variables expected, with expected dimensions.
# Don't have to have all of these, and we don't care about others.
expected_variables = {'x': ('number_of_points',),
                      'y': ('number_of_points',),
                      'elevation': ('number_of_points',),
                      'elevation_range': ('numbers_in_range',),
                      'z': ('number_of_points',),
                      'volumes': ('number_of_volumes', 'number of vertices'),
                      'time': ('number_of_timesteps',),
                      'stage': ('numbers_in_range',),
                      'stage_range': ('numbers_in_range',),
                      'xmomentum': ('number_of_timesteps', 'number_of_points'),
                      'xmomentum_range': ('numbers_in_range'),
                      'ymomentum': ('number_of_timesteps', 'number_of_points'),
                      'ymomentum_range': ('numbers_in_range')}

##
# @brief An exception to inform user of usage problems.
class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg


##
# @brief Compare two SWW files.
# @param files A tuple of two filenames.
# @param globals A list of global attribute names to compare.
# @param timesteps A list of timesteps to compare at.
# @param variables A list of variable names to compare.
# @return Returns if files 'equal', else raises RuntimeError.
def files_are_the_same(files, globals=None, timesteps=None, variables=None):
    # split out the filenames and check they exist
    (file1, file2) = files

    error = False
    error_msg = ''

    try:
        fid1 = NetCDFFile(file1, netcdf_mode_r)
    except:
        error_msg += "File '%s' can't be opened?\n" % file1
        error = True

    try:
        fid2 = NetCDFFile(file2, netcdf_mode_r)
    except:
        error_msg += "File '%s' can't be opened?\n" % file2
        error = True
        fid1.close()

    if error:
        raise RuntimeError, error_msg

    #####
    # First, check that files have the required structure
    #####

    # dimensions - only check expected dimensions
    for key in expected_dimensions:
        if key not in fid1.dimensions.keys():
            error_msg += ("File %s doesn't contain dimension '%s'\n"
                          % (file1, key))
            error = True
        if key not in fid2.dimensions.keys():
            error_msg += ("File %s doesn't contain dimension '%s'\n"
                          % (file2, key))
            error = True

    # now check that dimensions are the same length
    # NOTE: DOESN'T CHECK 'UNLIMITED' DIMENSIONS YET! (get None at the moment)
    for dim in expected_dimensions:
        dim1_shape = fid1.dimensions.get(dim, None)
        dim2_shape = fid2.dimensions.get(dim, None)
        if dim1_shape and dim2_shape and dim1_shape != dim2_shape:
            error_msg += ('File %s has %s dimension of size %s, '
                          'file %s has that dimension of size %s\n'
                          % (file1, dim, str(dim1_shape),
                             file2, str(dim2_shape)))
            error = True

    # check that we have the required globals
    if globals:
        for glob in globals:
            if glob not in dir(fid1):
                error_msg += ("Global attribute '%s' isn't in file %s\n"
                              % (glob, file1))
                error = True
            if glob not in dir(fid2):
                error_msg += ("Global attribute '%s' isn't in file %s\n"
                              % (glob, file2))
                error = True
    else:
        # get list of global attributes
        glob_vars1 = []
        glob_vars2 = []
        for glob in expect_global_attributes:
            if glob in dir(fid1):
                glob_vars1.append(glob)
            if glob in dir(fid2):
                glob_vars2.append(glob)

        # now check attribute lists are same
        if glob_vars1 != glob_vars2:
            error_msg = ('Files differ in global attributes:\n'
                         '%s: %s,\n'
                         '%s: %s\n' % (file1, str(glob_vars1),
                                     file2, str(glob_vars2)))
            error = True
        globals = glob_vars1

    # get variables to test
    if variables:
        for var in variables:
            if var not in fid1.variables.keys():
                error_msg += ("Variable '%s' isn't in file %s\n"
                              % (var, file1))
                error = True
            if var not in fid2.variables.keys():
                error_msg += ("Variable '%s' isn't in file %s\n"
                              % (var, file2))
                error = True
    else:
        # check that variables are as expected in both files
        var_names1 = []
        var_names2 = []
        for var_name in expected_variables:
            if fid1.variables.has_key(var_name):
                var_names1.append(var_name)
            if fid2.variables.has_key(var_name):
                var_names2.append(var_name)
    
        if var_names1 != var_names2:
            error_msg += ('Variables are not the same between files: '
                          '%s variables= %s, '
                          '%s variables = %s\n'
                          % (file1, str(var_names1), file2, str(var_names2)))
            error = True
        variables = var_names1

    if error:
        fid1.close()
        fid2.close()
        raise RuntimeError, error_msg

    # get size of time dimension
    num_timesteps1 = fid1.variables['time'].shape
    num_timesteps2 = fid2.variables['time'].shape
    if num_timesteps1 != num_timesteps2:
        error_msg += ('Files have different number of timesteps: %s=%d, %s=%d\n'
                      % (file1, num_timesteps1, file2, num_timesteps2))
        error = True

    if error:
        fid1.close()
        fid2.close()
        raise RuntimeError, error_msg

    num_timesteps = num_timesteps1[0]

    # variable shapes same?
    for var_name in variables:
        var1 = fid1.variables[var_name]
        var2 = fid2.variables[var_name]
        var1_shape = var1.shape
        var2_shape = var2.shape
        if var1_shape != var2_shape:
            error_msg += ('Files differ in variable %s shape:\n'
                          '%s: %s,\n'
                          '%s: %s\n'
                          % (var_name, file1, str(var1_shape),
                             file2, str(var2_shape)))
            error = True
            continue

    if error:
        fid1.close()
        fid2.close()
        raise RuntimeError, error_msg

    #####
    # Now check that actual data values are the same
    #####

    # check values of global attributes
    for glob_name in globals:
        g1 = getattr(fid1, glob_name)
        g2 = getattr(fid2, glob_name)
        if g1 != g2:
            error_msg += ("Files differ in global '%s': "
                          "%s: '%s', "
                          "%s: '%s'\n"
                          % (glob_name, file1, str(g1), file2, str(g2)))
            error = True

    # check data variables, be clever with time series data
    for var_name in variables:
        var_dims = expected_variables[var_name]
        if (len(var_dims) > 1) and (var_dims[0] == 'number_of_timesteps'):
            # time series, check by timestep block
            for i in xrange(num_timesteps):
                var1 = num.array(fid1.variables[var_name][i,:])
                var2 = num.array(fid2.variables[var_name][i,:])
                if not num.allclose(var1, var2):
                    error_msg += ('Files differ in variable %s data:\n'
                                  '%s: %s\n'
                                  '%s: %s\n'
                                  % (glob_name, file1, str(var1),
                                                file2, str(var1)))
                    error = True
        else:
            # simple data, check whole thing at once
            var1 = num.array(fid1.variables[var_name][:])
            var2 = num.array(fid2.variables[var_name][:])
            if not num.allclose(var1, var2):
                error_msg += ('Files differ in variable %s:\n'
                              '%s: %s,\n'
                              '%s: %s\n'
                              % (var_name, file1, str(var1),
                                           file2, str(var2)))
                error = True

    if error:
        fid1.close()
        fid2.close()
        raise RuntimeError, error_msg

    #####
    # All OK, close files and signal EQUAL
    #####

    fid1.close()
    fid2.close()

    return


##
# @brief Return a usage string.
def usage():
    result = []
    a = result.append
    a('Usage: %s <options> <file1> <file2>\n' % ProgName)
    a('where <options> is zero or more of:\n')
    a('                   -h        print this help\n')
    a("                   -a <val>  set absolute threshold of 'equivalent'\n")
    a("                   -r <val>  set relative threshold of 'equivalent'\n")
    a('                   -g <arg>  check only global attributes specified\n')
    a('                             <arg> has the form <globname>[,<globname>[,...]]\n')
    a('                   -t <arg>  check only timesteps specified\n')
    a('                             <arg> has the form <starttime>[,<stoptime>[,<step>]]\n')
    a('                   -v <arg>  check only the named variables\n')
    a('                             <arg> has the form <varname>[,<varname>[,...]]\n')
    a('and <file1> and <file2> are two SWW files to compare.\n')
    a('\n')
    a('The program exit status is one of:\n')
    a('   0    the two files are equivalent\n')
    a('   else the files are not equivalent.')
    return ''.join(result)

##
# @brief Print a message to stderr.
def warn(msg):
    print >>sys.stderr, msg


##
# @brief 
# @param argv 
# @return The status code the program will exit with.
def main(argv=None):
    if argv is None:
        argv = sys.argv

    try:
        try:
            opts, args = getopt.getopt(argv[1:], 'hg:t:v:',
                                       ['help', 'globals',
                                        'variables', 'timesteps'])
        except getopt.error, msg:
            raise Usage(msg)
    except Usage, err:
        print >>sys.stderr, err.msg
        print >>sys.stderr, "for help use --help"
        return 2

    # process options
    globals = None
    timesteps = None
    variables = None
    for opt, arg in opts:
        if opt in ('-h', '--help'):
            print usage()
            sys.exit(0)
        elif opt in ('-g', '--globals'):
            globals = arg.split(',')
        elif opt in ('-t', '--timesteps'):
            timesteps = arg.split(',')
        elif opt in ('-v', '--variables'):
            variables = arg.split(',')

    # process arguments
    if len(args) != 2:
        msg = usage()
        print 'msg=%s' % msg
        raise Usage(msg)

    try:
        files_are_the_same(args, globals=globals,
                           timesteps=timesteps, variables=variables)
    except RuntimeError, msg:
         print msg
         return 10


if __name__ == "__main__":
    global ProgName

    ProgName = os.path.basename(sys.argv[0])

    sys.exit(main())

