Changeset 7841


Ignore:
Timestamp:
Jun 15, 2010, 12:06:46 PM (14 years ago)
Author:
hudson
Message:

Refactorings to allow tests to pass.

Location:
trunk/anuga_core/source/anuga
Files:
1 added
15 edited

Legend:

Unmodified
Added
Removed
  • trunk/anuga_core/source/anuga/__init__.py

    r7814 r7841  
    3030from anuga.shallow_water.shallow_water_domain import Domain
    3131
    32 from anuga.abstract_2d_finite_volumes.util import file_function, sww2timeseries
    33 
    34 from anuga.abstract_2d_finite_volumes.mesh_factory import rectangular_cross
     32from anuga.abstract_2d_finite_volumes.util import file_function, \
     33                                        sww2timeseries, sww2csv_gauges
     34
     35from anuga.abstract_2d_finite_volumes.mesh_factory import rectangular_cross, \
     36                                                    rectangular
    3537
    3638from anuga.file.csv_file import load_csv_as_building_polygons,  \
     
    4749
    4850from anuga.utilities.system_tools import file_length
    49 
     51from anuga.utilities.sww_merge import sww_merge
    5052from anuga.utilities.file_utils import copy_code_files
    5153
     
    189191    throw an error
    190192   
    191     Interior_holes is a list of ploygons for each hole.
     193    interior_holes is a list of ploygons for each hole. These polygons do not
     194    need to be closed, but their points must be specified in a counter-clockwise
     195    order.
    192196
    193197    This function does not allow segments to share points - use underlying
  • trunk/anuga_core/source/anuga/config.py

    r7810 r7841  
    3434# Major revision number for use with create_distribution
    3535# and update_anuga_user_guide
    36 major_revision = '1.1beta'
     36major_revision = '1.2'
    3737
    3838################################################################################
  • trunk/anuga_core/source/anuga/file/sww.py

    r7796 r7841  
    4747
    4848        #FIXME: Should we have a general set_precision function?
    49 
    50 
    51 ##
    52 # @brief Class for handling checkpoints data
    53 # @note This is not operational at the moment
    54 class CPT_file(Data_format):
    55     """Interface to native NetCDF format (.cpt) to be
    56     used for checkpointing (one day)
    57     """
    58 
    59     ##
    60     # @brief Initialize this instantiation.
    61     # @param domain ??
    62     # @param mode Mode of underlying data file (default WRITE).
    63     def __init__(self, domain, mode=netcdf_mode_w):
    64         from Scientific.IO.NetCDF import NetCDFFile
    65 
    66         self.precision = netcdf_float #Use full precision
    67 
    68         Data_format.__init__(self, domain, 'sww', mode)
    69 
    70         # NetCDF file definition
    71         fid = NetCDFFile(self.filename, mode)
    72         if mode[0] == 'w':
    73             # Create new file
    74             fid.institution = 'Geoscience Australia'
    75             fid.description = 'Checkpoint data'
    76             #fid.smooth = domain.smooth
    77             fid.order = domain.default_order
    78 
    79             # Dimension definitions
    80             fid.createDimension('number_of_volumes', self.number_of_volumes)
    81             fid.createDimension('number_of_vertices', 3)
    82 
    83             # Store info at all vertices (no smoothing)
    84             fid.createDimension('number_of_points', 3*self.number_of_volumes)
    85             fid.createDimension('number_of_timesteps', None) #extensible
    86 
    87             # Variable definitions
    88 
    89             # Mesh
    90             fid.createVariable('x', self.precision, ('number_of_points',))
    91             fid.createVariable('y', self.precision, ('number_of_points',))
    92 
    93 
    94             fid.createVariable('volumes', netcdf_int, ('number_of_volumes',
    95                                                        'number_of_vertices'))
    96 
    97             fid.createVariable('time', self.precision, ('number_of_timesteps',))
    98 
    99             #Allocate space for all quantities
    100             for name in domain.quantities.keys():
    101                 fid.createVariable(name, self.precision,
    102                                    ('number_of_timesteps',
    103                                     'number_of_points'))
    104 
    105         fid.close()
    106 
    107     ##
    108     # @brief Store connectivity data to underlying data file.
    109     def store_checkpoint(self):
    110         """Write x,y coordinates of triangles.
    111         Write connectivity (
    112         constituting
    113         the bed elevation.
    114         """
    115 
    116         from Scientific.IO.NetCDF import NetCDFFile
    117 
    118         domain = self.domain
    119 
    120         #Get NetCDF
    121         fid = NetCDFFile(self.filename, netcdf_mode_a)
    122 
    123         # Get the variables
    124         x = fid.variables['x']
    125         y = fid.variables['y']
    126 
    127         volumes = fid.variables['volumes']
    128 
    129         # Get X, Y and bed elevation Z
    130         Q = domain.quantities['elevation']
    131         X,Y,Z,V = Q.get_vertex_values(xy=True, precision=self.precision)
    132 
    133         x[:] = X.astype(self.precision)
    134         y[:] = Y.astype(self.precision)
    135         z[:] = Z.astype(self.precision)
    136 
    137         volumes[:] = V
    138 
    139         fid.close()
    140 
    141     ##
    142     # @brief Store time and named quantities to underlying data file.
    143     # @param name
    144     def store_timestep(self, name):
    145         """Store time and named quantity to file
    146         """
    147 
    148         from Scientific.IO.NetCDF import NetCDFFile
    149         from time import sleep
    150 
    151         #Get NetCDF
    152         retries = 0
    153         file_open = False
    154         while not file_open and retries < 10:
    155             try:
    156                 fid = NetCDFFile(self.filename, netcdf_mode_a)
    157             except IOError:
    158                 #This could happen if someone was reading the file.
    159                 #In that case, wait a while and try again
    160                 msg = 'Warning (store_timestep): File %s could not be opened' \
    161                       ' - trying again' % self.filename
    162                 log.critical(msg)
    163                 retries += 1
    164                 sleep(1)
    165             else:
    166                 file_open = True
    167 
    168         if not file_open:
    169             msg = 'File %s could not be opened for append' % self.filename
    170             raise DataFileNotOpenError, msg
    171 
    172         domain = self.domain
    173 
    174         # Get the variables
    175         time = fid.variables['time']
    176         stage = fid.variables['stage']
    177         i = len(time)
    178 
    179         #Store stage
    180         time[i] = self.domain.time
    181 
    182         # Get quantity
    183         Q = domain.quantities[name]
    184         A,V = Q.get_vertex_values(xy=False, precision=self.precision)
    185 
    186         stage[i,:] = A.astype(self.precision)
    187 
    188         #Flush and close
    189         fid.sync()
    190         fid.close()
    19149
    19250
     
    536394
    537395    def read_mesh(self):
     396        """ Read and store the mesh data contained within this sww file.
     397        """
    538398        fin = NetCDFFile(self.source, 'r')
    539399
     
    551411
    552412
    553 
    554413        fin.close()
    555414       
    556415    def read_quantities(self, frame_number=0):
    557 
     416        """
     417        Read the quantities contained in this file.
     418        frame_number is the time index to load.
     419        """
    558420        assert frame_number >= 0 and frame_number <= self.last_frame_number
    559421
     
    575437
    576438    def get_bounds(self):
     439        """
     440            Get the bounding rect around the mesh.
     441        """
    577442        return [self.xmin, self.xmax, self.ymin, self.ymax]
    578443
    579444    def get_last_frame_number(self):
     445        """
     446            Return the last loaded frame index.
     447        """
    580448        return self.last_frame_number
    581449
    582450    def get_time(self):
     451        """
     452            Get time at the current frame num, in secs.
     453        """
    583454        return self.time[self.frame_number]
    584455
    585456
    586 # @brief A class to write an SWW file.
    587457class Write_sww:
    588    
     458    """
     459        A class to write an SWW file.
     460       
     461        It is domain agnostic, and requires all the data to be fed in
     462        manually.
     463    """
    589464    RANGE = '_range'
    590465    EXTREMA = ':extrema'
    591466
    592     ##
    593     # brief Instantiate the SWW writer class.
    594467    def __init__(self, static_quantities, dynamic_quantities):
    595468        """Initialise Write_sww with two list af quantity names:
     
    607480
    608481
    609     ##
    610     # @brief Store a header in the SWW file.
    611     # @param outfile Open handle to the file that will be written.
    612     # @param times A list of time slices *or* a start time.
    613     # @param number_of_volumes The number of triangles.
    614     # @param number_of_points The number of points.
    615     # @param description The internal file description string.
    616     # @param smoothing True if smoothing is to be used.
    617     # @param order
    618     # @param sww_precision Data type of the quantity written (netcdf constant)
    619     # @param verbose True if this function is to be verbose.
    620     # @note If 'times' is a list, the info will be made relative.
    621482    def store_header(self,
    622483                     outfile,
     
    631492        """Write an SWW file header.
    632493
     494        Writes the first section of the .sww file.
     495
    633496        outfile - the open file that will be written
    634497        times - A list of the time slice times OR a start time
    635498        Note, if a list is given the info will be made relative.
    636499        number_of_volumes - the number of triangles
     500        number_of_points - the number of vertices in the mesh
    637501        """
    638502
     
    746610                         % (num.min(times), num.max(times), len(times.flat)))
    747611
    748     ##
    749     # @brief Store triangulation data in the underlying file.
    750     # @param outfile Open handle to underlying file.
    751     # @param points_utm List or array of points in UTM.
    752     # @param volumes
    753     # @param zone
    754     # @param new_origin georeference that the points can be set to.
    755     # @param points_georeference The georeference of the points_utm.
    756     # @param verbose True if this function is to be verbose.
     612
    757613    def store_triangulation(self,
    758614                            outfile,
     
    764620                            verbose=False):
    765621        """
     622        Store triangulation data in the underlying file.
     623       
     624        Stores the points and triangle indices in the sww file
     625       
     626        outfile Open handle to underlying file.
     627
     628        new_origin georeference that the points can be set to.
     629
     630        points_georeference The georeference of the points_utm.
     631
     632        verbose True if this function is to be verbose.
     633
    766634        new_origin - qa georeference that the points can be set to. (Maybe
    767635        do this before calling this function.)
     
    840708
    841709
    842 
    843     # @brief Write the static quantity data to the underlying file.
    844     # @param outfile Handle to open underlying file.
    845     # @param sww_precision Format of quantity data to write (default Float32).
    846     # @param verbose True if this function is to be verbose.
    847     # @param **quant
    848710    def store_static_quantities(self,
    849711                                outfile,
     
    896758       
    897759       
    898     ##
    899     # @brief Write the quantity data to the underlying file.
    900     # @param outfile Handle to open underlying file.
    901     # @param sww_precision Format of quantity data to write (default Float32).
    902     # @param slice_index
    903     # @param time
    904     # @param verbose True if this function is to be verbose.
    905     # @param **quant
    906760    def store_quantities(self,
    907761                         outfile,
     
    986840
    987841
    988 ##
    989 # @brief Get the extents of a NetCDF data file.
    990 # @param file_name The path to the SWW file.
    991 # @return A list of x, y, z and stage limits (min, max).
    992842def extent_sww(file_name):
    993     """Read in an sww file.
     843    """Read in an sww file, then get its extents
    994844
    995845    Input:
     
    1015865
    1016866
    1017 ##
    1018 # @brief
    1019 # @param filename
    1020 # @param boundary
    1021 # @param t
    1022 # @param fail_if_NaN
    1023 # @param NaN_filler
    1024 # @param verbose
    1025 # @param very_verbose
    1026 # @return
    1027867def load_sww_as_domain(filename, boundary=None, t=None,
    1028868               fail_if_NaN=True, NaN_filler=0,
    1029869               verbose=False, very_verbose=False):
    1030870    """
     871    Load an sww file into a domain.
     872
    1031873    Usage: domain = load_sww_as_domain('file.sww',
    1032874                        t=time (default = last time in file))
     
    11941036
    11951037
    1196 ##
    1197 # @brief Get mesh and quantity data from an SWW file.
    1198 # @param filename Path to data file to read.
    1199 # @param quantities UNUSED!
    1200 # @param verbose True if this function is to be verbose.
    1201 # @return (mesh, quantities, time) where mesh is the mesh data, quantities is
    1202 #         a dictionary of {name: value}, and time is the time vector.
    1203 # @note Quantities extracted: 'elevation', 'stage', 'xmomentum' and 'ymomentum'
    12041038def get_mesh_and_quantities_from_file(filename,
    12051039                                      quantities=None,
     
    12801114
    12811115
    1282 
    1283 ##
    1284 # @brief
    1285 # @parm time
    1286 # @param t
    1287 # @return An (index, ration) tuple.
    12881116def get_time_interp(time, t=None):
    1289     #Finds the ratio and index for time interpolation.
    1290     #It is borrowed from previous abstract_2d_finite_volumes code.
     1117    """Finds the ratio and index for time interpolation.
     1118        time is an array of time steps
     1119        t is the sample time.
     1120        returns a tuple containing index into time, and ratio
     1121    """
    12911122    if t is None:
    12921123        t=time[-1]
     
    13371168
    13381169
    1339 ##
    1340 # @brief
    1341 # @param coordinates
    1342 # @param volumes
    1343 # @param boundary
    13441170def weed(coordinates, volumes, boundary=None):
     1171    """ Excise all duplicate points.
     1172    """
    13451173    if isinstance(coordinates, num.ndarray):
    13461174        coordinates = coordinates.tolist()
  • trunk/anuga_core/source/anuga/file_conversion/dem2pts.py

    r7814 r7841  
    66from anuga.config import netcdf_mode_r, netcdf_mode_w, netcdf_mode_a, \
    77                            netcdf_float
     8
     9from asc2dem import asc2dem
    810                           
    9 ##
    10 # @brief Convert DEM data  to PTS data.
    11 # @param basename_in Stem of input filename.
    12 # @param basename_out Stem of output filename.
    13 # @param easting_min
    14 # @param easting_max
    15 # @param northing_min
    16 # @param northing_max
    17 # @param use_cache
    18 # @param verbose
    19 # @return
     11
    2012def dem2pts(name_in, name_out=None,
    2113            easting_min=None, easting_max=None,
     
    3325    NODATA_value  -9999
    3426    138.3698 137.4194 136.5062 135.5558 ..........
     27
     28    name_in may be a .asc or .dem file to be converted.
    3529
    3630    Convert to NetCDF pts format which is
     
    8175    from Scientific.IO.NetCDF import NetCDFFile
    8276
    83     if name_in[-4:] != '.dem':
    84         raise IOError('Input file %s should be of type .dem.' % name_in)
     77    root = name_in[:-4]
     78
     79    if name_in[-4:] == '.asc':
     80        intermediate = root + '.dem'
     81        if verbose:
     82            log.critical('Preconvert %s from asc to %s' % \
     83                                    (name_in, intermediate))
     84        asc2dem(name_in)
     85        name_in = intermediate
     86    elif name_in[-4:] != '.dem':
     87        raise IOError('Input file %s should be of type .asc or .dem.' % name_in)
    8588
    8689    if name_out != None and basename_out[-4:] != '.pts':
    8790        raise IOError('Input file %s should be of type .pts.' % name_out)
    88 
    89     root = name_in[:-4]
    9091
    9192    # Get NetCDF
  • trunk/anuga_core/source/anuga/file_conversion/sww2dem.py

    r7814 r7841  
    124124    if block_size is None:
    125125        block_size = DEFAULT_BLOCK_SIZE
     126
     127    assert(isinstance(block_size, (int, long, float)))
    126128
    127129    # Read sww file
     
    217219    result = num.zeros(number_of_points, num.float)
    218220
     221    if verbose:
     222        msg = 'Slicing sww file, num points: ' + str(number_of_points)
     223        msg += ', block size: ' + str(block_size)
     224        log.critical(msg)
     225
    219226    for start_slice in xrange(0, number_of_points, block_size):
    220227        # Limit slice size to array end if at last block
     
    357364
    358365        #Write
    359         if verbose: log.critical('Writing %s' % demfile)
     366        if verbose:
     367            log.critical('Writing %s' % name_out)
    360368
    361369        import ermapper_grids
     
    437445                format='ers'):
    438446    """Wrapper for sww2dem.
    439     See sww2dem to find out what most of the parameters do.
     447    See sww2dem to find out what most of the parameters do. Note that since this
     448    is a batch command, the normal filename naming conventions do not apply.
    440449
    441450    basename_in is a path to sww file/s, without the .sww extension.
     451    extra_name_out is a postfix to add to the output filename.
    442452
    443453    Quantities is a list of quantities.  Each quantity will be
     
    474484                basename_out = sww_file + '_' + quantity + '_' + extra_name_out
    475485
    476             file_out = sww2dem(dir+os.sep+sww_file+'.sww',
    477                                dir+os.sep+basename_out,
     486            swwin = dir+os.sep+sww_file+'.sww'
     487            demout = dir+os.sep+basename_out+'.'+format
     488
     489            if verbose:
     490                log.critical('sww2dem: %s => %s' % (swwin, demout))
     491
     492            file_out = sww2dem(swwin,
     493                               demout,
    478494                               quantity,
    479495                               reduction,
     
    487503                               verbose,
    488504                               origin,
    489                                datum,
    490                                format)
     505                               datum)
     506                               
    491507            files_out.append(file_out)
    492508    return files_out
  • trunk/anuga_core/source/anuga/file_conversion/test_dem2pts.py

    r7814 r7841  
    7979
    8080        #Convert to NetCDF pts
    81         asc2dem(filename)
    82         dem2pts(root+'.dem', easting_min=2002.0, easting_max=2007.0,
     81        dem2pts(filename, easting_min=2002.0, easting_max=2007.0,
    8382                northing_min=3003.0, northing_max=3006.0,
    8483                verbose=False)
     
    199198
    200199        #Convert to NetCDF pts
    201         asc2dem(filename)
    202         dem2pts(root+'.dem', easting_min=2002.0, easting_max=2007.0,
     200        dem2pts(filename, easting_min=2002.0, easting_max=2007.0,
    203201                northing_min=3003.0, northing_max=3006.0)
    204202
     
    328326
    329327        #Convert to NetCDF pts
    330         asc2dem(filename)
    331         dem2pts(root+'.dem', easting_min=2002.0, easting_max=2007.0,
     328        dem2pts(filename, easting_min=2002.0, easting_max=2007.0,
    332329                northing_min=3003.0, northing_max=3006.0)
    333330
  • trunk/anuga_core/source/anuga/file_conversion/test_sww2dem.py

    r7814 r7841  
    2020
    2121# local modules
    22 from sww2dem import sww2dem
     22from sww2dem import sww2dem, sww2dem_batch
    2323
    2424class Test_Sww2Dem(unittest.TestCase):
     
    13921392        os.remove(self.domain.get_name() + '_elevation')
    13931393        os.remove(self.domain.get_name() + '_elevation.ers')
     1394       
     1395    def test_export_grid_parallel(self):
     1396        """Test that sww information can be converted correctly to asc/prj
     1397        format readable by e.g. ArcView
     1398        """
     1399
     1400        import time, os
     1401        from Scientific.IO.NetCDF import NetCDFFile
     1402
     1403        base_name = 'tegp'
     1404        #Setup
     1405        self.domain.set_name(base_name+'_P0_8')
     1406        swwfile = self.domain.get_name() + '.sww'
     1407
     1408        self.domain.set_datadir('.')
     1409        self.domain.format = 'sww'
     1410        self.domain.smooth = True
     1411        self.domain.set_quantity('elevation', lambda x,y: -x-y)
     1412        self.domain.set_quantity('stage', 1.0)
     1413
     1414        self.domain.geo_reference = Geo_reference(56,308500,6189000)
     1415
     1416        sww = SWW_file(self.domain)
     1417        sww.store_connectivity()
     1418        sww.store_timestep()
     1419        self.domain.evolve_to_end(finaltime = 0.0001)
     1420        #Setup
     1421        self.domain.set_name(base_name+'_P1_8')
     1422        swwfile2 = self.domain.get_name() + '.sww'
     1423        sww = SWW_file(self.domain)
     1424        sww.store_connectivity()
     1425        sww.store_timestep()
     1426        self.domain.evolve_to_end(finaltime = 0.0002)
     1427        sww.store_timestep()
     1428
     1429        cellsize = 0.25
     1430        #Check contents
     1431        #Get NetCDF
     1432
     1433        fid = NetCDFFile(sww.filename, netcdf_mode_r)
     1434
     1435        # Get the variables
     1436        x = fid.variables['x'][:]
     1437        y = fid.variables['y'][:]
     1438        z = fid.variables['elevation'][:]
     1439        time = fid.variables['time'][:]
     1440        stage = fid.variables['stage'][:]
     1441
     1442        fid.close()
     1443
     1444        #Export to ascii/prj files
     1445        extra_name_out = 'yeah'
     1446        sww2dem_batch(base_name,
     1447                    quantities = ['elevation', 'depth'],
     1448                    extra_name_out = extra_name_out,
     1449                    cellsize = cellsize,
     1450                    verbose = self.verbose,
     1451                    format = 'asc')
     1452
     1453        prjfile = base_name + '_P0_8_elevation_yeah.prj'
     1454        ascfile = base_name + '_P0_8_elevation_yeah.asc'       
     1455        #Check asc file
     1456        ascid = open(ascfile)
     1457        lines = ascid.readlines()
     1458        ascid.close()
     1459        #Check grid values
     1460        for j in range(5):
     1461            L = lines[6+j].strip().split()
     1462            y = (4-j) * cellsize
     1463            for i in range(5):
     1464                #print " -i*cellsize - y",  -i*cellsize - y
     1465                #print "float(L[i])", float(L[i])
     1466                assert num.allclose(float(L[i]), -i*cellsize - y)               
     1467        #Cleanup
     1468        os.remove(prjfile)
     1469        os.remove(ascfile)
     1470
     1471        prjfile = base_name + '_P1_8_elevation_yeah.prj'
     1472        ascfile = base_name + '_P1_8_elevation_yeah.asc'       
     1473        #Check asc file
     1474        ascid = open(ascfile)
     1475        lines = ascid.readlines()
     1476        ascid.close()
     1477        #Check grid values
     1478        for j in range(5):
     1479            L = lines[6+j].strip().split()
     1480            y = (4-j) * cellsize
     1481            for i in range(5):
     1482                #print " -i*cellsize - y",  -i*cellsize - y
     1483                #print "float(L[i])", float(L[i])
     1484                assert num.allclose(float(L[i]), -i*cellsize - y)               
     1485        #Cleanup
     1486        os.remove(prjfile)
     1487        os.remove(ascfile)
     1488        os.remove(swwfile)
     1489
     1490        #Check asc file
     1491        ascfile = base_name + '_P0_8_depth_yeah.asc'
     1492        prjfile = base_name + '_P0_8_depth_yeah.prj'
     1493        ascid = open(ascfile)
     1494        lines = ascid.readlines()
     1495        ascid.close()
     1496        #Check grid values
     1497        for j in range(5):
     1498            L = lines[6+j].strip().split()
     1499            y = (4-j) * cellsize
     1500            for i in range(5):
     1501                assert num.allclose(float(L[i]), 1 - (-i*cellsize - y))
     1502        #Cleanup
     1503        os.remove(prjfile)
     1504        os.remove(ascfile)
     1505
     1506        #Check asc file
     1507        ascfile = base_name + '_P1_8_depth_yeah.asc'
     1508        prjfile = base_name + '_P1_8_depth_yeah.prj'
     1509        ascid = open(ascfile)
     1510        lines = ascid.readlines()
     1511        ascid.close()
     1512        #Check grid values
     1513        for j in range(5):
     1514            L = lines[6+j].strip().split()
     1515            y = (4-j) * cellsize
     1516            for i in range(5):
     1517                assert num.allclose(float(L[i]), 1 - (-i*cellsize - y))
     1518        #Cleanup
     1519        os.remove(prjfile)
     1520        os.remove(ascfile)
     1521        os.remove(swwfile2)
     1522       
    13941523
    13951524#################################################################################
  • trunk/anuga_core/source/anuga/fit_interpolate/interpolate.py

    r7810 r7841  
    11"""Least squares interpolation.
    22
    3    Implements a least-squares interpolation.
    4    Putting mesh data onto points.
     3   These functions and classes calculate a value at a particular point on
     4   the given mesh. It interpolates the values stored at the vertices of the
     5   mesh.
     6   
     7   For example, if you want to get the height of a terrain mesh at particular
     8   point, you pass the point to an Interpolate class. The point will intersect
     9   one of the triangles on the mesh, and the interpolated height will be an
     10   intermediate value between the three vertices of that triangle.
     11   This value is returned by the class.
    512
    613   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
     
    896903                    if verbose is True:
    897904                        import sys
    898                         if sys.platform == 'win32':
    899                             # FIXME (Ole): Why only Windoze?
    900                             from anuga.geometry.polygon import plot_polygons
    901                             title = ('Interpolation points fall '
    902                                      'outside specified mesh')
    903                             plot_polygons([mesh_boundary_polygon,
    904                                            interpolation_points,
    905                                            out_interp_pts],
    906                                           ['line', 'point', 'outside'],
    907                                           figname='points_boundary_out',
    908                                           label=title,
    909                                           verbose=verbose)
     905                        from anuga.geometry.polygon import plot_polygons
     906                        title = ('Interpolation points fall '
     907                                 'outside specified mesh')
     908                        plot_polygons([mesh_boundary_polygon,
     909                                       interpolation_points,
     910                                       out_interp_pts],
     911                                      ['line', 'point', 'outside'],
     912                                      figname='points_boundary_out',
     913                                      label=title)
    910914
    911915                    # Joaquim Luis suggested this as an Exception, so
  • trunk/anuga_core/source/anuga/geometry/__init__.py

    r7711 r7841  
    22"""
    33
    4 pass
    54
    65#Add path of package to PYTHONPATH to allow C-extensions to be loaded
  • trunk/anuga_core/source/anuga/geometry/polygon.py

    r7778 r7841  
    55import numpy as num
    66
    7 from math import sqrt
    87from anuga.utilities.numerical_tools import ensure_numeric
    98from anuga.geospatial_data.geospatial_data import ensure_absolute, \
    109                                                    Geospatial_data
    11 from anuga.config import netcdf_float
    1210import anuga.utilities.log as log
    1311
     
    129127    line1 = ensure_numeric(line1, num.float)
    130128
    131     x0 = line0[0,0]; y0 = line0[0,1]
    132     x1 = line0[1,0]; y1 = line0[1,1]
    133 
    134     x2 = line1[0,0]; y2 = line1[0,1]
    135     x3 = line1[1,0]; y3 = line1[1,1]
     129    x0 = line0[0, 0]; y0 = line0[0, 1]
     130    x1 = line0[1, 0]; y1 = line0[1, 1]
     131
     132    x2 = line1[0, 0]; y2 = line1[0, 1]
     133    x3 = line1[1, 0]; y3 = line1[1, 1]
    136134
    137135    denom = (y3-y2)*(x1-x0) - (x3-x2)*(y1-y0)
     
    208206    line1 = ensure_numeric(line1, num.float)
    209207
    210     status, value = _intersection(line0[0,0], line0[0,1],
    211                                   line0[1,0], line0[1,1],
    212                                   line1[0,0], line1[0,1],
    213                                   line1[1,0], line1[1,1])
     208    status, value = _intersection(line0[0, 0], line0[0, 1],
     209                                  line0[1, 0], line0[1, 1],
     210                                  line1[0, 0], line1[0, 1],
     211                                  line1[1, 0], line1[1, 1])
    214212
    215213    return status, value
     
    219217                       rtol=1.0e-12,
    220218                       atol=1.0e-12,                     
    221                        check_inputs=True,
    222                        verbose=False):
     219                       check_inputs=True):
    223220    """Determine if one point is inside a triangle
    224221   
     
    278275   
    279276    # Quickly reject points that are clearly outside
    280     if point[0] < min(triangle[:, 0]): return False
    281     if point[0] > max(triangle[:, 0]): return False   
    282     if point[1] < min(triangle[:, 1]): return False
    283     if point[1] > max(triangle[:, 1]): return False       
     277    if point[0] < min(triangle[:, 0]):
     278        return False
     279    if point[0] > max(triangle[:, 0]):
     280        return False   
     281    if point[1] < min(triangle[:, 1]):
     282        return False
     283    if point[1] > max(triangle[:, 1]):
     284        return False       
    284285
    285286
     
    314315        # Check if point lies on one of the edges
    315316       
    316         for X, Y in [[A,B], [B,C], [C,A]]:
     317        for X, Y in [[A, B], [B, C], [C, A]]:
    317318            res = _point_on_line(point[0], point[1],
    318319                                 X[0], X[1],
     
    341342   
    342343    def segments_joined(seg0, seg1):
     344        """ See if there are identical segments in the 2 lists. """
    343345        for i in seg0:
    344346            for j in seg1:   
     
    381383    return False
    382384   
    383    
    384 def is_inside_polygon_quick(point, polygon, closed=True, verbose=False):
    385     """Determine if one point is inside a polygon
    386     Both point and polygon are assumed to be numeric arrays or lists and
    387     no georeferencing etc or other checks will take place.
    388    
    389     As such it is faster than is_inside_polygon
    390     """
    391 
    392     # FIXME(Ole): This function isn't being used
    393     polygon = ensure_numeric(polygon, num.float)
    394     points = ensure_numeric(point, num.float) # Convert point to array of points
    395     points = num.ascontiguousarray(points[num.newaxis, :])
    396     msg = ('is_inside_polygon() must be invoked with one point only.\n'
    397            'I got %s and converted to %s' % (str(point), str(points.shape)))
    398     assert points.shape[0] == 1 and points.shape[1] == 2, msg
    399    
    400     indices = num.zeros(1, num.int)
    401 
    402     count = _separate_points_by_polygon(points, polygon, indices,
    403                                         int(closed), int(verbose))
    404 
    405     return count > 0
    406 
    407385
    408386def is_inside_polygon(point, polygon, closed=True, verbose=False):
     
    420398    else:
    421399        msg = 'is_inside_polygon must be invoked with one point only'
    422         raise msg
     400        raise Exception(msg)
    423401
    424402##
     
    743721                  figname=None,
    744722                  label=None,
    745                   alpha=None,
    746                   verbose=False):
     723                  alpha=None):
    747724    """ Take list of polygons and plot.
    748725
     
    769746    """
    770747
    771     from pylab import ion, hold, plot, axis, figure, legend, savefig, xlabel, \
     748    from pylab import ion, hold, plot, savefig, xlabel, \
    772749                      ylabel, title, close, title, fill
    773750
     
    793770            alpha = None
    794771        else:
    795             if alpha < 0.0:
    796                 alpha = 0.0
    797             if alpha > 1.0:
    798                 alpha = 1.0
     772            alpha = max(0.0, min(1.0, alpha))
    799773
    800774    n = len(polygons_points)
     
    818792    for i, item in enumerate(polygons_points):
    819793        x, y = poly_xy(item)
    820         if min(x) < minx: minx = min(x)
    821         if max(x) > maxx: maxx = max(x)
    822         if min(y) < miny: miny = min(y)
    823         if max(y) > maxy: maxy = max(y)
    824         plot(x,y,colour[i])
     794        if min(x) < minx:
     795            minx = min(x)
     796        if max(x) > maxx:
     797            maxx = max(x)
     798        if min(y) < miny:
     799            miny = min(y)
     800        if max(y) > maxy:
     801            maxy = max(y)
     802        plot(x, y, colour[i])
    825803        if alpha:
    826804            fill(x, y, colour[i], alpha=alpha)
     
    840818
    841819
    842 def poly_xy(polygon, verbose=False):
     820def poly_xy(polygon):
    843821    """ this is used within plot_polygons so need to duplicate
    844822        the first point so can have closed polygon in plot
     
    858836        raise Exception, msg
    859837
    860     x = polygon[:,0]
    861     y = polygon[:,1]
    862     x = num.concatenate((x, [polygon[0,0]]), axis = 0)
    863     y = num.concatenate((y, [polygon[0,1]]), axis = 0)
     838    x = polygon[:, 0]
     839    y = polygon[:, 1]
     840    x = num.concatenate((x, [polygon[0, 0]]), axis = 0)
     841    y = num.concatenate((y, [polygon[0, 1]]), axis = 0)
    864842
    865843    return x, y
     
    960938        assert y.shape[0] == N
    961939
    962         points = num.ascontiguousarray(num.concatenate((x[:,num.newaxis],
    963                                                         y[:,num.newaxis]),
    964                                                        axis=1 ))
     940        points = num.ascontiguousarray(num.concatenate((x[:, num.newaxis],
     941                                                        y[:, num.newaxis]),
     942                                                       axis = 1 ))
    965943
    966944        if callable(self.default):
     
    10421020        fid.write('%f, %f\n' % point)
    10431021    fid.close()
    1044 
    1045 ##
    1046 # @brief Unimplemented.
    1047 def read_tagged_polygons(filename):
    1048     """
    1049     """
    1050     pass
    10511022
    10521023
     
    10971068            if exclude is not None:
    10981069                for ex_poly in exclude:
    1099                     if is_inside_polygon([x,y], ex_poly):
     1070                    if is_inside_polygon([x, y], ex_poly):
    11001071                        append = False
    11011072
    11021073        if append is True:
    1103             points.append([x,y])
     1074            points.append([x, y])
    11041075
    11051076    return points
     
    11221093    import exceptions
    11231094
    1124     class Found(exceptions.Exception): pass
     1095    class Found(exceptions.Exception):
     1096        pass
    11251097
    11261098    polygon = ensure_numeric(polygon)
     
    12321204    # Find outer extent of polygon
    12331205    num_polygon = ensure_numeric(polygon)
    1234     max_x = max(num_polygon[:,0])
    1235     max_y = max(num_polygon[:,1])
    1236     min_x = min(num_polygon[:,0])
    1237     min_y = min(num_polygon[:,1])
     1206    max_x = max(num_polygon[:, 0])
     1207    max_y = max(num_polygon[:, 1])
     1208    min_x = min(num_polygon[:, 0])
     1209    min_y = min(num_polygon[:, 1])
    12381210
    12391211    # Keep only some points making sure extrema are kept
     
    12571229                         interpolation_points=None,
    12581230                         rtol=1.0e-6,
    1259                          atol=1.0e-8,
    1260                          verbose=False):
     1231                         atol=1.0e-8):
    12611232    """Interpolate linearly between values data on polyline nodes
    12621233    of a polyline to list of interpolation points.
     
    12891260    gauge_neighbour_id = ensure_numeric(gauge_neighbour_id, num.int)
    12901261
    1291     n = polyline_nodes.shape[0]    # Number of nodes in polyline
     1262    num_nodes = polyline_nodes.shape[0]    # Number of nodes in polyline
    12921263
    12931264    # Input sanity check
    1294     msg = 'interpolation_points are not given (interpolate.py)'
    1295     assert interpolation_points is not None, msg
    1296 
    1297     msg = 'function value must be specified at every interpolation node'
    1298     assert data.shape[0] == polyline_nodes.shape[0], msg
    1299 
    1300     msg = 'Must define function value at one or more nodes'
    1301     assert data.shape[0] > 0, msg
    1302 
    1303     if n == 1:
    1304         msg = 'Polyline contained only one point. I need more. ' + str(data)
    1305         raise Exception, msg
    1306     elif n > 1:
     1265    assert_msg = 'interpolation_points are not given (interpolate.py)'
     1266    assert interpolation_points is not None, assert_msg
     1267
     1268    assert_msg = 'function value must be specified at every interpolation node'
     1269    assert data.shape[0] == polyline_nodes.shape[0], assert_msg
     1270
     1271    assert_msg = 'Must define function value at one or more nodes'
     1272    assert data.shape[0] > 0, assert_msg
     1273
     1274    if num_nodes == 1:
     1275        assert_msg = 'Polyline contained only one point. I need more. '
     1276        assert_msg += str(data)
     1277        raise Exception, assert_msg
     1278    elif num_nodes > 1:
    13071279        _interpolate_polyline(data,
    13081280                              polyline_nodes,
     
    13461318
    13471319else:
    1348     msg = 'C implementations could not be accessed by %s.\n ' %__file__
    1349     msg += 'Make sure compile_all.py has been run as described in '
    1350     msg += 'the ANUGA installation guide.'
    1351     raise Exception, msg
     1320    error_msg = 'C implementations could not be accessed by %s.\n ' %__file__
     1321    error_msg += 'Make sure compile_all.py has been run as described in '
     1322    error_msg += 'the ANUGA installation guide.'
     1323    raise Exception(error_msg)
    13521324
    13531325
  • trunk/anuga_core/source/anuga/geometry/test_polygon.py

    r7711 r7841  
    178178
    179179    def test_inside_polygon_main(self):
    180         """test_is_inside_polygon_quick
     180        """test_is_inside_polygon
    181181       
    182182        Test fast version of of is_inside_polygon
     
    186186        polygon = [[0,0], [1,0], [1,1], [0,1]]
    187187
    188         assert is_inside_polygon_quick( (0.5, 0.5), polygon )
    189         assert not is_inside_polygon_quick( (0.5, 1.5), polygon )
    190         assert not is_inside_polygon_quick( (0.5, -0.5), polygon )
    191         assert not is_inside_polygon_quick( (-0.5, 0.5), polygon )
    192         assert not is_inside_polygon_quick( (1.5, 0.5), polygon )
     188        assert is_inside_polygon( (0.5, 0.5), polygon )
     189        assert not is_inside_polygon( (0.5, 1.5), polygon )
     190        assert not is_inside_polygon( (0.5, -0.5), polygon )
     191        assert not is_inside_polygon( (-0.5, 0.5), polygon )
     192        assert not is_inside_polygon( (1.5, 0.5), polygon )
    193193
    194194        # Try point on borders
    195         assert is_inside_polygon_quick( (1., 0.5), polygon, closed=True)
    196         assert is_inside_polygon_quick( (0.5, 1), polygon, closed=True)
    197         assert is_inside_polygon_quick( (0., 0.5), polygon, closed=True)
    198         assert is_inside_polygon_quick( (0.5, 0.), polygon, closed=True)
    199 
    200         assert not is_inside_polygon_quick( (0.5, 1), polygon, closed=False)
    201         assert not is_inside_polygon_quick( (0., 0.5), polygon, closed=False)
    202         assert not is_inside_polygon_quick( (0.5, 0.), polygon, closed=False)
    203         assert not is_inside_polygon_quick( (1., 0.5), polygon, closed=False)
     195        assert is_inside_polygon( (1., 0.5), polygon, closed=True)
     196        assert is_inside_polygon( (0.5, 1), polygon, closed=True)
     197        assert is_inside_polygon( (0., 0.5), polygon, closed=True)
     198        assert is_inside_polygon( (0.5, 0.), polygon, closed=True)
     199
     200        assert not is_inside_polygon( (0.5, 1), polygon, closed=False)
     201        assert not is_inside_polygon( (0., 0.5), polygon, closed=False)
     202        assert not is_inside_polygon( (0.5, 0.), polygon, closed=False)
     203        assert not is_inside_polygon( (1., 0.5), polygon, closed=False)
    204204
    205205
     
    229229        assert not is_inside_polygon( (0.5, -0.5), polygon )
    230230
    231         assert is_inside_polygon_quick( (0.5, 0.5), polygon )
    232         assert is_inside_polygon_quick( (1, -0.5), polygon )
    233         assert is_inside_polygon_quick( (1.5, 0), polygon )
    234 
    235         assert not is_inside_polygon_quick( (0.5, 1.5), polygon )
    236         assert not is_inside_polygon_quick( (0.5, -0.5), polygon )
     231        assert is_inside_polygon( (0.5, 0.5), polygon )
     232        assert is_inside_polygon( (1, -0.5), polygon )
     233        assert is_inside_polygon( (1.5, 0), polygon )
     234
     235        assert not is_inside_polygon( (0.5, 1.5), polygon )
     236        assert not is_inside_polygon( (0.5, -0.5), polygon )
    237237
    238238        # Very convoluted polygon
     
    449449            assert is_inside_polygon(point, polygon)
    450450
    451             assert is_inside_polygon_quick(point, polygon)
     451            assert is_inside_polygon(point, polygon)
    452452
    453453
     
    458458        for point in points:
    459459            assert is_inside_polygon(point, polygon)
    460             assert is_inside_polygon_quick(point, polygon)
     460            assert is_inside_polygon(point, polygon)
    461461
    462462
     
    15981598        assert y[4] == 6
    15991599
    1600     # Disabled
    1601     def xtest_plot_polygons(self):
     1600
     1601    def test_plot_polygons(self):
    16021602        import os
    16031603
     
    16051605        polygon1 = [[0,0], [1,0], [1,1], [0,1]]
    16061606        polygon2 = [[1,1], [2,1], [3,2], [2,2]]
    1607         v = plot_polygons([polygon1, polygon2], 'test1')
     1607        v = plot_polygons([polygon1, polygon2], figname='test1')
    16081608        assert len(v) == 4
    16091609        assert v[0] == 0
     
    16141614        # Another case
    16151615        polygon3 = [[1,5], [10,1], [100,10], [50,10], [3,6]]
    1616         v = plot_polygons([polygon2,polygon3], 'test2')
     1616        v = plot_polygons([polygon2,polygon3], figname='test2')
    16171617        assert len(v) == 4
    16181618        assert v[0] == 1
  • trunk/anuga_core/source/anuga/shallow_water/test_data_manager.py

    r7780 r7841  
    11211121            self.failUnless(0 ==1,  'Bad input did not throw exception error!')
    11221122
    1123     def test_export_grid_parallel(self):
    1124         """Test that sww information can be converted correctly to asc/prj
    1125         format readable by e.g. ArcView
    1126         """
    1127 
    1128         import time, os
    1129         from Scientific.IO.NetCDF import NetCDFFile
    1130 
    1131         base_name = 'tegp'
    1132         #Setup
    1133         self.domain.set_name(base_name+'_P0_8')
    1134         swwfile = self.domain.get_name() + '.sww'
    1135 
    1136         self.domain.set_datadir('.')
    1137         self.domain.format = 'sww'
    1138         self.domain.smooth = True
    1139         self.domain.set_quantity('elevation', lambda x,y: -x-y)
    1140         self.domain.set_quantity('stage', 1.0)
    1141 
    1142         self.domain.geo_reference = Geo_reference(56,308500,6189000)
    1143 
    1144         sww = SWW_file(self.domain)
    1145         sww.store_connectivity()
    1146         sww.store_timestep()
    1147         self.domain.evolve_to_end(finaltime = 0.0001)
    1148         #Setup
    1149         self.domain.set_name(base_name+'_P1_8')
    1150         swwfile2 = self.domain.get_name() + '.sww'
    1151         sww = SWW_file(self.domain)
    1152         sww.store_connectivity()
    1153         sww.store_timestep()
    1154         self.domain.evolve_to_end(finaltime = 0.0002)
    1155         sww.store_timestep()
    1156 
    1157         cellsize = 0.25
    1158         #Check contents
    1159         #Get NetCDF
    1160 
    1161         fid = NetCDFFile(sww.filename, netcdf_mode_r)
    1162 
    1163         # Get the variables
    1164         x = fid.variables['x'][:]
    1165         y = fid.variables['y'][:]
    1166         z = fid.variables['elevation'][:]
    1167         time = fid.variables['time'][:]
    1168         stage = fid.variables['stage'][:]
    1169 
    1170         fid.close()
    1171 
    1172         #Export to ascii/prj files
    1173         extra_name_out = 'yeah'
    1174         sww2dem_batch(base_name,
    1175                     quantities = ['elevation', 'depth'],
    1176                     extra_name_out = extra_name_out,
    1177                     cellsize = cellsize,
    1178                     verbose = self.verbose,
    1179                     format = 'asc')
    1180 
    1181         prjfile = base_name + '_P0_8_elevation_yeah.prj'
    1182         ascfile = base_name + '_P0_8_elevation_yeah.asc'       
    1183         #Check asc file
    1184         ascid = open(ascfile)
    1185         lines = ascid.readlines()
    1186         ascid.close()
    1187         #Check grid values
    1188         for j in range(5):
    1189             L = lines[6+j].strip().split()
    1190             y = (4-j) * cellsize
    1191             for i in range(5):
    1192                 #print " -i*cellsize - y",  -i*cellsize - y
    1193                 #print "float(L[i])", float(L[i])
    1194                 assert num.allclose(float(L[i]), -i*cellsize - y)               
    1195         #Cleanup
    1196         os.remove(prjfile)
    1197         os.remove(ascfile)
    1198 
    1199         prjfile = base_name + '_P1_8_elevation_yeah.prj'
    1200         ascfile = base_name + '_P1_8_elevation_yeah.asc'       
    1201         #Check asc file
    1202         ascid = open(ascfile)
    1203         lines = ascid.readlines()
    1204         ascid.close()
    1205         #Check grid values
    1206         for j in range(5):
    1207             L = lines[6+j].strip().split()
    1208             y = (4-j) * cellsize
    1209             for i in range(5):
    1210                 #print " -i*cellsize - y",  -i*cellsize - y
    1211                 #print "float(L[i])", float(L[i])
    1212                 assert num.allclose(float(L[i]), -i*cellsize - y)               
    1213         #Cleanup
    1214         os.remove(prjfile)
    1215         os.remove(ascfile)
    1216         os.remove(swwfile)
    1217 
    1218         #Check asc file
    1219         ascfile = base_name + '_P0_8_depth_yeah.asc'
    1220         prjfile = base_name + '_P0_8_depth_yeah.prj'
    1221         ascid = open(ascfile)
    1222         lines = ascid.readlines()
    1223         ascid.close()
    1224         #Check grid values
    1225         for j in range(5):
    1226             L = lines[6+j].strip().split()
    1227             y = (4-j) * cellsize
    1228             for i in range(5):
    1229                 assert num.allclose(float(L[i]), 1 - (-i*cellsize - y))
    1230         #Cleanup
    1231         os.remove(prjfile)
    1232         os.remove(ascfile)
    1233 
    1234         #Check asc file
    1235         ascfile = base_name + '_P1_8_depth_yeah.asc'
    1236         prjfile = base_name + '_P1_8_depth_yeah.prj'
    1237         ascid = open(ascfile)
    1238         lines = ascid.readlines()
    1239         ascid.close()
    1240         #Check grid values
    1241         for j in range(5):
    1242             L = lines[6+j].strip().split()
    1243             y = (4-j) * cellsize
    1244             for i in range(5):
    1245                 assert num.allclose(float(L[i]), 1 - (-i*cellsize - y))
    1246         #Cleanup
    1247         os.remove(prjfile)
    1248         os.remove(ascfile)
    1249         os.remove(swwfile2)
    1250 
    1251 
    1252 
    1253     def DISABLEDtest_sww2domain2(self):
    1254         ##################################################################
    1255         #Same as previous test, but this checks how NaNs are handled.
    1256         ##################################################################
    1257 
    1258         #FIXME: See ticket 223
    1259 
    1260         from mesh_factory import rectangular
    1261 
    1262         #Create basic mesh
    1263         points, vertices, boundary = rectangular(2,2)
    1264 
    1265         #Create shallow water domain
    1266         domain = Domain(points, vertices, boundary)
    1267         domain.smooth = False
    1268         domain.store = True
    1269         domain.set_name('test_file')
    1270         domain.set_datadir('.')
    1271         domain.default_order=2
    1272 
    1273         domain.set_quantity('elevation', lambda x,y: -x/3)
    1274         domain.set_quantity('friction', 0.1)
    1275 
    1276         from math import sin, pi
    1277         Br = Reflective_boundary(domain)
    1278         Bt = Transmissive_boundary(domain)
    1279         Bd = Dirichlet_boundary([0.2,0.,0.])
    1280         Bw = Time_boundary(domain=domain,
    1281                            f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
    1282 
    1283         domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
    1284 
    1285         h = 0.05
    1286         elevation = domain.quantities['elevation'].vertex_values
    1287         domain.set_quantity('stage', elevation + h)
    1288 
    1289         domain.check_integrity()
    1290 
    1291         for t in domain.evolve(yieldstep = 1, finaltime = 2.0):
    1292             pass
    1293             #domain.write_time()
    1294 
    1295 
    1296         filename = domain.datadir + os.sep + domain.get_name() + '.sww'
    1297 
    1298         # Fail because NaNs are present
    1299         #domain2 = sww2domain(filename,
    1300         #                     boundary,
    1301         #                     fail_if_NaN=True,
    1302         #                     verbose=self.verbose)       
    1303         try:
    1304             domain2 = load_sww_as_domain(filename,
    1305                                  boundary,
    1306                                  fail_if_NaN=True,
    1307                                  verbose=self.verbose)
    1308         except DataDomainError:
    1309             # Now import it, filling NaNs to be -9999
    1310             filler = -9999
    1311             domain2 = load_sww_as_domain(filename,
    1312                                  None,
    1313                                  fail_if_NaN=False,
    1314                                  NaN_filler=filler,
    1315                                  verbose=self.verbose)
    1316         else:
    1317             raise Exception, 'should have failed'
    1318 
    1319            
    1320         # Now import it, filling NaNs to be 0
    1321         filler = -9999
    1322         domain2 = load_sww_as_domain(filename,
    1323                              None,
    1324                              fail_if_NaN=False,
    1325                              NaN_filler=filler,
    1326                              verbose=self.verbose)           
    1327                              
    1328         import sys; sys.exit()
    1329            
    1330         # Clean up
    1331         os.remove(filename)
    1332 
    1333 
    1334         bits = ['geo_reference.get_xllcorner()',
    1335                 'geo_reference.get_yllcorner()',
    1336                 'vertex_coordinates']
    1337 
    1338         for quantity in domain.quantities_to_be_stored:
    1339             bits.append('get_quantity("%s").get_integral()' %quantity)
    1340             bits.append('get_quantity("%s").get_values()' %quantity)
    1341 
    1342         for bit in bits:
    1343         #    print 'testing that domain.'+bit+' has been restored'
    1344             assert num.allclose(eval('domain.'+bit),eval('domain2.'+bit))
    1345 
    1346         print
    1347         print
    1348         print domain2.get_quantity('xmomentum').get_values()
    1349         print
    1350         print domain2.get_quantity('stage').get_values()
    1351         print
    1352              
    1353         print 'filler', filler
    1354         print max(domain2.get_quantity('xmomentum').get_values().flat)
    1355        
    1356         assert max(max(domain2.get_quantity('xmomentum').get_values()))==filler
    1357         assert min(min(domain2.get_quantity('xmomentum').get_values()))==filler
    1358         assert max(max(domain2.get_quantity('ymomentum').get_values()))==filler
    1359         assert min(min(domain2.get_quantity('ymomentum').get_values()))==filler
    1360 
    1361 
    1362 
    1363     #FIXME This fails - smooth makes the comparism too hard for allclose
    1364     def ztest_sww2domain3(self):
    1365         ################################################
    1366         #DOMAIN.SMOOTH = TRUE !!!!!!!!!!!!!!!!!!!!!!!!!!
    1367         ################################################
    1368         from mesh_factory import rectangular
    1369         #Create basic mesh
    1370 
    1371         yiel=0.01
    1372         points, vertices, boundary = rectangular(10,10)
    1373 
    1374         #Create shallow water domain
    1375         domain = Domain(points, vertices, boundary)
    1376         domain.geo_reference = Geo_reference(56,11,11)
    1377         domain.smooth = True
    1378         domain.store = True
    1379         domain.set_name('bedslope')
    1380         domain.default_order=2
    1381         #Bed-slope and friction
    1382         domain.set_quantity('elevation', lambda x,y: -x/3)
    1383         domain.set_quantity('friction', 0.1)
    1384         # Boundary conditions
    1385         from math import sin, pi
    1386         Br = Reflective_boundary(domain)
    1387         Bt = Transmissive_boundary(domain)
    1388         Bd = Dirichlet_boundary([0.2,0.,0.])
    1389         Bw = Time_boundary(domain=domain,
    1390                            f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
    1391 
    1392         domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
    1393 
    1394         domain.quantities_to_be_stored['xmomentum'] = 2
    1395         domain.quantities_to_be_stored['ymomentum'] = 2       
    1396         #Initial condition
    1397         h = 0.05
    1398         elevation = domain.quantities['elevation'].vertex_values
    1399         domain.set_quantity('stage', elevation + h)
    1400 
    1401 
    1402         domain.check_integrity()
    1403         #Evolution
    1404         for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
    1405         #    domain.write_time()
    1406             pass
    1407 
    1408 
    1409         filename = domain.datadir + os.sep + domain.get_name() + '.sww'
    1410         domain2 = load_sww_as_domain(filename,None,fail_if_NaN=False,verbose=self.verbose)
    1411         #points, vertices, boundary = rectangular(15,15)
    1412         #domain2.boundary = boundary
    1413         ###################
    1414         ##NOW TEST IT!!!
    1415         ###################
    1416 
    1417         os.remove(domain.get_name() + '.sww')
    1418 
    1419         #FIXME smooth domain so that they can be compared
    1420 
    1421 
    1422         bits = []
    1423         for quantity in domain.quantities_to_be_stored:
    1424             bits.append('quantities["%s"].get_integral()'%quantity)
    1425 
    1426 
    1427         for bit in bits:
    1428             #print 'testing that domain.'+bit+' has been restored'
    1429             #print bit
    1430             #print 'done'
    1431             #print ('domain.'+bit), eval('domain.'+bit)
    1432             #print ('domain2.'+bit), eval('domain2.'+bit)
    1433             assert num.allclose(eval('domain.'+bit),eval('domain2.'+bit),rtol=1.0e-1,atol=1.e-3)
    1434             pass
    1435 
    1436         ######################################
    1437         #Now evolve them both, just to be sure
    1438         ######################################x
    1439         domain.time = 0.
    1440         from time import sleep
    1441 
    1442         final = .5
    1443         domain.set_quantity('friction', 0.1)
    1444         domain.store = False
    1445         domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Br})
    1446 
    1447         for t in domain.evolve(yieldstep = yiel, finaltime = final):
    1448             #domain.write_time()
    1449             pass
    1450 
    1451         domain2.smooth = True
    1452         domain2.store = False
    1453         domain2.default_order=2
    1454         domain2.set_quantity('friction', 0.1)
    1455         #Bed-slope and friction
    1456         # Boundary conditions
    1457         Bd2=Dirichlet_boundary([0.2,0.,0.])
    1458         Br2 = Reflective_boundary(domain2)
    1459         domain2.boundary = domain.boundary
    1460         #print 'domain2.boundary'
    1461         #print domain2.boundary
    1462         domain2.set_boundary({'left': Bd2, 'right': Bd2, 'top': Bd2, 'bottom': Br2})
    1463         #domain2.boundary = domain.boundary
    1464         #domain2.set_boundary({'exterior': Bd})
    1465 
    1466         domain2.check_integrity()
    1467 
    1468         for t in domain2.evolve(yieldstep = yiel, finaltime = final):
    1469             #domain2.write_time()
    1470             pass
    1471 
    1472         ###################
    1473         ##NOW TEST IT!!!
    1474         ##################
    1475 
    1476         print '><><><><>>'
    1477         bits = [ 'vertex_coordinates']
    1478 
    1479         for quantity in ['elevation','xmomentum','ymomentum']:
    1480             #bits.append('quantities["%s"].get_integral()'%quantity)
    1481             bits.append('get_quantity("%s").get_values()' %quantity)
    1482 
    1483         for bit in bits:
    1484             print bit
    1485             assert num.allclose(eval('domain.'+bit),eval('domain2.'+bit))
    14861123
    14871124    def test_file_boundary_stsIV_sinewave_ordering(self):
  • trunk/anuga_core/source/anuga/shallow_water/test_shallow_water_domain.py

    r7804 r7841  
    13391339
    13401340                assert num.allclose(Q, uh*width)
     1341
    13411342
    13421343    def test_get_energy_through_cross_section_with_geo(self):
  • trunk/anuga_core/source/anuga/shallow_water/tsh2sww.py

    r7735 r7841  
    1313from anuga.pyvolution.pmesh2domain import pmesh_to_domain_instance
    1414import time, os
    15 from anuga.pyvolution.sww_file import SWW_file
     15from anuga.file.sww import SWW_file
    1616from anuga.utilities.numerical_tools import mean
    1717import anuga.utilities.log as log
  • trunk/anuga_core/source/anuga/utilities/file_utils.py

    r7778 r7841  
    88import numpy as num
    99import shutil
     10import log
    1011
    1112def make_filename(s):
     
    333334        if verbose:
    334335            log.critical('Make directory %s' % dir_name)
    335         mkdir(dir_name, 0777)
     336        os.mkdir(dir_name, 0777)
    336337
    337338    if verbose:
Note: See TracChangeset for help on using the changeset viewer.