source: trunk/anuga_core/source/anuga/file/test_sww.py @ 8879

Last change on this file since 8879 was 8780, checked in by steve, 12 years ago

Some changes to allow netcdf4 use

File size: 16.9 KB
Line 
1import os
2import unittest
3import tempfile
4import numpy as num
5
6from anuga.coordinate_transforms.geo_reference import Geo_reference
7from csv_file import load_csv_as_array, load_csv_as_dict
8from anuga.abstract_2d_finite_volumes.mesh_factory import rectangular
9from anuga.shallow_water.shallow_water_domain import Domain
10from sww import load_sww_as_domain, weed, get_mesh_and_quantities_from_file, \
11                Write_sww
12from anuga.file.netcdf import NetCDFFile
13
14from anuga.config import netcdf_mode_w, netcdf_float
15
16# boundary functions
17from anuga.shallow_water.boundaries import Reflective_boundary, \
18            Field_boundary, Transmissive_momentum_set_stage_boundary, \
19            Transmissive_stage_zero_momentum_boundary
20from anuga.abstract_2d_finite_volumes.generic_boundary_conditions\
21     import Transmissive_boundary, Dirichlet_boundary, \
22            Time_boundary, File_boundary, AWI_boundary
23
24
25class Test_sww(unittest.TestCase):
26    def setUp(self):
27        self.verbose = False
28        pass
29
30    def tearDown(self):
31        pass
32       
33    def test_sww2domain1(self):
34        ################################################
35        #Create a test domain, and evolve and save it.
36        ################################################
37        from mesh_factory import rectangular
38
39        #Create basic mesh
40
41        yiel=0.01
42        points, vertices, boundary = rectangular(10,10)
43
44        #print "=============== boundary rect ======================="
45        #print boundary
46
47        #Create shallow water domain
48        domain = Domain(points, vertices, boundary)
49        domain.geo_reference = Geo_reference(56,11,11)
50        domain.smooth = False
51        domain.store = True
52        domain.set_name('bedslope')
53        domain.default_order=2
54        #Bed-slope and friction
55        domain.set_quantity('elevation', lambda x,y: -x/3)
56        domain.set_quantity('friction', 0.1)
57        # Boundary conditions
58        from math import sin, pi
59        Br = Reflective_boundary(domain)
60        Bt = Transmissive_boundary(domain)
61        Bd = Dirichlet_boundary([0.2,0.,0.])
62        Bw = Time_boundary(domain=domain,function=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
63
64        #domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
65        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
66
67        domain.quantities_to_be_stored['xmomentum'] = 2
68        domain.quantities_to_be_stored['ymomentum'] = 2
69        #Initial condition
70        h = 0.05
71        elevation = domain.quantities['elevation'].vertex_values
72        domain.set_quantity('stage', elevation + h)
73
74        domain.check_integrity()
75        #Evolution
76        #domain.tight_slope_limiters = 1
77        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
78            #domain.write_time()
79            pass
80
81        #print boundary
82
83
84        filename = domain.datadir + os.sep + domain.get_name() + '.sww'
85        domain2 = load_sww_as_domain(filename, None, fail_if_NaN=False,
86                                        verbose=self.verbose)
87
88        # Unfortunately we loss the boundaries top, bottom, left and right,
89        # they are now all lumped into "exterior"
90
91        #print "=============== boundary domain2 ======================="
92        #print domain2.boundary
93       
94
95        #print domain2.get_boundary_tags()
96       
97        #points, vertices, boundary = rectangular(15,15)
98        #domain2.boundary = boundary
99        ###################
100        ##NOW TEST IT!!!
101        ###################
102
103        os.remove(filename)
104
105        bits = ['vertex_coordinates']
106        for quantity in ['stage']:
107            bits.append('get_quantity("%s").get_integral()' % quantity)
108            bits.append('get_quantity("%s").get_values()' % quantity)
109
110        for bit in bits:
111            #print 'testing that domain.'+bit+' has been restored'
112            #print bit
113            #print 'done'
114            #print eval('domain.'+bit)
115            #print eval('domain2.'+bit)
116            assert num.allclose(eval('domain.'+bit),eval('domain2.'+bit))
117
118        ######################################
119        #Now evolve them both, just to be sure
120        ######################################x
121        from time import sleep
122
123        final = .1
124        domain.set_quantity('friction', 0.1)
125        domain.store = False
126        domain.set_boundary({'exterior': Bd, 'left' : Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
127
128
129        for t in domain.evolve(yieldstep = yiel, finaltime = final):
130            #domain.write_time()
131            pass
132
133        #BUT since domain1 gets time hacked back to 0:
134       
135        final = final + (domain2.get_starttime() - domain.get_starttime())
136
137        domain2.smooth = False
138        domain2.store = False
139        domain2.default_order=2
140        domain2.set_quantity('friction', 0.1)
141        #Bed-slope and friction
142        # Boundary conditions
143        Bd2=Dirichlet_boundary([0.2,0.,0.])
144        domain2.boundary = domain.boundary
145        #print 'domain2.boundary'
146        #print domain2.boundary
147        domain2.set_boundary({'exterior': Bd, 'left' : Bd,  'right': Bd, 'top': Bd, 'bottom': Bd})
148        #domain2.set_boundary({'exterior': Bd})
149
150        domain2.check_integrity()
151
152        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
153            #domain2.write_time()
154            pass
155
156        ###################
157        ##NOW TEST IT!!!
158        ##################
159
160        bits = ['vertex_coordinates']
161
162        for quantity in ['elevation','stage', 'ymomentum','xmomentum']:
163            bits.append('get_quantity("%s").get_integral()' %quantity)
164            bits.append('get_quantity("%s").get_values()' %quantity)
165
166        #print bits
167        for bit in bits:
168            #print bit
169            #print eval('domain.'+bit)
170            #print eval('domain2.'+bit)
171           
172            #print eval('domain.'+bit+'-domain2.'+bit)
173            msg = 'Values in the two domains are different for ' + bit
174            assert num.allclose(eval('domain.'+bit),eval('domain2.'+bit),
175                                rtol=5.e-2, atol=5.e-2), msg
176
177
178
179    def test_get_mesh_and_quantities_from_sww_file(self):
180        """test_get_mesh_and_quantities_from_sww_file(self):
181        """     
182       
183        # Generate a test sww file with non trivial georeference
184       
185        import time, os
186
187        # Setup
188        from mesh_factory import rectangular
189
190        # Create basic mesh (100m x 5m)
191        width = 5
192        length = 50
193        t_end = 10
194        points, vertices, boundary = rectangular(length, width, 50, 5)
195
196        # Create shallow water domain
197        domain = Domain(points, vertices, boundary,
198                        geo_reference = Geo_reference(56,308500,6189000))
199
200        domain.set_name('test_get_mesh_and_quantities_from_sww_file')
201        swwfile = domain.get_name() + '.sww'
202        domain.set_datadir('.')
203
204        Br = Reflective_boundary(domain)    # Side walls
205        Bd = Dirichlet_boundary([1, 0, 0])  # inflow
206
207        domain.set_boundary( {'left': Bd, 'right': Bd, 'top': Br, 'bottom': Br})
208
209        for t in domain.evolve(yieldstep=1, finaltime = t_end):
210            pass
211
212       
213        # Read it
214
215        # Get mesh and quantities from sww file
216        X = get_mesh_and_quantities_from_file(swwfile,
217                                              quantities=['elevation',
218                                                          'stage',
219                                                          'xmomentum',
220                                                          'ymomentum'], 
221                                              verbose=False)
222        mesh, quantities, time = X
223       
224
225        # Check that mesh has been recovered
226        assert num.alltrue(mesh.triangles == domain.get_triangles())
227        assert num.allclose(mesh.nodes, domain.get_nodes())
228
229        # Check that time has been recovered
230        assert num.allclose(time, range(t_end+1))
231
232        # Check that quantities have been recovered
233        # (sww files use single precision)
234        z=domain.get_quantity('elevation').get_values(location='unique vertices')
235        assert num.allclose(quantities['elevation'], z)
236
237        for q in ['stage', 'xmomentum', 'ymomentum']:
238            # Get quantity at last timestep
239            q_ref=domain.get_quantity(q).get_values(location='unique vertices')
240
241            #print q,quantities[q]
242            q_sww=quantities[q][-1,:]
243
244            msg = 'Quantity %s failed to be recovered' %q
245            assert num.allclose(q_ref, q_sww, atol=1.0e-6), msg
246           
247        # Cleanup
248        #os.remove(swwfile)
249       
250       
251
252    def test_weed(self):
253        coordinates1 = [[0.,0.],[1.,0.],[1.,1.],[1.,0.],[2.,0.],[1.,1.]]
254        volumes1 = [[0,1,2],[3,4,5]]
255        boundary1= {(0,1): 'external',(1,2): 'not external',(2,0): 'external',(3,4): 'external',(4,5): 'external',(5,3): 'not external'}
256        coordinates2,volumes2,boundary2=weed(coordinates1,volumes1,boundary1)
257
258        points2 = {(0.,0.):None,(1.,0.):None,(1.,1.):None,(2.,0.):None}
259
260        assert len(points2)==len(coordinates2)
261        for i in range(len(coordinates2)):
262            coordinate = tuple(coordinates2[i])
263            assert points2.has_key(coordinate)
264            points2[coordinate]=i
265
266        for triangle in volumes1:
267            for coordinate in triangle:
268                assert coordinates2[points2[tuple(coordinates1[coordinate])]][0]==coordinates1[coordinate][0]
269                assert coordinates2[points2[tuple(coordinates1[coordinate])]][1]==coordinates1[coordinate][1]
270
271
272    def test_triangulation(self):
273        #
274       
275       
276        filename = tempfile.mktemp("_data_manager.sww")
277        outfile = NetCDFFile(filename, netcdf_mode_w)
278        points_utm = num.array([[0.,0.],[1.,1.],[0.,1.]])
279        volumes = [[0,1,2]]
280        elevation = [0,1,2]
281        new_origin = None
282        new_origin = Geo_reference(56, 0, 0)
283        times = [0, 10]
284        number_of_volumes = len(volumes)
285        number_of_points = len(points_utm)
286        sww = Write_sww(['elevation'], ['stage', 'xmomentum', 'ymomentum'])
287        sww.store_header(outfile, times, number_of_volumes,
288                         number_of_points, description='fully sick testing',
289                         verbose=self.verbose,sww_precision=netcdf_float)
290        sww.store_triangulation(outfile, points_utm, volumes,
291                                elevation,  new_origin=new_origin,
292                                verbose=self.verbose)       
293        outfile.close()
294        fid = NetCDFFile(filename)
295
296        x = fid.variables['x'][:]
297        y = fid.variables['y'][:]
298        fid.close()
299
300        assert num.allclose(num.array(map(None, x,y)), points_utm)
301        os.remove(filename)
302
303       
304    def test_triangulationII(self):
305        #
306       
307
308        DEFAULT_ZONE = 0 # Not documented anywhere what this should be.
309       
310        filename = tempfile.mktemp("_data_manager.sww")
311        outfile = NetCDFFile(filename, netcdf_mode_w)
312        points_utm = num.array([[0.,0.],[1.,1.], [0.,1.]])
313        volumes = [[0,1,2]]
314        elevation = [0,1,2]
315        new_origin = None
316        #new_origin = Geo_reference(56, 0, 0)
317        times = [0, 10]
318        number_of_volumes = len(volumes)
319        number_of_points = len(points_utm)
320        sww = Write_sww(['elevation'], ['stage', 'xmomentum', 'ymomentum'])       
321        sww.store_header(outfile, times, number_of_volumes,
322                         number_of_points, description='fully sick testing',
323                         verbose=self.verbose,sww_precision=netcdf_float)
324        sww.store_triangulation(outfile, points_utm, volumes,
325                                new_origin=new_origin,
326                                verbose=self.verbose)
327        sww.store_static_quantities(outfile, elevation=elevation)                               
328                               
329        outfile.close()
330        fid = NetCDFFile(filename)
331
332        x = fid.variables['x'][:]
333        y = fid.variables['y'][:]
334        results_georef = Geo_reference()
335        results_georef.read_NetCDF(fid)
336        assert results_georef == Geo_reference(DEFAULT_ZONE, 0, 0)
337        fid.close()
338
339        assert num.allclose(num.array(map(None, x,y)), points_utm)
340        os.remove(filename)
341
342       
343    def test_triangulation_new_origin(self):
344        #
345       
346       
347        filename = tempfile.mktemp("_data_manager.sww")
348        outfile = NetCDFFile(filename, netcdf_mode_w)
349        points_utm = num.array([[0.,0.],[1.,1.], [0.,1.]])
350        volumes = [[0,1,2]]
351        elevation = [0,1,2]
352        new_origin = None
353        new_origin = Geo_reference(56, 1, 554354)
354        points_utm = new_origin.change_points_geo_ref(points_utm)
355        times = [0, 10]
356        number_of_volumes = len(volumes)
357        number_of_points = len(points_utm)
358        sww = Write_sww(['elevation'], ['stage', 'xmomentum', 'ymomentum'])       
359        sww.store_header(outfile, times, number_of_volumes,
360                         number_of_points, description='fully sick testing',
361                         verbose=self.verbose,sww_precision=netcdf_float)
362        sww.store_triangulation(outfile, points_utm, volumes,
363                                elevation,  new_origin=new_origin,
364                                verbose=self.verbose)
365        outfile.close()
366        fid = NetCDFFile(filename)
367
368        x = fid.variables['x'][:]
369        y = fid.variables['y'][:]
370        results_georef = Geo_reference()
371        results_georef.read_NetCDF(fid)
372        assert results_georef == new_origin
373        fid.close()
374
375        absolute = Geo_reference(56, 0,0)
376        assert num.allclose(num.array(
377            absolute.change_points_geo_ref(map(None, x,y),
378                                           new_origin)),points_utm)
379       
380        os.remove(filename)
381       
382    def test_triangulation_points_georeference(self):
383        #
384       
385       
386        filename = tempfile.mktemp("_data_manager.sww")
387        outfile = NetCDFFile(filename, netcdf_mode_w)
388        points_utm = num.array([[0.,0.],[1.,1.], [0.,1.]])
389        volumes = [[0,1,2]]
390        elevation = [0,1,2]
391        new_origin = None
392        points_georeference = Geo_reference(56, 1, 554354)
393        points_utm = points_georeference.change_points_geo_ref(points_utm)
394        times = [0, 10]
395        number_of_volumes = len(volumes)
396        number_of_points = len(points_utm)
397        sww = Write_sww(['elevation'], ['stage', 'xmomentum', 'ymomentum'])       
398        sww.store_header(outfile, times, number_of_volumes,
399                         number_of_points, description='fully sick testing',
400                         verbose=self.verbose,sww_precision=netcdf_float)
401        sww.store_triangulation(outfile, points_utm, volumes,
402                                elevation,  new_origin=new_origin,
403                                points_georeference=points_georeference,
404                                verbose=self.verbose)       
405        outfile.close()
406        fid = NetCDFFile(filename)
407
408        x = fid.variables['x'][:]
409        y = fid.variables['y'][:]
410        results_georef = Geo_reference()
411        results_georef.read_NetCDF(fid)
412        assert results_georef == points_georeference
413        fid.close()
414
415        assert num.allclose(num.array(map(None, x,y)), points_utm)
416        os.remove(filename)
417       
418    def test_triangulation_2_geo_refs(self):
419        #
420       
421       
422        filename = tempfile.mktemp("_data_manager.sww")
423        outfile = NetCDFFile(filename, netcdf_mode_w)
424        points_utm = num.array([[0.,0.],[1.,1.], [0.,1.]])
425        volumes = [[0,1,2]]
426        elevation = [0,1,2]
427        new_origin = Geo_reference(56, 1, 1)
428        points_georeference = Geo_reference(56, 0, 0)
429        points_utm = points_georeference.change_points_geo_ref(points_utm)
430        times = [0, 10]
431        number_of_volumes = len(volumes)
432        number_of_points = len(points_utm)
433        sww = Write_sww(['elevation'], ['stage', 'xmomentum', 'ymomentum'])       
434        sww.store_header(outfile, times, number_of_volumes,
435                         number_of_points, description='fully sick testing',
436                         verbose=self.verbose,sww_precision=netcdf_float)
437        sww.store_triangulation(outfile, points_utm, volumes,
438                                elevation,  new_origin=new_origin,
439                                points_georeference=points_georeference,
440                                verbose=self.verbose)       
441        outfile.close()
442        fid = NetCDFFile(filename)
443
444        x = fid.variables['x'][:]
445        y = fid.variables['y'][:]
446        results_georef = Geo_reference()
447        results_georef.read_NetCDF(fid)
448        assert results_georef == new_origin
449        fid.close()
450
451
452        absolute = Geo_reference(56, 0,0)
453        assert num.allclose(num.array(
454            absolute.change_points_geo_ref(map(None, x,y),
455                                           new_origin)),points_utm)
456        os.remove(filename)
457
458#################################################################################
459
460if __name__ == "__main__":
461    suite = unittest.makeSuite(Test_sww, 'test')
462    runner = unittest.TextTestRunner(verbosity=1)
463    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.