source: inundation/pyvolution/test_data_manager.py @ 1794

Last change on this file since 1794 was 1753, checked in by ole, 19 years ago

Embedded caching functionality within quantity.set_values and modified validation example lwru2.py to illustrate the advantages that can be gained from supervised caching.

File size: 60.8 KB
Line 
1#!/usr/bin/env python
2#
3
4import unittest
5import copy
6from Numeric import zeros, array, allclose, Float
7from util import mean
8
9from data_manager import *
10from shallow_water import *
11from config import epsilon
12
13from coordinate_transforms.geo_reference import Geo_reference
14
15class Test_Data_Manager(unittest.TestCase):
16    def setUp(self):
17        import time
18        from mesh_factory import rectangular
19
20
21        #Create basic mesh
22        points, vertices, boundary = rectangular(2, 2)
23
24        #Create shallow water domain
25        domain = Domain(points, vertices, boundary)
26        domain.default_order=2
27
28
29        #Set some field values
30        domain.set_quantity('elevation', lambda x,y: -x)
31        domain.set_quantity('friction', 0.03)
32
33
34        ######################
35        # Boundary conditions
36        B = Transmissive_boundary(domain)
37        domain.set_boundary( {'left': B, 'right': B, 'top': B, 'bottom': B})
38
39
40        ######################
41        #Initial condition - with jumps
42
43
44        bed = domain.quantities['elevation'].vertex_values
45        stage = zeros(bed.shape, Float)
46
47        h = 0.3
48        for i in range(stage.shape[0]):
49            if i % 2 == 0:
50                stage[i,:] = bed[i,:] + h
51            else:
52                stage[i,:] = bed[i,:]
53
54        domain.set_quantity('stage', stage)
55        self.initial_stage = copy.copy(domain.quantities['stage'].vertex_values)
56
57        domain.distribute_to_vertices_and_edges()
58
59
60        self.domain = domain
61
62        C = domain.get_vertex_coordinates()
63        self.X = C[:,0:6:2].copy()
64        self.Y = C[:,1:6:2].copy()
65
66        self.F = bed
67
68
69    def tearDown(self):
70        pass
71
72
73
74
75#     def test_xya(self):
76#         import os
77#         from Numeric import concatenate
78
79#         import time, os
80#         from Numeric import array, zeros, allclose, Float, concatenate
81
82#         domain = self.domain
83
84#         domain.filename = 'datatest' + str(time.time())
85#         domain.format = 'xya'
86#         domain.smooth = True
87
88#         xya = get_dataobject(self.domain)
89#         xya.store_all()
90
91
92#         #Read back
93#         file = open(xya.filename)
94#         lFile = file.read().split('\n')
95#         lFile = lFile[:-1]
96
97#         file.close()
98#         os.remove(xya.filename)
99
100#         #Check contents
101#         if domain.smooth:
102#             self.failUnless(lFile[0] == '9 3 # <vertex #> <x> <y> [attributes]')
103#         else:
104#             self.failUnless(lFile[0] == '24 3 # <vertex #> <x> <y> [attributes]')
105
106#         #Get smoothed field values with X and Y
107#         X,Y,F,V = domain.get_vertex_values(xy=True, value_array='field_values',
108#                                            indices = (0,1), precision = Float)
109
110
111#         Q,V = domain.get_vertex_values(xy=False, value_array='conserved_quantities',
112#                                            indices = (0,), precision = Float)
113
114
115
116#         for i, line in enumerate(lFile[1:]):
117#             fields = line.split()
118
119#             assert len(fields) == 5
120
121#             assert allclose(float(fields[0]), X[i])
122#             assert allclose(float(fields[1]), Y[i])
123#             assert allclose(float(fields[2]), F[i,0])
124#             assert allclose(float(fields[3]), Q[i,0])
125#             assert allclose(float(fields[4]), F[i,1])
126
127
128
129
130    def test_sww_constant(self):
131        """Test that constant sww information can be written correctly
132        (non smooth)
133        """
134
135        import time, os
136        from Numeric import array, zeros, allclose, Float, concatenate
137        from Scientific.IO.NetCDF import NetCDFFile
138
139        self.domain.filename = 'datatest' + str(id(self))
140        self.domain.format = 'sww'
141        self.domain.smooth = False
142
143        sww = get_dataobject(self.domain)
144        sww.store_connectivity()
145
146        #Check contents
147        #Get NetCDF
148        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
149
150        # Get the variables
151        x = fid.variables['x']
152        y = fid.variables['y']
153        z = fid.variables['elevation']
154
155        volumes = fid.variables['volumes']
156
157
158        assert allclose (x[:], self.X.flat)
159        assert allclose (y[:], self.Y.flat)
160        assert allclose (z[:], self.F.flat)
161
162        V = volumes
163
164        P = len(self.domain)
165        for k in range(P):
166            assert V[k, 0] == 3*k
167            assert V[k, 1] == 3*k+1
168            assert V[k, 2] == 3*k+2
169
170
171        fid.close()
172
173        #Cleanup
174        os.remove(sww.filename)
175
176
177    def test_sww_constant_smooth(self):
178        """Test that constant sww information can be written correctly
179        (non smooth)
180        """
181
182        import time, os
183        from Numeric import array, zeros, allclose, Float, concatenate
184        from Scientific.IO.NetCDF import NetCDFFile
185
186        self.domain.filename = 'datatest' + str(id(self))
187        self.domain.format = 'sww'
188        self.domain.smooth = True
189
190        sww = get_dataobject(self.domain)
191        sww.store_connectivity()
192
193        #Check contents
194        #Get NetCDF
195        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
196
197        # Get the variables
198        x = fid.variables['x']
199        y = fid.variables['y']
200        z = fid.variables['elevation']
201
202        volumes = fid.variables['volumes']
203
204        X = x[:]
205        Y = y[:]
206
207        assert allclose([X[0], Y[0]], array([0.0, 0.0]))
208        assert allclose([X[1], Y[1]], array([0.0, 0.5]))
209        assert allclose([X[2], Y[2]], array([0.0, 1.0]))
210
211        assert allclose([X[4], Y[4]], array([0.5, 0.5]))
212
213        assert allclose([X[7], Y[7]], array([1.0, 0.5]))
214
215        Z = z[:]
216        assert Z[4] == -0.5
217
218        V = volumes
219        assert V[2,0] == 4
220        assert V[2,1] == 5
221        assert V[2,2] == 1
222
223        assert V[4,0] == 6
224        assert V[4,1] == 7
225        assert V[4,2] == 3
226
227
228        fid.close()
229
230        #Cleanup
231        os.remove(sww.filename)
232
233
234
235    def test_sww_variable(self):
236        """Test that sww information can be written correctly
237        """
238
239        import time, os
240        from Numeric import array, zeros, allclose, Float, concatenate
241        from Scientific.IO.NetCDF import NetCDFFile
242
243        self.domain.filename = 'datatest' + str(id(self))
244        self.domain.format = 'sww'
245        self.domain.smooth = True
246        self.domain.reduction = mean
247
248        sww = get_dataobject(self.domain)
249        sww.store_connectivity()
250        sww.store_timestep('stage')
251
252        #Check contents
253        #Get NetCDF
254        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
255
256
257        # Get the variables
258        x = fid.variables['x']
259        y = fid.variables['y']
260        z = fid.variables['elevation']
261        time = fid.variables['time']
262        stage = fid.variables['stage']
263
264
265        Q = self.domain.quantities['stage']
266        Q0 = Q.vertex_values[:,0]
267        Q1 = Q.vertex_values[:,1]
268        Q2 = Q.vertex_values[:,2]
269
270        A = stage[0,:]
271        #print A[0], (Q2[0,0] + Q1[1,0])/2
272        assert allclose(A[0], (Q2[0] + Q1[1])/2)
273        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
274        assert allclose(A[2], Q0[3])
275        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
276
277        #Center point
278        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
279                                 Q0[5] + Q2[6] + Q1[7])/6)
280
281
282
283        fid.close()
284
285        #Cleanup
286        os.remove(sww.filename)
287
288
289    def test_sww_variable2(self):
290        """Test that sww information can be written correctly
291        multiple timesteps. Use average as reduction operator
292        """
293
294        import time, os
295        from Numeric import array, zeros, allclose, Float, concatenate
296        from Scientific.IO.NetCDF import NetCDFFile
297
298        self.domain.filename = 'datatest' + str(id(self))
299        self.domain.format = 'sww'
300        self.domain.smooth = True
301
302        self.domain.reduction = mean
303
304        sww = get_dataobject(self.domain)
305        sww.store_connectivity()
306        sww.store_timestep('stage')
307        self.domain.evolve_to_end(finaltime = 0.01)
308        sww.store_timestep('stage')
309
310
311        #Check contents
312        #Get NetCDF
313        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
314
315        # Get the variables
316        x = fid.variables['x']
317        y = fid.variables['y']
318        z = fid.variables['elevation']
319        time = fid.variables['time']
320        stage = fid.variables['stage']
321
322        #Check values
323        Q = self.domain.quantities['stage']
324        Q0 = Q.vertex_values[:,0]
325        Q1 = Q.vertex_values[:,1]
326        Q2 = Q.vertex_values[:,2]
327
328        A = stage[1,:]
329        assert allclose(A[0], (Q2[0] + Q1[1])/2)
330        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
331        assert allclose(A[2], Q0[3])
332        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
333
334        #Center point
335        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
336                                 Q0[5] + Q2[6] + Q1[7])/6)
337
338
339        fid.close()
340
341        #Cleanup
342        os.remove(sww.filename)
343
344    def test_sww_variable3(self):
345        """Test that sww information can be written correctly
346        multiple timesteps using a different reduction operator (min)
347        """
348
349        import time, os
350        from Numeric import array, zeros, allclose, Float, concatenate
351        from Scientific.IO.NetCDF import NetCDFFile
352
353        self.domain.filename = 'datatest' + str(id(self))
354        self.domain.format = 'sww'
355        self.domain.smooth = True
356        self.domain.reduction = min
357
358        sww = get_dataobject(self.domain)
359        sww.store_connectivity()
360        sww.store_timestep('stage')
361
362        self.domain.evolve_to_end(finaltime = 0.01)
363        sww.store_timestep('stage')
364
365
366        #Check contents
367        #Get NetCDF
368        fid = NetCDFFile(sww.filename, 'r')
369
370
371        # Get the variables
372        x = fid.variables['x']
373        y = fid.variables['y']
374        z = fid.variables['elevation']
375        time = fid.variables['time']
376        stage = fid.variables['stage']
377
378        #Check values
379        Q = self.domain.quantities['stage']
380        Q0 = Q.vertex_values[:,0]
381        Q1 = Q.vertex_values[:,1]
382        Q2 = Q.vertex_values[:,2]
383
384        A = stage[1,:]
385        assert allclose(A[0], min(Q2[0], Q1[1]))
386        assert allclose(A[1], min(Q0[1], Q1[3], Q2[2]))
387        assert allclose(A[2], Q0[3])
388        assert allclose(A[3], min(Q0[0], Q1[5], Q2[4]))
389
390        #Center point
391        assert allclose(A[4], min(Q1[0], Q2[1], Q0[2],\
392                                  Q0[5], Q2[6], Q1[7]))
393
394
395        fid.close()
396
397        #Cleanup
398        os.remove(sww.filename)
399
400
401    def test_sync(self):
402        """Test info stored at each timestep is as expected (incl initial condition)
403        """
404
405        import time, os, config
406        from Numeric import array, zeros, allclose, Float, concatenate
407        from Scientific.IO.NetCDF import NetCDFFile
408
409        self.domain.filename = 'synctest'
410        self.domain.format = 'sww'
411        self.domain.smooth = False
412        self.domain.store = True
413        self.domain.beta_h = 0
414
415        #Evolution
416        for t in self.domain.evolve(yieldstep = 1.0, finaltime = 4.0):
417            stage = self.domain.quantities['stage'].vertex_values
418
419            #Get NetCDF
420            fid = NetCDFFile(self.domain.writer.filename, 'r')
421            stage_file = fid.variables['stage']
422
423            if t == 0.0:
424                assert allclose(stage, self.initial_stage)
425                assert allclose(stage_file[:], stage.flat)
426            else:
427                assert not allclose(stage, self.initial_stage)
428                assert not allclose(stage_file[:], stage.flat)
429
430            fid.close()
431
432        os.remove(self.domain.writer.filename)
433
434
435
436    def test_sww_DSG(self):
437        """Not a test, rather a look at the sww format
438        """
439
440        import time, os
441        from Numeric import array, zeros, allclose, Float, concatenate
442        from Scientific.IO.NetCDF import NetCDFFile
443
444        self.domain.filename = 'datatest' + str(id(self))
445        self.domain.format = 'sww'
446        self.domain.smooth = True
447        self.domain.reduction = mean
448
449        sww = get_dataobject(self.domain)
450        sww.store_connectivity()
451        sww.store_timestep('stage')
452
453        #Check contents
454        #Get NetCDF
455        fid = NetCDFFile(sww.filename, 'r')
456
457        # Get the variables
458        x = fid.variables['x']
459        y = fid.variables['y']
460        z = fid.variables['elevation']
461
462        volumes = fid.variables['volumes']
463        time = fid.variables['time']
464
465        # 2D
466        stage = fid.variables['stage']
467
468        X = x[:]
469        Y = y[:]
470        Z = z[:]
471        V = volumes[:]
472        T = time[:]
473        S = stage[:,:]
474
475#         print "****************************"
476#         print "X ",X
477#         print "****************************"
478#         print "Y ",Y
479#         print "****************************"
480#         print "Z ",Z
481#         print "****************************"
482#         print "V ",V
483#         print "****************************"
484#         print "Time ",T
485#         print "****************************"
486#         print "Stage ",S
487#         print "****************************"
488
489
490        fid.close()
491
492        #Cleanup
493        os.remove(sww.filename)
494
495
496    def test_write_pts(self):
497        #Get (enough) datapoints
498
499        from Numeric import array
500        points = array([[ 0.66666667, 0.66666667],
501                        [ 1.33333333, 1.33333333],
502                        [ 2.66666667, 0.66666667],
503                        [ 0.66666667, 2.66666667],
504                        [ 0.0, 1.0],
505                        [ 0.0, 3.0],
506                        [ 1.0, 0.0],
507                        [ 1.0, 1.0],
508                        [ 1.0, 2.0],
509                        [ 1.0, 3.0],
510                        [ 2.0, 1.0],
511                        [ 3.0, 0.0],
512                        [ 3.0, 1.0]])
513
514        z = points[:,0] + 2*points[:,1]
515
516        ptsfile = 'testptsfile.pts'
517        write_ptsfile(ptsfile, points, z,
518                      attribute_name = 'linear_combination')
519
520        #Check contents
521        #Get NetCDF
522        from Scientific.IO.NetCDF import NetCDFFile       
523        fid = NetCDFFile(ptsfile, 'r')
524
525        # Get the variables
526        #print fid.variables.keys()
527        points1 = fid.variables['points']
528        z1 = fid.variables['linear_combination']
529
530        #Check values
531
532        #print points[:]
533        #print ref_points
534        assert allclose(points, points1)
535
536        #print attributes[:]
537        #print ref_elevation
538        assert allclose(z, z1)
539
540        #Cleanup
541        fid.close()
542
543        import os
544        os.remove(ptsfile)
545       
546
547
548
549    def test_dem2pts(self):
550        """Test conversion from dem in ascii format to native NetCDF xya format
551        """
552
553        import time, os
554        from Numeric import array, zeros, allclose, Float, concatenate
555        from Scientific.IO.NetCDF import NetCDFFile
556
557        #Write test asc file
558        root = 'demtest'
559
560        filename = root+'.asc'
561        fid = open(filename, 'w')
562        fid.write("""ncols         5
563nrows         6
564xllcorner     2000.5
565yllcorner     3000.5
566cellsize      25
567NODATA_value  -9999
568""")
569        #Create linear function
570
571        ref_points = []
572        ref_elevation = []
573        for i in range(6):
574            y = (6-i)*25.0
575            for j in range(5):
576                x = j*25.0
577                z = x+2*y
578
579                ref_points.append( [x,y] )
580                ref_elevation.append(z)
581                fid.write('%f ' %z)
582            fid.write('\n')
583
584        fid.close()
585
586        #Write prj file with metadata
587        metafilename = root+'.prj'
588        fid = open(metafilename, 'w')
589
590
591        fid.write("""Projection UTM
592Zone 56
593Datum WGS84
594Zunits NO
595Units METERS
596Spheroid WGS84
597Xshift 0.0000000000
598Yshift 10000000.0000000000
599Parameters
600""")
601        fid.close()
602
603        #Convert to NetCDF pts
604        convert_dem_from_ascii2netcdf(root)
605        dem2pts(root)
606
607        #Check contents
608        #Get NetCDF
609        fid = NetCDFFile(root+'.pts', 'r')
610
611        # Get the variables
612        #print fid.variables.keys()
613        points = fid.variables['points']
614        elevation = fid.variables['elevation']
615
616        #Check values
617
618        #print points[:]
619        #print ref_points
620        assert allclose(points, ref_points)
621
622        #print attributes[:]
623        #print ref_elevation
624        assert allclose(elevation, ref_elevation)
625
626        #Cleanup
627        fid.close()
628
629
630        os.remove(root + '.pts')
631        os.remove(root + '.dem')
632        os.remove(root + '.asc')
633        os.remove(root + '.prj')
634
635
636
637    def test_dem2pts_bounding_box(self):
638        """Test conversion from dem in ascii format to native NetCDF xya format
639        """
640
641        import time, os
642        from Numeric import array, zeros, allclose, Float, concatenate
643        from Scientific.IO.NetCDF import NetCDFFile
644
645        #Write test asc file
646        root = 'demtest'
647
648        filename = root+'.asc'
649        fid = open(filename, 'w')
650        fid.write("""ncols         5
651nrows         6
652xllcorner     2000.5
653yllcorner     3000.5
654cellsize      25
655NODATA_value  -9999
656""")
657        #Create linear function
658
659        ref_points = []
660        ref_elevation = []
661        for i in range(6):
662            y = (6-i)*25.0
663            for j in range(5):
664                x = j*25.0
665                z = x+2*y
666
667                ref_points.append( [x,y] )
668                ref_elevation.append(z)
669                fid.write('%f ' %z)
670            fid.write('\n')
671
672        fid.close()
673
674        #Write prj file with metadata
675        metafilename = root+'.prj'
676        fid = open(metafilename, 'w')
677
678
679        fid.write("""Projection UTM
680Zone 56
681Datum WGS84
682Zunits NO
683Units METERS
684Spheroid WGS84
685Xshift 0.0000000000
686Yshift 10000000.0000000000
687Parameters
688""")
689        fid.close()
690
691        #Convert to NetCDF pts
692        convert_dem_from_ascii2netcdf(root)
693        dem2pts(root, easting_min=2010.0, easting_max=2110.0,
694                northing_min=3035.0, northing_max=3125.5)
695
696        #Check contents
697        #Get NetCDF
698        fid = NetCDFFile(root+'.pts', 'r')
699
700        # Get the variables
701        #print fid.variables.keys()
702        points = fid.variables['points']
703        elevation = fid.variables['elevation']
704
705        #Check values
706        assert fid.xllcorner[0] == 2010.0
707        assert fid.yllcorner[0] == 3035.0
708
709        #create new reference points
710        ref_points = []
711        ref_elevation = []
712        for i in range(4):
713            y = (4-i)*25.0 + 25.0
714            y_new = y + 3000.5 - 3035.0
715            for j in range(4):
716                x = j*25.0 + 25.0
717                x_new = x + 2000.5 - 2010.0
718                z = x+2*y
719
720                ref_points.append( [x_new,y_new] )
721                ref_elevation.append(z)
722
723        #print points[:]
724        #print ref_points
725        assert allclose(points, ref_points)
726
727        #print attributes[:]
728        #print ref_elevation
729        assert allclose(elevation, ref_elevation)
730
731        #Cleanup
732        fid.close()
733
734
735        os.remove(root + '.pts')
736        os.remove(root + '.dem')
737        os.remove(root + '.asc')
738        os.remove(root + '.prj')
739
740
741
742    def test_sww2asc_elevation(self):
743        """Test that sww information can be converted correctly to asc/prj
744        format readable by e.g. ArcView
745        """
746
747        import time, os
748        from Numeric import array, zeros, allclose, Float, concatenate
749        from Scientific.IO.NetCDF import NetCDFFile
750
751        #Setup
752        self.domain.filename = 'datatest'
753
754        prjfile = self.domain.filename + '_elevation.prj'
755        ascfile = self.domain.filename + '_elevation.asc'
756        swwfile = self.domain.filename + '.sww'
757
758        self.domain.set_datadir('.')
759        self.domain.format = 'sww'
760        self.domain.smooth = True
761        self.domain.set_quantity('elevation', lambda x,y: -x-y)
762
763        self.domain.geo_reference = Geo_reference(56,308500,6189000)
764
765        sww = get_dataobject(self.domain)
766        sww.store_connectivity()
767        sww.store_timestep('stage')
768
769        self.domain.evolve_to_end(finaltime = 0.01)
770        sww.store_timestep('stage')
771
772        cellsize = 0.25
773        #Check contents
774        #Get NetCDF
775
776        fid = NetCDFFile(sww.filename, 'r')
777
778        # Get the variables
779        x = fid.variables['x'][:]
780        y = fid.variables['y'][:]
781        z = fid.variables['elevation'][:]
782        time = fid.variables['time'][:]
783        stage = fid.variables['stage'][:]
784
785
786        #Export to ascii/prj files
787        sww2asc(self.domain.filename,
788                quantity = 'elevation',
789                cellsize = cellsize,
790                verbose = False)
791
792
793        #Check prj (meta data)
794        prjid = open(prjfile)
795        lines = prjid.readlines()
796        prjid.close()
797
798        L = lines[0].strip().split()
799        assert L[0].strip().lower() == 'projection'
800        assert L[1].strip().lower() == 'utm'
801
802        L = lines[1].strip().split()
803        assert L[0].strip().lower() == 'zone'
804        assert L[1].strip().lower() == '56'
805
806        L = lines[2].strip().split()
807        assert L[0].strip().lower() == 'datum'
808        assert L[1].strip().lower() == 'wgs84'
809
810        L = lines[3].strip().split()
811        assert L[0].strip().lower() == 'zunits'
812        assert L[1].strip().lower() == 'no'
813
814        L = lines[4].strip().split()
815        assert L[0].strip().lower() == 'units'
816        assert L[1].strip().lower() == 'meters'
817
818        L = lines[5].strip().split()
819        assert L[0].strip().lower() == 'spheroid'
820        assert L[1].strip().lower() == 'wgs84'
821
822        L = lines[6].strip().split()
823        assert L[0].strip().lower() == 'xshift'
824        assert L[1].strip().lower() == '500000'
825
826        L = lines[7].strip().split()
827        assert L[0].strip().lower() == 'yshift'
828        assert L[1].strip().lower() == '10000000'
829
830        L = lines[8].strip().split()
831        assert L[0].strip().lower() == 'parameters'
832
833
834        #Check asc file
835        ascid = open(ascfile)
836        lines = ascid.readlines()
837        ascid.close()
838
839        L = lines[0].strip().split()
840        assert L[0].strip().lower() == 'ncols'
841        assert L[1].strip().lower() == '5'
842
843        L = lines[1].strip().split()
844        assert L[0].strip().lower() == 'nrows'
845        assert L[1].strip().lower() == '5'
846
847        L = lines[2].strip().split()
848        assert L[0].strip().lower() == 'xllcorner'
849        assert allclose(float(L[1].strip().lower()), 308500)
850
851        L = lines[3].strip().split()
852        assert L[0].strip().lower() == 'yllcorner'
853        assert allclose(float(L[1].strip().lower()), 6189000)
854
855        L = lines[4].strip().split()
856        assert L[0].strip().lower() == 'cellsize'
857        assert allclose(float(L[1].strip().lower()), cellsize)
858
859        L = lines[5].strip().split()
860        assert L[0].strip() == 'NODATA_value'
861        assert L[1].strip().lower() == '-9999'
862
863        #Check grid values
864        for j in range(5):
865            L = lines[6+j].strip().split()
866            y = (4-j) * cellsize
867            for i in range(5):
868                assert allclose(float(L[i]), -i*cellsize - y)
869
870
871        fid.close()
872
873        #Cleanup
874        os.remove(prjfile)
875        os.remove(ascfile)
876        os.remove(swwfile)
877
878
879    def test_sww2asc_stage_reduction(self):
880        """Test that sww information can be converted correctly to asc/prj
881        format readable by e.g. ArcView
882
883        This tests the reduction of quantity stage using min
884        """
885
886        import time, os
887        from Numeric import array, zeros, allclose, Float, concatenate
888        from Scientific.IO.NetCDF import NetCDFFile
889
890        #Setup
891        self.domain.filename = 'datatest'
892
893        prjfile = self.domain.filename + '_stage.prj'
894        ascfile = self.domain.filename + '_stage.asc'
895        swwfile = self.domain.filename + '.sww'
896
897        self.domain.set_datadir('.')
898        self.domain.format = 'sww'
899        self.domain.smooth = True
900        self.domain.set_quantity('elevation', lambda x,y: -x-y)
901
902        self.domain.geo_reference = Geo_reference(56,308500,6189000)
903
904
905        sww = get_dataobject(self.domain)
906        sww.store_connectivity()
907        sww.store_timestep('stage')
908
909        self.domain.evolve_to_end(finaltime = 0.01)
910        sww.store_timestep('stage')
911
912        cellsize = 0.25
913        #Check contents
914        #Get NetCDF
915
916        fid = NetCDFFile(sww.filename, 'r')
917
918        # Get the variables
919        x = fid.variables['x'][:]
920        y = fid.variables['y'][:]
921        z = fid.variables['elevation'][:]
922        time = fid.variables['time'][:]
923        stage = fid.variables['stage'][:]
924
925
926        #Export to ascii/prj files
927        sww2asc(self.domain.filename,
928                quantity = 'stage',
929                cellsize = cellsize,
930                reduction = min)
931
932
933        #Check asc file
934        ascid = open(ascfile)
935        lines = ascid.readlines()
936        ascid.close()
937
938        L = lines[0].strip().split()
939        assert L[0].strip().lower() == 'ncols'
940        assert L[1].strip().lower() == '5'
941
942        L = lines[1].strip().split()
943        assert L[0].strip().lower() == 'nrows'
944        assert L[1].strip().lower() == '5'
945
946        L = lines[2].strip().split()
947        assert L[0].strip().lower() == 'xllcorner'
948        assert allclose(float(L[1].strip().lower()), 308500)
949
950        L = lines[3].strip().split()
951        assert L[0].strip().lower() == 'yllcorner'
952        assert allclose(float(L[1].strip().lower()), 6189000)
953
954        L = lines[4].strip().split()
955        assert L[0].strip().lower() == 'cellsize'
956        assert allclose(float(L[1].strip().lower()), cellsize)
957
958        L = lines[5].strip().split()
959        assert L[0].strip() == 'NODATA_value'
960        assert L[1].strip().lower() == '-9999'
961
962
963        #Check grid values (where applicable)
964        for j in range(5):
965            if j%2 == 0:
966                L = lines[6+j].strip().split()
967                jj = 4-j
968                for i in range(5):
969                    if i%2 == 0:
970                        index = jj/2 + i/2*3
971                        val0 = stage[0,index]
972                        val1 = stage[1,index]
973
974                        #print i, j, index, ':', L[i], val0, val1
975                        assert allclose(float(L[i]), min(val0, val1))
976
977
978        fid.close()
979
980        #Cleanup
981        os.remove(prjfile)
982        os.remove(ascfile)
983        #os.remove(swwfile)
984
985
986
987
988    def test_sww2asc_missing_points(self):
989        """Test that sww information can be converted correctly to asc/prj
990        format readable by e.g. ArcView
991
992        This test includes the writing of missing values
993        """
994
995        import time, os
996        from Numeric import array, zeros, allclose, Float, concatenate
997        from Scientific.IO.NetCDF import NetCDFFile
998
999        #Setup mesh not coinciding with rectangle.
1000        #This will cause missing values to occur in gridded data
1001
1002
1003        points = [                        [1.0, 1.0],
1004                              [0.5, 0.5], [1.0, 0.5],
1005                  [0.0, 0.0], [0.5, 0.0], [1.0, 0.0]]
1006
1007        vertices = [ [4,1,3], [5,2,4], [1,4,2], [2,0,1]]
1008
1009        #Create shallow water domain
1010        domain = Domain(points, vertices)
1011        domain.default_order=2
1012
1013
1014        #Set some field values
1015        domain.set_quantity('elevation', lambda x,y: -x-y)
1016        domain.set_quantity('friction', 0.03)
1017
1018
1019        ######################
1020        # Boundary conditions
1021        B = Transmissive_boundary(domain)
1022        domain.set_boundary( {'exterior': B} )
1023
1024
1025        ######################
1026        #Initial condition - with jumps
1027
1028        bed = domain.quantities['elevation'].vertex_values
1029        stage = zeros(bed.shape, Float)
1030
1031        h = 0.3
1032        for i in range(stage.shape[0]):
1033            if i % 2 == 0:
1034                stage[i,:] = bed[i,:] + h
1035            else:
1036                stage[i,:] = bed[i,:]
1037
1038        domain.set_quantity('stage', stage)
1039        domain.distribute_to_vertices_and_edges()
1040
1041        domain.filename = 'datatest'
1042
1043        prjfile = domain.filename + '_elevation.prj'
1044        ascfile = domain.filename + '_elevation.asc'
1045        swwfile = domain.filename + '.sww'
1046
1047        domain.set_datadir('.')
1048        domain.format = 'sww'
1049        domain.smooth = True
1050
1051        domain.geo_reference = Geo_reference(56,308500,6189000)
1052
1053        sww = get_dataobject(domain)
1054        sww.store_connectivity()
1055        sww.store_timestep('stage')
1056
1057        cellsize = 0.25
1058        #Check contents
1059        #Get NetCDF
1060
1061        fid = NetCDFFile(swwfile, 'r')
1062
1063        # Get the variables
1064        x = fid.variables['x'][:]
1065        y = fid.variables['y'][:]
1066        z = fid.variables['elevation'][:]
1067        time = fid.variables['time'][:]
1068
1069        try:
1070            geo_reference = Geo_reference(NetCDFObject=fid)
1071        except AttributeError, e:
1072            geo_reference = Geo_reference(DEFAULT_ZONE,0,0)
1073
1074        #Export to ascii/prj files
1075        sww2asc(domain.filename,
1076                quantity = 'elevation',
1077                cellsize = cellsize,
1078                verbose = False)
1079
1080
1081        #Check asc file
1082        ascid = open(ascfile)
1083        lines = ascid.readlines()
1084        ascid.close()
1085
1086        L = lines[0].strip().split()
1087        assert L[0].strip().lower() == 'ncols'
1088        assert L[1].strip().lower() == '5'
1089
1090        L = lines[1].strip().split()
1091        assert L[0].strip().lower() == 'nrows'
1092        assert L[1].strip().lower() == '5'
1093
1094        L = lines[2].strip().split()
1095        assert L[0].strip().lower() == 'xllcorner'
1096        assert allclose(float(L[1].strip().lower()), 308500)
1097
1098        L = lines[3].strip().split()
1099        assert L[0].strip().lower() == 'yllcorner'
1100        assert allclose(float(L[1].strip().lower()), 6189000)
1101
1102        L = lines[4].strip().split()
1103        assert L[0].strip().lower() == 'cellsize'
1104        assert allclose(float(L[1].strip().lower()), cellsize)
1105
1106        L = lines[5].strip().split()
1107        assert L[0].strip() == 'NODATA_value'
1108        assert L[1].strip().lower() == '-9999'
1109
1110
1111        #Check grid values
1112        for j in range(5):
1113            L = lines[6+j].strip().split()
1114            y = (4-j) * cellsize
1115            for i in range(5):
1116                if i+j >= 4:
1117                    assert allclose(float(L[i]), -i*cellsize - y)
1118                else:
1119                    #Missing values
1120                    assert allclose(float(L[i]), -9999)
1121
1122
1123
1124        fid.close()
1125
1126        #Cleanup
1127        os.remove(prjfile)
1128        os.remove(ascfile)
1129        os.remove(swwfile)
1130
1131
1132    def test_ferret2sww(self):
1133        """Test that georeferencing etc works when converting from
1134        ferret format (lat/lon) to sww format (UTM)
1135        """
1136        from Scientific.IO.NetCDF import NetCDFFile
1137
1138        #The test file has
1139        # LON = 150.66667, 150.83334, 151, 151.16667
1140        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1141        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1142        #
1143        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1144        # Fourth value (index==3) is -6.50198 cm
1145
1146
1147        from coordinate_transforms.redfearn import redfearn
1148
1149        fid = NetCDFFile('small_ha.nc')
1150        first_value = fid.variables['HA'][:][0,0,0]
1151        fourth_value = fid.variables['HA'][:][0,0,3]
1152
1153
1154        #Call conversion (with zero origin)
1155        ferret2sww('small', verbose=False,
1156                   origin = (56, 0, 0))
1157
1158
1159        #Work out the UTM coordinates for first point
1160        zone, e, n = redfearn(-34.5, 150.66667)
1161        #print zone, e, n
1162
1163        #Read output file 'small.sww'
1164        fid = NetCDFFile('small.sww')
1165
1166        x = fid.variables['x'][:]
1167        y = fid.variables['y'][:]
1168
1169        #Check that first coordinate is correctly represented
1170        assert allclose(x[0], e)
1171        assert allclose(y[0], n)
1172
1173        #Check first value
1174        stage = fid.variables['stage'][:]
1175        xmomentum = fid.variables['xmomentum'][:]
1176        ymomentum = fid.variables['ymomentum'][:]
1177
1178        #print ymomentum
1179
1180        assert allclose(stage[0,0], first_value/100)  #Meters
1181
1182        #Check fourth value
1183        assert allclose(stage[0,3], fourth_value/100)  #Meters
1184
1185        fid.close()
1186
1187        #Cleanup
1188        import os
1189        os.remove('small.sww')
1190
1191
1192
1193    def test_ferret2sww_2(self):
1194        """Test that georeferencing etc works when converting from
1195        ferret format (lat/lon) to sww format (UTM)
1196        """
1197        from Scientific.IO.NetCDF import NetCDFFile
1198
1199        #The test file has
1200        # LON = 150.66667, 150.83334, 151, 151.16667
1201        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1202        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1203        #
1204        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1205        # Fourth value (index==3) is -6.50198 cm
1206
1207
1208        from coordinate_transforms.redfearn import redfearn
1209
1210        fid = NetCDFFile('small_ha.nc')
1211
1212        #Pick a coordinate and a value
1213
1214        time_index = 1
1215        lat_index = 0
1216        lon_index = 2
1217
1218        test_value = fid.variables['HA'][:][time_index, lat_index, lon_index]
1219        test_time = fid.variables['TIME'][:][time_index]
1220        test_lat = fid.variables['LAT'][:][lat_index]
1221        test_lon = fid.variables['LON'][:][lon_index]
1222
1223        linear_point_index = lat_index*4 + lon_index
1224        fid.close()
1225
1226        #Call conversion (with zero origin)
1227        ferret2sww('small', verbose=False,
1228                   origin = (56, 0, 0))
1229
1230
1231        #Work out the UTM coordinates for test point
1232        zone, e, n = redfearn(test_lat, test_lon)
1233
1234        #Read output file 'small.sww'
1235        fid = NetCDFFile('small.sww')
1236
1237        x = fid.variables['x'][:]
1238        y = fid.variables['y'][:]
1239
1240        #Check that test coordinate is correctly represented
1241        assert allclose(x[linear_point_index], e)
1242        assert allclose(y[linear_point_index], n)
1243
1244        #Check test value
1245        stage = fid.variables['stage'][:]
1246
1247        assert allclose(stage[time_index, linear_point_index], test_value/100)
1248
1249        fid.close()
1250
1251        #Cleanup
1252        import os
1253        os.remove('small.sww')
1254
1255
1256
1257    def test_ferret2sww3(self):
1258        """
1259        """
1260        from Scientific.IO.NetCDF import NetCDFFile
1261
1262        #The test file has
1263        # LON = 150.66667, 150.83334, 151, 151.16667
1264        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1265        # ELEVATION = [-1 -2 -3 -4
1266        #             -5 -6 -7 -8
1267        #              ...
1268        #              ...      -16]
1269        # where the top left corner is -1m,
1270        # and the ll corner is -13.0m
1271        #
1272        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1273        # Fourth value (index==3) is -6.50198 cm
1274
1275        from coordinate_transforms.redfearn import redfearn
1276        import os
1277        fid1 = NetCDFFile('test_ha.nc','w')
1278        fid2 = NetCDFFile('test_ua.nc','w')
1279        fid3 = NetCDFFile('test_va.nc','w')
1280        fid4 = NetCDFFile('test_e.nc','w')
1281
1282        h1_list = [150.66667,150.83334,151.]
1283        h2_list = [-34.5,-34.33333]
1284
1285        long_name = 'LON'
1286        lat_name = 'LAT'
1287
1288        nx = 3
1289        ny = 2
1290
1291        for fid in [fid1,fid2,fid3]:
1292            fid.createDimension(long_name,nx)
1293            fid.createVariable(long_name,'d',(long_name,))
1294            fid.variables[long_name].point_spacing='uneven'
1295            fid.variables[long_name].units='degrees_east'
1296            fid.variables[long_name].assignValue(h1_list)
1297
1298            fid.createDimension(lat_name,ny)
1299            fid.createVariable(lat_name,'d',(lat_name,))
1300            fid.variables[lat_name].point_spacing='uneven'
1301            fid.variables[lat_name].units='degrees_north'
1302            fid.variables[lat_name].assignValue(h2_list)
1303
1304            fid.createDimension('TIME',2)
1305            fid.createVariable('TIME','d',('TIME',))
1306            fid.variables['TIME'].point_spacing='uneven'
1307            fid.variables['TIME'].units='seconds'
1308            fid.variables['TIME'].assignValue([0.,1.])
1309            if fid == fid3: break
1310
1311
1312        for fid in [fid4]:
1313            fid.createDimension(long_name,nx)
1314            fid.createVariable(long_name,'d',(long_name,))
1315            fid.variables[long_name].point_spacing='uneven'
1316            fid.variables[long_name].units='degrees_east'
1317            fid.variables[long_name].assignValue(h1_list)
1318
1319            fid.createDimension(lat_name,ny)
1320            fid.createVariable(lat_name,'d',(lat_name,))
1321            fid.variables[lat_name].point_spacing='uneven'
1322            fid.variables[lat_name].units='degrees_north'
1323            fid.variables[lat_name].assignValue(h2_list)
1324
1325        name = {}
1326        name[fid1]='HA'
1327        name[fid2]='UA'
1328        name[fid3]='VA'
1329        name[fid4]='ELEVATION'
1330
1331        units = {}
1332        units[fid1]='cm'
1333        units[fid2]='cm/s'
1334        units[fid3]='cm/s'
1335        units[fid4]='m'
1336
1337        values = {}
1338        values[fid1]=[[[5., 10.,15.], [13.,18.,23.]],[[50.,100.,150.],[130.,180.,230.]]]
1339        values[fid2]=[[[1., 2.,3.], [4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]]
1340        values[fid3]=[[[13., 12.,11.], [10.,9.,8.]],[[7.,6.,5.],[4.,3.,2.]]]
1341        values[fid4]=[[-3000,-3100,-3200],[-4000,-5000,-6000]]
1342
1343        for fid in [fid1,fid2,fid3]:
1344          fid.createVariable(name[fid],'d',('TIME',lat_name,long_name))
1345          fid.variables[name[fid]].point_spacing='uneven'
1346          fid.variables[name[fid]].units=units[fid]
1347          fid.variables[name[fid]].assignValue(values[fid])
1348          fid.variables[name[fid]].missing_value = -99999999.
1349          if fid == fid3: break
1350
1351        for fid in [fid4]:
1352            fid.createVariable(name[fid],'d',(lat_name,long_name))
1353            fid.variables[name[fid]].point_spacing='uneven'
1354            fid.variables[name[fid]].units=units[fid]
1355            fid.variables[name[fid]].assignValue(values[fid])
1356            fid.variables[name[fid]].missing_value = -99999999.
1357
1358
1359        fid1.sync(); fid1.close()
1360        fid2.sync(); fid2.close()
1361        fid3.sync(); fid3.close()
1362        fid4.sync(); fid4.close()
1363
1364        fid1 = NetCDFFile('test_ha.nc','r')
1365        fid2 = NetCDFFile('test_e.nc','r')
1366        fid3 = NetCDFFile('test_va.nc','r')
1367
1368
1369        first_amp = fid1.variables['HA'][:][0,0,0]
1370        third_amp = fid1.variables['HA'][:][0,0,2]
1371        first_elevation = fid2.variables['ELEVATION'][0,0]
1372        third_elevation= fid2.variables['ELEVATION'][:][0,2]
1373        first_speed = fid3.variables['VA'][0,0,0]
1374        third_speed = fid3.variables['VA'][:][0,0,2]
1375
1376        fid1.close()
1377        fid2.close()
1378        fid3.close()
1379
1380        #Call conversion (with zero origin)
1381        ferret2sww('test', verbose=False,
1382                   origin = (56, 0, 0))
1383
1384        os.remove('test_va.nc')
1385        os.remove('test_ua.nc')
1386        os.remove('test_ha.nc')
1387        os.remove('test_e.nc')
1388
1389        #Read output file 'test.sww'
1390        fid = NetCDFFile('test.sww')
1391
1392
1393        #Check first value
1394        elevation = fid.variables['elevation'][:]
1395        stage = fid.variables['stage'][:]
1396        xmomentum = fid.variables['xmomentum'][:]
1397        ymomentum = fid.variables['ymomentum'][:]
1398
1399        #print ymomentum
1400        first_height = first_amp/100 - first_elevation
1401        third_height = third_amp/100 - third_elevation
1402        first_momentum=first_speed*first_height/100
1403        third_momentum=third_speed*third_height/100
1404
1405        assert allclose(ymomentum[0][0],first_momentum)  #Meters
1406        assert allclose(ymomentum[0][2],third_momentum)  #Meters
1407
1408        fid.close()
1409
1410        #Cleanup
1411        os.remove('test.sww')
1412
1413
1414
1415
1416    def test_sww_extent(self):
1417        """Not a test, rather a look at the sww format
1418        """
1419
1420        import time, os
1421        from Numeric import array, zeros, allclose, Float, concatenate
1422        from Scientific.IO.NetCDF import NetCDFFile
1423
1424        self.domain.filename = 'datatest' + str(id(self))
1425        self.domain.format = 'sww'
1426        self.domain.smooth = True
1427        self.domain.reduction = mean
1428        self.domain.set_datadir('.')
1429
1430
1431        sww = get_dataobject(self.domain)
1432        sww.store_connectivity()
1433        sww.store_timestep('stage')
1434        self.domain.time = 2.
1435
1436        #Modify stage at second timestep
1437        stage = self.domain.quantities['stage'].vertex_values
1438        self.domain.set_quantity('stage', stage/2)
1439
1440        sww.store_timestep('stage')
1441
1442        file_and_extension_name = self.domain.filename + ".sww"
1443        #print "file_and_extension_name",file_and_extension_name
1444        [xmin, xmax, ymin, ymax, stagemin, stagemax] = \
1445               extent_sww(file_and_extension_name )
1446
1447        assert allclose(xmin, 0.0)
1448        assert allclose(xmax, 1.0)
1449        assert allclose(ymin, 0.0)
1450        assert allclose(ymax, 1.0)
1451        assert allclose(stagemin, -0.85)
1452        assert allclose(stagemax, 0.15)
1453
1454
1455        #Cleanup
1456        os.remove(sww.filename)
1457
1458
1459    def test_ferret2sww_nz_origin(self):
1460        from Scientific.IO.NetCDF import NetCDFFile
1461        from coordinate_transforms.redfearn import redfearn
1462
1463        #Call conversion (with nonzero origin)
1464        ferret2sww('small', verbose=False,
1465                   origin = (56, 100000, 200000))
1466
1467
1468        #Work out the UTM coordinates for first point
1469        zone, e, n = redfearn(-34.5, 150.66667)
1470
1471        #Read output file 'small.sww'
1472        fid = NetCDFFile('small.sww', 'r')
1473
1474        x = fid.variables['x'][:]
1475        y = fid.variables['y'][:]
1476
1477        #Check that first coordinate is correctly represented
1478        assert allclose(x[0], e-100000)
1479        assert allclose(y[0], n-200000)
1480
1481        fid.close()
1482
1483        #Cleanup
1484        import os
1485        os.remove('small.sww')
1486
1487    def test_sww2domain(self):
1488        ################################################
1489        #Create a test domain, and evolve and save it.
1490        ################################################
1491        from mesh_factory import rectangular
1492        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1493             Constant_height, Time_boundary, Transmissive_boundary
1494        from Numeric import array
1495
1496        #Create basic mesh
1497
1498        yiel=0.01
1499        points, vertices, boundary = rectangular(10,10)
1500
1501        #Create shallow water domain
1502        domain = Domain(points, vertices, boundary)
1503        domain.geo_reference = Geo_reference(56,11,11)
1504        domain.smooth = False
1505        domain.visualise = False
1506        domain.store = True
1507        domain.filename = 'bedslope'
1508        domain.default_order=2
1509        #Bed-slope and friction
1510        domain.set_quantity('elevation', lambda x,y: -x/3)
1511        domain.set_quantity('friction', 0.1)
1512        # Boundary conditions
1513        from math import sin, pi
1514        Br = Reflective_boundary(domain)
1515        Bt = Transmissive_boundary(domain)
1516        Bd = Dirichlet_boundary([0.2,0.,0.])
1517        Bw = Time_boundary(domain=domain,
1518                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1519
1520        #domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1521        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1522
1523        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1524        #Initial condition
1525        h = 0.05
1526        elevation = domain.quantities['elevation'].vertex_values
1527        domain.set_quantity('stage', elevation + h)
1528        #elevation = domain.get_quantity('elevation')
1529        #domain.set_quantity('stage', elevation + h)
1530
1531        domain.check_integrity()
1532        #Evolution
1533        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1534        #    domain.write_time()
1535            pass
1536
1537
1538        ##########################################
1539        #Import the example's file as a new domain
1540        ##########################################
1541        from data_manager import sww2domain
1542        from Numeric import allclose
1543        import os
1544
1545        filename = domain.datadir+os.sep+domain.filename+'.sww'
1546        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
1547        #points, vertices, boundary = rectangular(15,15)
1548        #domain2.boundary = boundary
1549        ###################
1550        ##NOW TEST IT!!!
1551        ###################
1552
1553        bits = ['vertex_coordinates']
1554        for quantity in ['elevation']+domain.quantities_to_be_stored:
1555            bits.append('quantities["%s"].get_integral()'%quantity)
1556            bits.append('get_quantity("%s")'%quantity)
1557
1558        for bit in bits:
1559            #print 'testing that domain.'+bit+' has been restored'
1560            #print bit
1561        #print 'done'
1562            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1563
1564        ######################################
1565        #Now evolve them both, just to be sure
1566        ######################################x
1567        visualise = False
1568        #visualise = True
1569        domain.visualise = visualise
1570        domain.time = 0.
1571        from time import sleep
1572
1573        final = .1
1574        domain.set_quantity('friction', 0.1)
1575        domain.store = False
1576        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1577
1578
1579        for t in domain.evolve(yieldstep = yiel, finaltime = final):
1580            if visualise: sleep(1.)
1581            #domain.write_time()
1582            pass
1583
1584        final = final - (domain2.starttime-domain.starttime)
1585        #BUT since domain1 gets time hacked back to 0:
1586        final = final + (domain2.starttime-domain.starttime)
1587
1588        domain2.smooth = False
1589        domain2.visualise = visualise
1590        domain2.store = False
1591        domain2.default_order=2
1592        domain2.set_quantity('friction', 0.1)
1593        #Bed-slope and friction
1594        # Boundary conditions
1595        Bd2=Dirichlet_boundary([0.2,0.,0.])
1596        domain2.boundary = domain.boundary
1597        #print 'domain2.boundary'
1598        #print domain2.boundary
1599        domain2.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1600        #domain2.set_boundary({'exterior': Bd})
1601
1602        domain2.check_integrity()
1603
1604        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
1605            if visualise: sleep(1.)
1606            #domain2.write_time()
1607            pass
1608
1609        ###################
1610        ##NOW TEST IT!!!
1611        ##################
1612
1613        bits = [ 'vertex_coordinates']
1614
1615        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
1616            bits.append('quantities["%s"].get_integral()'%quantity)
1617            bits.append('get_quantity("%s")'%quantity)
1618
1619        for bit in bits:
1620            #print bit
1621            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1622
1623
1624    def test_sww2domain2(self):
1625        ##################################################################
1626        #Same as previous test, but this checks how NaNs are handled.
1627        ##################################################################
1628
1629
1630        from mesh_factory import rectangular
1631        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1632             Constant_height, Time_boundary, Transmissive_boundary
1633        from Numeric import array
1634
1635        #Create basic mesh
1636        points, vertices, boundary = rectangular(2,2)
1637
1638        #Create shallow water domain
1639        domain = Domain(points, vertices, boundary)
1640        domain.smooth = False
1641        domain.visualise = False
1642        domain.store = True
1643        domain.filename = 'bedslope'
1644        domain.default_order=2
1645        domain.quantities_to_be_stored=['stage']
1646
1647        domain.set_quantity('elevation', lambda x,y: -x/3)
1648        domain.set_quantity('friction', 0.1)
1649
1650        from math import sin, pi
1651        Br = Reflective_boundary(domain)
1652        Bt = Transmissive_boundary(domain)
1653        Bd = Dirichlet_boundary([0.2,0.,0.])
1654        Bw = Time_boundary(domain=domain,
1655                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1656
1657        domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1658
1659        h = 0.05
1660        elevation = domain.quantities['elevation'].vertex_values
1661        domain.set_quantity('stage', elevation + h)
1662
1663        domain.check_integrity()
1664
1665        for t in domain.evolve(yieldstep = 1, finaltime = 2.0):
1666            pass
1667            #domain.write_time()
1668
1669
1670
1671        ##################################
1672        #Import the file as a new domain
1673        ##################################
1674        from data_manager import sww2domain
1675        from Numeric import allclose
1676        import os
1677
1678        filename = domain.datadir+os.sep+domain.filename+'.sww'
1679
1680        #Fail because NaNs are present
1681        try:
1682            domain2 = sww2domain(filename,boundary,fail_if_NaN=True,verbose=False)
1683            assert True == False
1684        except:
1685            #Now import it, filling NaNs to be 0
1686            filler = 0
1687            domain2 = sww2domain(filename,None,fail_if_NaN=False,NaN_filler = filler,verbose=False)
1688        bits = [ 'geo_reference.get_xllcorner()',
1689                'geo_reference.get_yllcorner()',
1690                'vertex_coordinates']
1691
1692        for quantity in ['elevation']+domain.quantities_to_be_stored:
1693            bits.append('quantities["%s"].get_integral()'%quantity)
1694            bits.append('get_quantity("%s")'%quantity)
1695
1696        for bit in bits:
1697        #    print 'testing that domain.'+bit+' has been restored'
1698            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1699
1700        assert max(max(domain2.get_quantity('xmomentum')))==filler
1701        assert min(min(domain2.get_quantity('xmomentum')))==filler
1702        assert max(max(domain2.get_quantity('ymomentum')))==filler
1703        assert min(min(domain2.get_quantity('ymomentum')))==filler
1704
1705        #print 'passed'
1706
1707        #cleanup
1708        #import os
1709        #os.remove(domain.datadir+'/'+domain.filename+'.sww')
1710
1711
1712    #def test_weed(self):
1713        from data_manager import weed
1714
1715        coordinates1 = [[0.,0.],[1.,0.],[1.,1.],[1.,0.],[2.,0.],[1.,1.]]
1716        volumes1 = [[0,1,2],[3,4,5]]
1717        boundary1= {(0,1): 'external',(1,2): 'not external',(2,0): 'external',(3,4): 'external',(4,5): 'external',(5,3): 'not external'}
1718        coordinates2,volumes2,boundary2=weed(coordinates1,volumes1,boundary1)
1719
1720        points2 = {(0.,0.):None,(1.,0.):None,(1.,1.):None,(2.,0.):None}
1721
1722        assert len(points2)==len(coordinates2)
1723        for i in range(len(coordinates2)):
1724            coordinate = tuple(coordinates2[i])
1725            assert points2.has_key(coordinate)
1726            points2[coordinate]=i
1727
1728        for triangle in volumes1:
1729            for coordinate in triangle:
1730                assert coordinates2[points2[tuple(coordinates1[coordinate])]][0]==coordinates1[coordinate][0]
1731                assert coordinates2[points2[tuple(coordinates1[coordinate])]][1]==coordinates1[coordinate][1]
1732
1733
1734     #FIXME This fails - smooth makes the comparism too hard for allclose
1735    def ztest_sww2domain3(self):
1736        ################################################
1737        #DOMAIN.SMOOTH = TRUE !!!!!!!!!!!!!!!!!!!!!!!!!!
1738        ################################################
1739        from mesh_factory import rectangular
1740        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1741             Constant_height, Time_boundary, Transmissive_boundary
1742        from Numeric import array
1743        #Create basic mesh
1744
1745        yiel=0.01
1746        points, vertices, boundary = rectangular(10,10)
1747
1748        #Create shallow water domain
1749        domain = Domain(points, vertices, boundary)
1750        domain.geo_reference = Geo_reference(56,11,11)
1751        domain.smooth = True
1752        domain.visualise = False
1753        domain.store = True
1754        domain.filename = 'bedslope'
1755        domain.default_order=2
1756        #Bed-slope and friction
1757        domain.set_quantity('elevation', lambda x,y: -x/3)
1758        domain.set_quantity('friction', 0.1)
1759        # Boundary conditions
1760        from math import sin, pi
1761        Br = Reflective_boundary(domain)
1762        Bt = Transmissive_boundary(domain)
1763        Bd = Dirichlet_boundary([0.2,0.,0.])
1764        Bw = Time_boundary(domain=domain,
1765                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1766
1767        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1768
1769        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1770        #Initial condition
1771        h = 0.05
1772        elevation = domain.quantities['elevation'].vertex_values
1773        domain.set_quantity('stage', elevation + h)
1774        #elevation = domain.get_quantity('elevation')
1775        #domain.set_quantity('stage', elevation + h)
1776
1777        domain.check_integrity()
1778        #Evolution
1779        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1780        #    domain.write_time()
1781            pass
1782
1783
1784        ##########################################
1785        #Import the example's file as a new domain
1786        ##########################################
1787        from data_manager import sww2domain
1788        from Numeric import allclose
1789        import os
1790
1791        filename = domain.datadir+os.sep+domain.filename+'.sww'
1792        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
1793        #points, vertices, boundary = rectangular(15,15)
1794        #domain2.boundary = boundary
1795        ###################
1796        ##NOW TEST IT!!!
1797        ###################
1798
1799        #FIXME smooth domain so that they can be compared
1800
1801
1802        bits = []#'vertex_coordinates']
1803        for quantity in ['elevation']+domain.quantities_to_be_stored:
1804            bits.append('quantities["%s"].get_integral()'%quantity)
1805            #bits.append('get_quantity("%s")'%quantity)
1806
1807        for bit in bits:
1808            #print 'testing that domain.'+bit+' has been restored'
1809            #print bit
1810            #print 'done'
1811            #print ('domain.'+bit), eval('domain.'+bit)
1812            #print ('domain2.'+bit), eval('domain2.'+bit)
1813            assert allclose(eval('domain.'+bit),eval('domain2.'+bit),rtol=1.0e-1,atol=1.e-3)
1814            pass
1815
1816        ######################################
1817        #Now evolve them both, just to be sure
1818        ######################################x
1819        visualise = False
1820        visualise = True
1821        domain.visualise = visualise
1822        domain.time = 0.
1823        from time import sleep
1824
1825        final = .5
1826        domain.set_quantity('friction', 0.1)
1827        domain.store = False
1828        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Br})
1829
1830        for t in domain.evolve(yieldstep = yiel, finaltime = final):
1831            if visualise: sleep(.03)
1832            #domain.write_time()
1833            pass
1834
1835        domain2.smooth = True
1836        domain2.visualise = visualise
1837        domain2.store = False
1838        domain2.default_order=2
1839        domain2.set_quantity('friction', 0.1)
1840        #Bed-slope and friction
1841        # Boundary conditions
1842        Bd2=Dirichlet_boundary([0.2,0.,0.])
1843        Br2 = Reflective_boundary(domain2)
1844        domain2.boundary = domain.boundary
1845        #print 'domain2.boundary'
1846        #print domain2.boundary
1847        domain2.set_boundary({'left': Bd2, 'right': Bd2, 'top': Bd2, 'bottom': Br2})
1848        #domain2.boundary = domain.boundary
1849        #domain2.set_boundary({'exterior': Bd})
1850
1851        domain2.check_integrity()
1852
1853        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
1854            if visualise: sleep(.03)
1855            #domain2.write_time()
1856            pass
1857
1858        ###################
1859        ##NOW TEST IT!!!
1860        ##################
1861
1862        bits = [ 'vertex_coordinates']
1863
1864        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
1865            #bits.append('quantities["%s"].get_integral()'%quantity)
1866            bits.append('get_quantity("%s")'%quantity)
1867
1868        for bit in bits:
1869            print bit
1870            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1871
1872
1873    def test_decimate_dem(self):
1874        """Test decimation of dem file
1875        """
1876
1877        import os
1878        from Numeric import ones, allclose, Float, arange
1879        from Scientific.IO.NetCDF import NetCDFFile
1880
1881        #Write test dem file
1882        root = 'decdemtest'
1883
1884        filename = root + '.dem'
1885        fid = NetCDFFile(filename, 'w')
1886
1887        fid.institution = 'Geoscience Australia'
1888        fid.description = 'NetCDF DEM format for compact and portable ' +\
1889                          'storage of spatial point data'
1890
1891        nrows = 15
1892        ncols = 18
1893
1894        fid.ncols = ncols
1895        fid.nrows = nrows
1896        fid.xllcorner = 2000.5
1897        fid.yllcorner = 3000.5
1898        fid.cellsize = 25
1899        fid.NODATA_value = -9999
1900
1901        fid.zone = 56
1902        fid.false_easting = 0.0
1903        fid.false_northing = 0.0
1904        fid.projection = 'UTM'
1905        fid.datum = 'WGS84'
1906        fid.units = 'METERS'
1907
1908        fid.createDimension('number_of_points', nrows*ncols)
1909
1910        fid.createVariable('elevation', Float, ('number_of_points',))
1911
1912        elevation = fid.variables['elevation']
1913
1914        elevation[:] = (arange(nrows*ncols))
1915
1916        fid.close()
1917
1918        #generate the elevation values expected in the decimated file
1919        ref_elevation = [(  0+  1+  2+ 18+ 19+ 20+ 36+ 37+ 38) / 9.0,
1920                         (  4+  5+  6+ 22+ 23+ 24+ 40+ 41+ 42) / 9.0,
1921                         (  8+  9+ 10+ 26+ 27+ 28+ 44+ 45+ 46) / 9.0,
1922                         ( 12+ 13+ 14+ 30+ 31+ 32+ 48+ 49+ 50) / 9.0,
1923                         ( 72+ 73+ 74+ 90+ 91+ 92+108+109+110) / 9.0,
1924                         ( 76+ 77+ 78+ 94+ 95+ 96+112+113+114) / 9.0,
1925                         ( 80+ 81+ 82+ 98+ 99+100+116+117+118) / 9.0,
1926                         ( 84+ 85+ 86+102+103+104+120+121+122) / 9.0,
1927                         (144+145+146+162+163+164+180+181+182) / 9.0,
1928                         (148+149+150+166+167+168+184+185+186) / 9.0,
1929                         (152+153+154+170+171+172+188+189+190) / 9.0,
1930                         (156+157+158+174+175+176+192+193+194) / 9.0,
1931                         (216+217+218+234+235+236+252+253+254) / 9.0,
1932                         (220+221+222+238+239+240+256+257+258) / 9.0,
1933                         (224+225+226+242+243+244+260+261+262) / 9.0,
1934                         (228+229+230+246+247+248+264+265+266) / 9.0]
1935
1936        #generate a stencil for computing the decimated values
1937        stencil = ones((3,3), Float) / 9.0
1938
1939        decimate_dem(root, stencil=stencil, cellsize_new=100)
1940
1941        #Open decimated NetCDF file
1942        fid = NetCDFFile(root + '_100.dem', 'r')
1943
1944        # Get decimated elevation
1945        elevation = fid.variables['elevation']
1946
1947        #Check values
1948        assert allclose(elevation, ref_elevation)
1949
1950        #Cleanup
1951        fid.close()
1952
1953        os.remove(root + '.dem')
1954        os.remove(root + '_100.dem')
1955
1956    def test_decimate_dem_NODATA(self):
1957        """Test decimation of dem file that includes NODATA values
1958        """
1959
1960        import os
1961        from Numeric import ones, allclose, Float, arange, reshape
1962        from Scientific.IO.NetCDF import NetCDFFile
1963
1964        #Write test dem file
1965        root = 'decdemtest'
1966
1967        filename = root + '.dem'
1968        fid = NetCDFFile(filename, 'w')
1969
1970        fid.institution = 'Geoscience Australia'
1971        fid.description = 'NetCDF DEM format for compact and portable ' +\
1972                          'storage of spatial point data'
1973
1974        nrows = 15
1975        ncols = 18
1976        NODATA_value = -9999
1977
1978        fid.ncols = ncols
1979        fid.nrows = nrows
1980        fid.xllcorner = 2000.5
1981        fid.yllcorner = 3000.5
1982        fid.cellsize = 25
1983        fid.NODATA_value = NODATA_value
1984
1985        fid.zone = 56
1986        fid.false_easting = 0.0
1987        fid.false_northing = 0.0
1988        fid.projection = 'UTM'
1989        fid.datum = 'WGS84'
1990        fid.units = 'METERS'
1991
1992        fid.createDimension('number_of_points', nrows*ncols)
1993
1994        fid.createVariable('elevation', Float, ('number_of_points',))
1995
1996        elevation = fid.variables['elevation']
1997
1998        #generate initial elevation values
1999        elevation_tmp = (arange(nrows*ncols))
2000        #add some NODATA values
2001        elevation_tmp[0]   = NODATA_value
2002        elevation_tmp[95]  = NODATA_value
2003        elevation_tmp[188] = NODATA_value
2004        elevation_tmp[189] = NODATA_value
2005        elevation_tmp[190] = NODATA_value
2006        elevation_tmp[209] = NODATA_value
2007        elevation_tmp[252] = NODATA_value
2008
2009        elevation[:] = elevation_tmp
2010
2011        fid.close()
2012
2013        #generate the elevation values expected in the decimated file
2014        ref_elevation = [NODATA_value,
2015                         (  4+  5+  6+ 22+ 23+ 24+ 40+ 41+ 42) / 9.0,
2016                         (  8+  9+ 10+ 26+ 27+ 28+ 44+ 45+ 46) / 9.0,
2017                         ( 12+ 13+ 14+ 30+ 31+ 32+ 48+ 49+ 50) / 9.0,
2018                         ( 72+ 73+ 74+ 90+ 91+ 92+108+109+110) / 9.0,
2019                         NODATA_value,
2020                         ( 80+ 81+ 82+ 98+ 99+100+116+117+118) / 9.0,
2021                         ( 84+ 85+ 86+102+103+104+120+121+122) / 9.0,
2022                         (144+145+146+162+163+164+180+181+182) / 9.0,
2023                         (148+149+150+166+167+168+184+185+186) / 9.0,
2024                         NODATA_value,
2025                         (156+157+158+174+175+176+192+193+194) / 9.0,
2026                         NODATA_value,
2027                         (220+221+222+238+239+240+256+257+258) / 9.0,
2028                         (224+225+226+242+243+244+260+261+262) / 9.0,
2029                         (228+229+230+246+247+248+264+265+266) / 9.0]
2030
2031        #generate a stencil for computing the decimated values
2032        stencil = ones((3,3), Float) / 9.0
2033
2034        decimate_dem(root, stencil=stencil, cellsize_new=100)
2035
2036        #Open decimated NetCDF file
2037        fid = NetCDFFile(root + '_100.dem', 'r')
2038
2039        # Get decimated elevation
2040        elevation = fid.variables['elevation']
2041
2042        #Check values
2043        assert allclose(elevation, ref_elevation)
2044
2045        #Cleanup
2046        fid.close()
2047
2048        os.remove(root + '.dem')
2049        os.remove(root + '_100.dem')
2050
2051
2052#-------------------------------------------------------------
2053if __name__ == "__main__":
2054    suite = unittest.makeSuite(Test_Data_Manager,'test')
2055    #suite = unittest.makeSuite(Test_Data_Manager,'test_dem2pts_bounding_box')
2056    #suite = unittest.makeSuite(Test_Data_Manager,'test_decimate_dem')
2057    #suite = unittest.makeSuite(Test_Data_Manager,'test_decimate_dem_NODATA')
2058    runner = unittest.TextTestRunner()
2059    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.