source: trunk/anuga_core/source/anuga/file/test_sww.py @ 7872

Last change on this file since 7872 was 7872, checked in by hudson, 14 years ago

Fixed a few unit test errors.

File size: 16.5 KB
Line 
1import os
2import unittest
3import tempfile
4import numpy as num
5
6from anuga.coordinate_transforms.geo_reference import Geo_reference
7from csv_file import load_csv_as_array, load_csv_as_dict
8from anuga.abstract_2d_finite_volumes.mesh_factory import rectangular
9from anuga.shallow_water.shallow_water_domain import Domain
10from sww import load_sww_as_domain, weed, get_mesh_and_quantities_from_file, \
11                Write_sww
12from Scientific.IO.NetCDF import NetCDFFile
13
14from anuga.config import netcdf_mode_w, netcdf_float
15
16# boundary functions
17from anuga.shallow_water.boundaries import Reflective_boundary, \
18            Field_boundary, Transmissive_momentum_set_stage_boundary, \
19            Transmissive_stage_zero_momentum_boundary
20from anuga.abstract_2d_finite_volumes.generic_boundary_conditions\
21     import Transmissive_boundary, Dirichlet_boundary, \
22            Time_boundary, File_boundary, AWI_boundary
23
24
25class Test_sww(unittest.TestCase):
26    def setUp(self):
27        self.verbose = False
28        pass
29
30    def tearDown(self):
31        pass
32       
33    def test_sww2domain1(self):
34        ################################################
35        #Create a test domain, and evolve and save it.
36        ################################################
37        from mesh_factory import rectangular
38
39        #Create basic mesh
40
41        yiel=0.01
42        points, vertices, boundary = rectangular(10,10)
43
44        #Create shallow water domain
45        domain = Domain(points, vertices, boundary)
46        domain.geo_reference = Geo_reference(56,11,11)
47        domain.smooth = False
48        domain.store = True
49        domain.set_name('bedslope')
50        domain.default_order=2
51        #Bed-slope and friction
52        domain.set_quantity('elevation', lambda x,y: -x/3)
53        domain.set_quantity('friction', 0.1)
54        # Boundary conditions
55        from math import sin, pi
56        Br = Reflective_boundary(domain)
57        Bt = Transmissive_boundary(domain)
58        Bd = Dirichlet_boundary([0.2,0.,0.])
59        Bw = Time_boundary(domain=domain,f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
60
61        #domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
62        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
63
64        domain.quantities_to_be_stored['xmomentum'] = 2
65        domain.quantities_to_be_stored['ymomentum'] = 2
66        #Initial condition
67        h = 0.05
68        elevation = domain.quantities['elevation'].vertex_values
69        domain.set_quantity('stage', elevation + h)
70
71        domain.check_integrity()
72        #Evolution
73        #domain.tight_slope_limiters = 1
74        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
75            #domain.write_time()
76            pass
77
78
79
80        filename = domain.datadir + os.sep + domain.get_name() + '.sww'
81        domain2 = load_sww_as_domain(filename, None, fail_if_NaN=False,
82                                        verbose=self.verbose)
83        #points, vertices, boundary = rectangular(15,15)
84        #domain2.boundary = boundary
85        ###################
86        ##NOW TEST IT!!!
87        ###################
88
89        os.remove(filename)
90
91        bits = ['vertex_coordinates']
92        for quantity in domain.quantities_to_be_stored:
93            bits.append('get_quantity("%s").get_integral()' % quantity)
94            bits.append('get_quantity("%s").get_values()' % quantity)
95
96        for bit in bits:
97            #print 'testing that domain.'+bit+' has been restored'
98            #print bit
99            #print 'done'
100            assert num.allclose(eval('domain.'+bit),eval('domain2.'+bit))
101
102        ######################################
103        #Now evolve them both, just to be sure
104        ######################################x
105        domain.time = 0.
106        from time import sleep
107
108        final = .1
109        domain.set_quantity('friction', 0.1)
110        domain.store = False
111        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
112
113
114        for t in domain.evolve(yieldstep = yiel, finaltime = final):
115            #domain.write_time()
116            pass
117
118        final = final - (domain2.starttime-domain.starttime)
119        #BUT since domain1 gets time hacked back to 0:
120        final = final + (domain2.starttime-domain.starttime)
121
122        domain2.smooth = False
123        domain2.store = False
124        domain2.default_order=2
125        domain2.set_quantity('friction', 0.1)
126        #Bed-slope and friction
127        # Boundary conditions
128        Bd2=Dirichlet_boundary([0.2,0.,0.])
129        domain2.boundary = domain.boundary
130        #print 'domain2.boundary'
131        #print domain2.boundary
132        domain2.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
133        #domain2.set_boundary({'exterior': Bd})
134
135        domain2.check_integrity()
136
137        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
138            #domain2.write_time()
139            pass
140
141        ###################
142        ##NOW TEST IT!!!
143        ##################
144
145        bits = ['vertex_coordinates']
146
147        for quantity in ['elevation','stage', 'ymomentum','xmomentum']:
148            bits.append('get_quantity("%s").get_integral()' %quantity)
149            bits.append('get_quantity("%s").get_values()' %quantity)
150
151        #print bits
152        for bit in bits:
153            #print bit
154            #print eval('domain.'+bit)
155            #print eval('domain2.'+bit)
156           
157            #print eval('domain.'+bit+'-domain2.'+bit)
158            msg = 'Values in the two domains are different for ' + bit
159            assert num.allclose(eval('domain.'+bit),eval('domain2.'+bit),
160                                rtol=1.e-5, atol=3.e-8), msg
161
162
163
164    def test_get_mesh_and_quantities_from_sww_file(self):
165        """test_get_mesh_and_quantities_from_sww_file(self):
166        """     
167       
168        # Generate a test sww file with non trivial georeference
169       
170        import time, os
171        from Scientific.IO.NetCDF import NetCDFFile
172
173        # Setup
174        from mesh_factory import rectangular
175
176        # Create basic mesh (100m x 5m)
177        width = 5
178        length = 50
179        t_end = 10
180        points, vertices, boundary = rectangular(length, width, 50, 5)
181
182        # Create shallow water domain
183        domain = Domain(points, vertices, boundary,
184                        geo_reference = Geo_reference(56,308500,6189000))
185
186        domain.set_name('flowtest')
187        swwfile = domain.get_name() + '.sww'
188        domain.set_datadir('.')
189
190        Br = Reflective_boundary(domain)    # Side walls
191        Bd = Dirichlet_boundary([1, 0, 0])  # inflow
192
193        domain.set_boundary( {'left': Bd, 'right': Bd, 'top': Br, 'bottom': Br})
194
195        for t in domain.evolve(yieldstep=1, finaltime = t_end):
196            pass
197
198       
199        # Read it
200
201        # Get mesh and quantities from sww file
202        X = get_mesh_and_quantities_from_file(swwfile,
203                                              quantities=['elevation',
204                                                          'stage',
205                                                          'xmomentum',
206                                                          'ymomentum'], 
207                                              verbose=False)
208        mesh, quantities, time = X
209       
210
211        # Check that mesh has been recovered
212        assert num.alltrue(mesh.triangles == domain.get_triangles())
213        assert num.allclose(mesh.nodes, domain.get_nodes())
214
215        # Check that time has been recovered
216        assert num.allclose(time, range(t_end+1))
217
218        # Check that quantities have been recovered
219        # (sww files use single precision)
220        z=domain.get_quantity('elevation').get_values(location='unique vertices')
221        assert num.allclose(quantities['elevation'], z)
222
223        for q in ['stage', 'xmomentum', 'ymomentum']:
224            # Get quantity at last timestep
225            q_ref=domain.get_quantity(q).get_values(location='unique vertices')
226
227            #print q,quantities[q]
228            q_sww=quantities[q][-1,:]
229
230            msg = 'Quantity %s failed to be recovered' %q
231            assert num.allclose(q_ref, q_sww, atol=1.0e-6), msg
232           
233        # Cleanup
234        os.remove(swwfile)
235       
236       
237
238    def test_weed(self):
239        coordinates1 = [[0.,0.],[1.,0.],[1.,1.],[1.,0.],[2.,0.],[1.,1.]]
240        volumes1 = [[0,1,2],[3,4,5]]
241        boundary1= {(0,1): 'external',(1,2): 'not external',(2,0): 'external',(3,4): 'external',(4,5): 'external',(5,3): 'not external'}
242        coordinates2,volumes2,boundary2=weed(coordinates1,volumes1,boundary1)
243
244        points2 = {(0.,0.):None,(1.,0.):None,(1.,1.):None,(2.,0.):None}
245
246        assert len(points2)==len(coordinates2)
247        for i in range(len(coordinates2)):
248            coordinate = tuple(coordinates2[i])
249            assert points2.has_key(coordinate)
250            points2[coordinate]=i
251
252        for triangle in volumes1:
253            for coordinate in triangle:
254                assert coordinates2[points2[tuple(coordinates1[coordinate])]][0]==coordinates1[coordinate][0]
255                assert coordinates2[points2[tuple(coordinates1[coordinate])]][1]==coordinates1[coordinate][1]
256
257
258    def test_triangulation(self):
259        #
260       
261       
262        filename = tempfile.mktemp("_data_manager.sww")
263        outfile = NetCDFFile(filename, netcdf_mode_w)
264        points_utm = num.array([[0.,0.],[1.,1.], [0.,1.]])
265        volumes = (0,1,2)
266        elevation = [0,1,2]
267        new_origin = None
268        new_origin = Geo_reference(56, 0, 0)
269        times = [0, 10]
270        number_of_volumes = len(volumes)
271        number_of_points = len(points_utm)
272        sww = Write_sww(['elevation'], ['stage', 'xmomentum', 'ymomentum'])
273        sww.store_header(outfile, times, number_of_volumes,
274                         number_of_points, description='fully sick testing',
275                         verbose=self.verbose,sww_precision=netcdf_float)
276        sww.store_triangulation(outfile, points_utm, volumes,
277                                elevation,  new_origin=new_origin,
278                                verbose=self.verbose)       
279        outfile.close()
280        fid = NetCDFFile(filename)
281
282        x = fid.variables['x'][:]
283        y = fid.variables['y'][:]
284        fid.close()
285
286        assert num.allclose(num.array(map(None, x,y)), points_utm)
287        os.remove(filename)
288
289       
290    def test_triangulationII(self):
291        #
292       
293
294        DEFAULT_ZONE = 0 # Not documented anywhere what this should be.
295       
296        filename = tempfile.mktemp("_data_manager.sww")
297        outfile = NetCDFFile(filename, netcdf_mode_w)
298        points_utm = num.array([[0.,0.],[1.,1.], [0.,1.]])
299        volumes = (0,1,2)
300        elevation = [0,1,2]
301        new_origin = None
302        #new_origin = Geo_reference(56, 0, 0)
303        times = [0, 10]
304        number_of_volumes = len(volumes)
305        number_of_points = len(points_utm)
306        sww = Write_sww(['elevation'], ['stage', 'xmomentum', 'ymomentum'])       
307        sww.store_header(outfile, times, number_of_volumes,
308                         number_of_points, description='fully sick testing',
309                         verbose=self.verbose,sww_precision=netcdf_float)
310        sww.store_triangulation(outfile, points_utm, volumes,
311                                new_origin=new_origin,
312                                verbose=self.verbose)
313        sww.store_static_quantities(outfile, elevation=elevation)                               
314                               
315        outfile.close()
316        fid = NetCDFFile(filename)
317
318        x = fid.variables['x'][:]
319        y = fid.variables['y'][:]
320        results_georef = Geo_reference()
321        results_georef.read_NetCDF(fid)
322        assert results_georef == Geo_reference(DEFAULT_ZONE, 0, 0)
323        fid.close()
324
325        assert num.allclose(num.array(map(None, x,y)), points_utm)
326        os.remove(filename)
327
328       
329    def test_triangulation_new_origin(self):
330        #
331       
332       
333        filename = tempfile.mktemp("_data_manager.sww")
334        outfile = NetCDFFile(filename, netcdf_mode_w)
335        points_utm = num.array([[0.,0.],[1.,1.], [0.,1.]])
336        volumes = (0,1,2)
337        elevation = [0,1,2]
338        new_origin = None
339        new_origin = Geo_reference(56, 1, 554354)
340        points_utm = new_origin.change_points_geo_ref(points_utm)
341        times = [0, 10]
342        number_of_volumes = len(volumes)
343        number_of_points = len(points_utm)
344        sww = Write_sww(['elevation'], ['stage', 'xmomentum', 'ymomentum'])       
345        sww.store_header(outfile, times, number_of_volumes,
346                         number_of_points, description='fully sick testing',
347                         verbose=self.verbose,sww_precision=netcdf_float)
348        sww.store_triangulation(outfile, points_utm, volumes,
349                                elevation,  new_origin=new_origin,
350                                verbose=self.verbose)
351        outfile.close()
352        fid = NetCDFFile(filename)
353
354        x = fid.variables['x'][:]
355        y = fid.variables['y'][:]
356        results_georef = Geo_reference()
357        results_georef.read_NetCDF(fid)
358        assert results_georef == new_origin
359        fid.close()
360
361        absolute = Geo_reference(56, 0,0)
362        assert num.allclose(num.array(
363            absolute.change_points_geo_ref(map(None, x,y),
364                                           new_origin)),points_utm)
365       
366        os.remove(filename)
367       
368    def test_triangulation_points_georeference(self):
369        #
370       
371       
372        filename = tempfile.mktemp("_data_manager.sww")
373        outfile = NetCDFFile(filename, netcdf_mode_w)
374        points_utm = num.array([[0.,0.],[1.,1.], [0.,1.]])
375        volumes = (0,1,2)
376        elevation = [0,1,2]
377        new_origin = None
378        points_georeference = Geo_reference(56, 1, 554354)
379        points_utm = points_georeference.change_points_geo_ref(points_utm)
380        times = [0, 10]
381        number_of_volumes = len(volumes)
382        number_of_points = len(points_utm)
383        sww = Write_sww(['elevation'], ['stage', 'xmomentum', 'ymomentum'])       
384        sww.store_header(outfile, times, number_of_volumes,
385                         number_of_points, description='fully sick testing',
386                         verbose=self.verbose,sww_precision=netcdf_float)
387        sww.store_triangulation(outfile, points_utm, volumes,
388                                elevation,  new_origin=new_origin,
389                                points_georeference=points_georeference,
390                                verbose=self.verbose)       
391        outfile.close()
392        fid = NetCDFFile(filename)
393
394        x = fid.variables['x'][:]
395        y = fid.variables['y'][:]
396        results_georef = Geo_reference()
397        results_georef.read_NetCDF(fid)
398        assert results_georef == points_georeference
399        fid.close()
400
401        assert num.allclose(num.array(map(None, x,y)), points_utm)
402        os.remove(filename)
403       
404    def test_triangulation_2_geo_refs(self):
405        #
406       
407       
408        filename = tempfile.mktemp("_data_manager.sww")
409        outfile = NetCDFFile(filename, netcdf_mode_w)
410        points_utm = num.array([[0.,0.],[1.,1.], [0.,1.]])
411        volumes = (0,1,2)
412        elevation = [0,1,2]
413        new_origin = Geo_reference(56, 1, 1)
414        points_georeference = Geo_reference(56, 0, 0)
415        points_utm = points_georeference.change_points_geo_ref(points_utm)
416        times = [0, 10]
417        number_of_volumes = len(volumes)
418        number_of_points = len(points_utm)
419        sww = Write_sww(['elevation'], ['stage', 'xmomentum', 'ymomentum'])       
420        sww.store_header(outfile, times, number_of_volumes,
421                         number_of_points, description='fully sick testing',
422                         verbose=self.verbose,sww_precision=netcdf_float)
423        sww.store_triangulation(outfile, points_utm, volumes,
424                                elevation,  new_origin=new_origin,
425                                points_georeference=points_georeference,
426                                verbose=self.verbose)       
427        outfile.close()
428        fid = NetCDFFile(filename)
429
430        x = fid.variables['x'][:]
431        y = fid.variables['y'][:]
432        results_georef = Geo_reference()
433        results_georef.read_NetCDF(fid)
434        assert results_georef == new_origin
435        fid.close()
436
437
438        absolute = Geo_reference(56, 0,0)
439        assert num.allclose(num.array(
440            absolute.change_points_geo_ref(map(None, x,y),
441                                           new_origin)),points_utm)
442        os.remove(filename)
443
444#################################################################################
445
446if __name__ == "__main__":
447    suite = unittest.makeSuite(Test_sww, 'test')
448    runner = unittest.TextTestRunner(verbosity=1)
449    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.