source: inundation/ga/storm_surge/pyvolution/test_data_manager.py @ 1145

Last change on this file since 1145 was 1145, checked in by ole, 20 years ago

Improved optimisation in sww2asc (happy Easter everyone!)

File size: 41.9 KB
Line 
1#!/usr/bin/env python
2#
3
4import unittest
5import copy
6from Numeric import zeros, array, allclose, Float
7from util import mean
8
9from data_manager import *
10from shallow_water import *
11from config import epsilon
12
13class Test_Data_Manager(unittest.TestCase):
14    def setUp(self):
15        import time
16        from mesh_factory import rectangular
17
18
19        #Create basic mesh
20        points, vertices, boundary = rectangular(2, 2)
21
22        #Create shallow water domain
23        domain = Domain(points, vertices, boundary)
24        domain.default_order=2
25
26
27        #Set some field values
28        domain.set_quantity('elevation', lambda x,y: -x)
29        domain.set_quantity('friction', 0.03)
30
31
32        ######################
33        # Boundary conditions
34        B = Transmissive_boundary(domain)
35        domain.set_boundary( {'left': B, 'right': B, 'top': B, 'bottom': B})
36
37
38        ######################
39        #Initial condition - with jumps
40
41
42        bed = domain.quantities['elevation'].vertex_values
43        stage = zeros(bed.shape, Float)
44
45        h = 0.3
46        for i in range(stage.shape[0]):
47            if i % 2 == 0:
48                stage[i,:] = bed[i,:] + h
49            else:
50                stage[i,:] = bed[i,:]
51
52        domain.set_quantity('stage', stage)
53        self.initial_stage = copy.copy(domain.quantities['stage'].vertex_values)
54
55        domain.distribute_to_vertices_and_edges()
56
57
58        self.domain = domain
59
60        C = domain.get_vertex_coordinates()
61        self.X = C[:,0:6:2].copy()
62        self.Y = C[:,1:6:2].copy()
63
64        self.F = bed
65
66
67    def tearDown(self):
68        pass
69
70
71
72
73#     def test_xya(self):
74#         import os
75#         from Numeric import concatenate
76
77#         import time, os
78#         from Numeric import array, zeros, allclose, Float, concatenate
79
80#         domain = self.domain
81
82#         domain.filename = 'datatest' + str(time.time())
83#         domain.format = 'xya'
84#         domain.smooth = True
85
86#         xya = get_dataobject(self.domain)
87#         xya.store_all()
88
89
90#         #Read back
91#         file = open(xya.filename)
92#         lFile = file.read().split('\n')
93#         lFile = lFile[:-1]
94
95#         file.close()
96#         os.remove(xya.filename)
97
98#         #Check contents
99#         if domain.smooth:
100#             self.failUnless(lFile[0] == '9 3 # <vertex #> <x> <y> [attributes]')
101#         else:
102#             self.failUnless(lFile[0] == '24 3 # <vertex #> <x> <y> [attributes]')
103
104#         #Get smoothed field values with X and Y
105#         X,Y,F,V = domain.get_vertex_values(xy=True, value_array='field_values',
106#                                            indices = (0,1), precision = Float)
107
108
109#         Q,V = domain.get_vertex_values(xy=False, value_array='conserved_quantities',
110#                                            indices = (0,), precision = Float)
111
112
113
114#         for i, line in enumerate(lFile[1:]):
115#             fields = line.split()
116
117#             assert len(fields) == 5
118
119#             assert allclose(float(fields[0]), X[i])
120#             assert allclose(float(fields[1]), Y[i])
121#             assert allclose(float(fields[2]), F[i,0])
122#             assert allclose(float(fields[3]), Q[i,0])
123#             assert allclose(float(fields[4]), F[i,1])
124
125
126
127
128    def test_sww_constant(self):
129        """Test that constant sww information can be written correctly
130        (non smooth)
131        """
132
133        import time, os
134        from Numeric import array, zeros, allclose, Float, concatenate
135        from Scientific.IO.NetCDF import NetCDFFile
136
137        self.domain.filename = 'datatest' + str(id(self))
138        self.domain.format = 'sww'
139        self.domain.smooth = False
140
141        sww = get_dataobject(self.domain)
142        sww.store_connectivity()
143
144        #Check contents
145        #Get NetCDF
146        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
147
148        # Get the variables
149        x = fid.variables['x']
150        y = fid.variables['y']
151        z = fid.variables['elevation']
152
153        volumes = fid.variables['volumes']
154
155
156        assert allclose (x[:], self.X.flat)
157        assert allclose (y[:], self.Y.flat)
158        assert allclose (z[:], self.F.flat)
159
160        V = volumes
161
162        P = len(self.domain)
163        for k in range(P):
164            assert V[k, 0] == 3*k
165            assert V[k, 1] == 3*k+1
166            assert V[k, 2] == 3*k+2
167
168
169        fid.close()
170
171        #Cleanup
172        os.remove(sww.filename)
173
174
175    def test_sww_constant_smooth(self):
176        """Test that constant sww information can be written correctly
177        (non smooth)
178        """
179
180        import time, os
181        from Numeric import array, zeros, allclose, Float, concatenate
182        from Scientific.IO.NetCDF import NetCDFFile
183
184        self.domain.filename = 'datatest' + str(id(self))
185        self.domain.format = 'sww'
186        self.domain.smooth = True
187
188        sww = get_dataobject(self.domain)
189        sww.store_connectivity()
190
191        #Check contents
192        #Get NetCDF
193        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
194
195        # Get the variables
196        x = fid.variables['x']
197        y = fid.variables['y']
198        z = fid.variables['elevation']
199
200        volumes = fid.variables['volumes']
201
202        X = x[:]
203        Y = y[:]
204
205        assert allclose([X[0], Y[0]], array([0.0, 0.0]))
206        assert allclose([X[1], Y[1]], array([0.0, 0.5]))
207        assert allclose([X[2], Y[2]], array([0.0, 1.0]))
208
209        assert allclose([X[4], Y[4]], array([0.5, 0.5]))
210
211        assert allclose([X[7], Y[7]], array([1.0, 0.5]))
212
213        Z = z[:]
214        assert Z[4] == -0.5
215
216        V = volumes
217        assert V[2,0] == 4
218        assert V[2,1] == 5
219        assert V[2,2] == 1
220
221        assert V[4,0] == 6
222        assert V[4,1] == 7
223        assert V[4,2] == 3
224
225
226        fid.close()
227
228        #Cleanup
229        os.remove(sww.filename)
230
231
232
233    def test_sww_variable(self):
234        """Test that sww information can be written correctly
235        """
236
237        import time, os
238        from Numeric import array, zeros, allclose, Float, concatenate
239        from Scientific.IO.NetCDF import NetCDFFile
240
241        self.domain.filename = 'datatest' + str(id(self))
242        self.domain.format = 'sww'
243        self.domain.smooth = True
244        self.domain.reduction = mean
245
246        sww = get_dataobject(self.domain)
247        sww.store_connectivity()
248        sww.store_timestep('stage')
249
250        #Check contents
251        #Get NetCDF
252        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
253
254
255        # Get the variables
256        x = fid.variables['x']
257        y = fid.variables['y']
258        z = fid.variables['elevation']
259        time = fid.variables['time']
260        stage = fid.variables['stage']
261
262
263        Q = self.domain.quantities['stage']
264        Q0 = Q.vertex_values[:,0]
265        Q1 = Q.vertex_values[:,1]
266        Q2 = Q.vertex_values[:,2]
267
268        A = stage[0,:]
269        #print A[0], (Q2[0,0] + Q1[1,0])/2
270        assert allclose(A[0], (Q2[0] + Q1[1])/2)
271        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
272        assert allclose(A[2], Q0[3])
273        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
274
275        #Center point
276        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
277                                 Q0[5] + Q2[6] + Q1[7])/6)
278
279
280
281        fid.close()
282
283        #Cleanup
284        os.remove(sww.filename)
285
286
287    def test_sww_variable2(self):
288        """Test that sww information can be written correctly
289        multiple timesteps. Use average as reduction operator
290        """
291
292        import time, os
293        from Numeric import array, zeros, allclose, Float, concatenate
294        from Scientific.IO.NetCDF import NetCDFFile
295
296        self.domain.filename = 'datatest' + str(id(self))
297        self.domain.format = 'sww'
298        self.domain.smooth = True
299
300        self.domain.reduction = mean
301
302        sww = get_dataobject(self.domain)
303        sww.store_connectivity()
304        sww.store_timestep('stage')
305        self.domain.evolve_to_end(finaltime = 0.01)
306        sww.store_timestep('stage')
307
308
309        #Check contents
310        #Get NetCDF
311        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
312
313        # Get the variables
314        x = fid.variables['x']
315        y = fid.variables['y']
316        z = fid.variables['elevation']
317        time = fid.variables['time']
318        stage = fid.variables['stage']
319
320        #Check values
321        Q = self.domain.quantities['stage']
322        Q0 = Q.vertex_values[:,0]
323        Q1 = Q.vertex_values[:,1]
324        Q2 = Q.vertex_values[:,2]
325
326        A = stage[1,:]
327        assert allclose(A[0], (Q2[0] + Q1[1])/2)
328        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
329        assert allclose(A[2], Q0[3])
330        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
331
332        #Center point
333        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
334                                 Q0[5] + Q2[6] + Q1[7])/6)
335
336
337
338
339
340        fid.close()
341
342        #Cleanup
343        os.remove(sww.filename)
344
345    def test_sww_variable3(self):
346        """Test that sww information can be written correctly
347        multiple timesteps using a different reduction operator (min)
348        """
349
350        import time, os
351        from Numeric import array, zeros, allclose, Float, concatenate
352        from Scientific.IO.NetCDF import NetCDFFile
353
354        self.domain.filename = 'datatest' + str(id(self))
355        self.domain.format = 'sww'
356        self.domain.smooth = True
357        self.domain.reduction = min
358
359        sww = get_dataobject(self.domain)
360        sww.store_connectivity()
361        sww.store_timestep('stage')
362
363        self.domain.evolve_to_end(finaltime = 0.01)
364        sww.store_timestep('stage')
365
366
367        #Check contents
368        #Get NetCDF
369        fid = NetCDFFile(sww.filename, 'r')
370
371
372        # Get the variables
373        x = fid.variables['x']
374        y = fid.variables['y']
375        z = fid.variables['elevation']
376        time = fid.variables['time']
377        stage = fid.variables['stage']
378
379        #Check values
380        Q = self.domain.quantities['stage']
381        Q0 = Q.vertex_values[:,0]
382        Q1 = Q.vertex_values[:,1]
383        Q2 = Q.vertex_values[:,2]
384
385        A = stage[1,:]
386        assert allclose(A[0], min(Q2[0], Q1[1]))
387        assert allclose(A[1], min(Q0[1], Q1[3], Q2[2]))
388        assert allclose(A[2], Q0[3])
389        assert allclose(A[3], min(Q0[0], Q1[5], Q2[4]))
390
391        #Center point
392        assert allclose(A[4], min(Q1[0], Q2[1], Q0[2],\
393                                  Q0[5], Q2[6], Q1[7]))
394
395
396        fid.close()
397
398        #Cleanup
399        os.remove(sww.filename)
400
401
402    def test_sync(self):
403        """Test info stored at each timestep is as expected (incl initial condition)
404        """
405
406        import time, os, config
407        from Numeric import array, zeros, allclose, Float, concatenate
408        from Scientific.IO.NetCDF import NetCDFFile
409
410        self.domain.filename = 'synctest'
411        self.domain.format = 'sww'
412        self.domain.smooth = False
413        self.domain.store = True
414        self.domain.beta_h = 0
415
416        #Evolution
417        for t in self.domain.evolve(yieldstep = 1.0, finaltime = 4.0):
418            stage = self.domain.quantities['stage'].vertex_values
419
420            #Get NetCDF
421            fid = NetCDFFile(self.domain.writer.filename, 'r')
422            stage_file = fid.variables['stage']
423
424            if t == 0.0:
425                assert allclose(stage, self.initial_stage)
426                assert allclose(stage_file[:], stage.flat)
427            else:
428                assert not allclose(stage, self.initial_stage)
429                assert not allclose(stage_file[:], stage.flat)
430
431            fid.close()
432        os.remove(self.domain.writer.filename)
433
434
435
436    def test_sww_DSG(self):
437        """Not a test, rather a look at the sww format
438        """
439
440        import time, os
441        from Numeric import array, zeros, allclose, Float, concatenate
442        from Scientific.IO.NetCDF import NetCDFFile
443
444        self.domain.filename = 'datatest' + str(id(self))
445        self.domain.format = 'sww'
446        self.domain.smooth = True
447        self.domain.reduction = mean
448
449        sww = get_dataobject(self.domain)
450        sww.store_connectivity()
451        sww.store_timestep('stage')
452
453        #Check contents
454        #Get NetCDF
455        fid = NetCDFFile(sww.filename, 'r')
456
457        # Get the variables
458        x = fid.variables['x']
459        y = fid.variables['y']
460        z = fid.variables['elevation']
461
462        volumes = fid.variables['volumes']
463        time = fid.variables['time']
464
465        # 2D
466        stage = fid.variables['stage']
467
468        X = x[:]
469        Y = y[:]
470        Z = z[:]
471        V = volumes[:]
472        T = time[:]
473        S = stage[:,:]
474
475#         print "****************************"
476#         print "X ",X
477#         print "****************************"
478#         print "Y ",Y
479#         print "****************************"
480#         print "Z ",Z
481#         print "****************************"
482#         print "V ",V
483#         print "****************************"
484#         print "Time ",T
485#         print "****************************"
486#         print "Stage ",S
487#         print "****************************"
488
489
490        fid.close()
491
492        #Cleanup
493        os.remove(sww.filename)
494
495
496
497    def test_dem2pts(self):
498        """Test conversion from dem in ascii format to native NetCDF xya format
499        """
500
501        import time, os
502        from Numeric import array, zeros, allclose, Float, concatenate
503        from Scientific.IO.NetCDF import NetCDFFile
504
505        #Write test asc file
506        root = 'demtest'
507
508        filename = root+'.asc'
509        fid = open(filename, 'w')
510        fid.write("""ncols         5
511nrows         6
512xllcorner     2000.5
513yllcorner     3000.5
514cellsize      25
515NODATA_value  -9999
516""")
517        #Create linear function
518
519        ref_points = []
520        ref_elevation = []
521        for i in range(6):
522            y = (6-i)*25.0
523            for j in range(5):
524                x = j*25.0
525                z = x+2*y
526
527                ref_points.append( [x,y] )
528                ref_elevation.append(z)
529                fid.write('%f ' %z)
530            fid.write('\n')
531
532        fid.close()
533
534        #Write prj file with metadata
535        metafilename = root+'.prj'
536        fid = open(metafilename, 'w')
537
538
539        fid.write("""Projection UTM
540Zone 56
541Datum WGS84
542Zunits NO
543Units METERS
544Spheroid WGS84
545Xshift 0.0000000000
546Yshift 10000000.0000000000
547Parameters
548""")
549        fid.close()
550
551        #Convert to NetCDF pts
552        convert_dem_from_ascii2netcdf(root)
553        dem2pts(root)
554
555        #Check contents
556        #Get NetCDF
557        fid = NetCDFFile(root+'.pts', 'r')
558
559        # Get the variables
560        #print fid.variables.keys()
561        points = fid.variables['points']
562        elevation = fid.variables['elevation']
563
564        #Check values
565
566        #print points[:]
567        #print ref_points
568        assert allclose(points, ref_points)
569
570        #print attributes[:]
571        #print ref_elevation
572        assert allclose(elevation, ref_elevation)
573
574        #Cleanup
575        fid.close()
576
577
578        os.remove(root + '.pts')
579        os.remove(root + '.dem')
580        os.remove(root + '.asc')
581        os.remove(root + '.prj')
582
583
584
585    def test_sww2asc_elevation(self):
586        """Test that sww information can be converted correctly to asc/prj
587        format readable by e.g. ArcView
588        """
589
590        import time, os
591        from Numeric import array, zeros, allclose, Float, concatenate
592        from Scientific.IO.NetCDF import NetCDFFile
593
594        #Setup
595        self.domain.filename = 'datatest'
596       
597        prjfile = self.domain.filename + '.prj'
598        ascfile = self.domain.filename + '.asc'       
599        swwfile = self.domain.filename + '.sww'
600       
601        self.domain.set_datadir('.')
602        self.domain.format = 'sww'
603        self.domain.smooth = True
604        self.domain.set_quantity('elevation', lambda x,y: -x-y)
605
606        self.domain.xllcorner = 308500
607        self.domain.yllcorner = 6189000
608        self.domain.zone = 56
609       
610       
611        sww = get_dataobject(self.domain)
612        sww.store_connectivity()
613        sww.store_timestep('stage')
614
615        self.domain.evolve_to_end(finaltime = 0.01)
616        sww.store_timestep('stage')
617
618        cellsize = 0.25
619        #Check contents
620        #Get NetCDF
621
622        fid = NetCDFFile(sww.filename, 'r')
623
624        # Get the variables
625        x = fid.variables['x'][:]
626        y = fid.variables['y'][:]
627        z = fid.variables['elevation'][:]
628        time = fid.variables['time'][:]
629        stage = fid.variables['stage'][:]
630
631
632        #Export to ascii/prj files
633        sww2asc(self.domain.filename, 
634                quantity = 'elevation',                         
635                cellsize = cellsize)
636
637
638        #Check prj (meta data)
639        prjid = open(prjfile)
640
641        lines = prjid.readlines()
642        prjid.close()
643
644        L = lines[0].strip().split()
645        assert L[0].strip().lower() == 'projection'
646        assert L[1].strip().lower() == 'utm'
647
648        L = lines[1].strip().split()
649        assert L[0].strip().lower() == 'zone'
650        assert L[1].strip().lower() == '56'
651
652        L = lines[2].strip().split()
653        assert L[0].strip().lower() == 'datum'
654        assert L[1].strip().lower() == 'wgs84'
655
656        L = lines[3].strip().split()
657        assert L[0].strip().lower() == 'zunits'
658        assert L[1].strip().lower() == 'no'                       
659
660        L = lines[4].strip().split()
661        assert L[0].strip().lower() == 'units'
662        assert L[1].strip().lower() == 'meters'               
663
664        L = lines[5].strip().split()
665        assert L[0].strip().lower() == 'spheroid'
666        assert L[1].strip().lower() == 'wgs84'
667
668        L = lines[6].strip().split()
669        assert L[0].strip().lower() == 'xshift'
670        assert L[1].strip().lower() == '500000'
671
672        L = lines[7].strip().split()
673        assert L[0].strip().lower() == 'yshift'
674        assert L[1].strip().lower() == '10000000'       
675
676        L = lines[8].strip().split()
677        assert L[0].strip().lower() == 'parameters'
678       
679
680        #Check asc file
681        ascid = open(ascfile)
682        lines = ascid.readlines()
683        ascid.close()       
684
685        L = lines[0].strip().split()
686        assert L[0].strip().lower() == 'ncols'
687        assert L[1].strip().lower() == '5'
688
689        L = lines[1].strip().split()
690        assert L[0].strip().lower() == 'nrows'
691        assert L[1].strip().lower() == '5'       
692
693        L = lines[2].strip().split()
694        assert L[0].strip().lower() == 'xllcorner'
695        assert allclose(float(L[1].strip().lower()), 308500)
696
697        L = lines[3].strip().split()
698        assert L[0].strip().lower() == 'yllcorner'
699        assert allclose(float(L[1].strip().lower()), 6189000)
700
701        L = lines[4].strip().split()
702        assert L[0].strip().lower() == 'cellsize'
703        assert allclose(float(L[1].strip().lower()), cellsize)
704
705        L = lines[5].strip().split()
706        assert L[0].strip() == 'NODATA_value'
707        assert L[1].strip().lower() == '-9999'       
708
709        #Check grid values
710        for j in range(5):
711            L = lines[6+j].strip().split()           
712            y = (4-j) * cellsize
713            for i in range(5):
714                assert allclose(float(L[i]), -i*cellsize - y)
715       
716
717        fid.close()
718
719        #Cleanup
720        os.remove(prjfile)
721        os.remove(ascfile)       
722        os.remove(swwfile)
723
724
725    def test_sww2asc_stage_reduction(self):
726        """Test that sww information can be converted correctly to asc/prj
727        format readable by e.g. ArcView
728
729        This tests the reduction of quantity stage using min
730        """
731
732        import time, os
733        from Numeric import array, zeros, allclose, Float, concatenate
734        from Scientific.IO.NetCDF import NetCDFFile
735
736        #Setup
737        self.domain.filename = 'datatest'
738       
739        prjfile = self.domain.filename + '.prj'
740        ascfile = self.domain.filename + '.asc'       
741        swwfile = self.domain.filename + '.sww'
742       
743        self.domain.set_datadir('.')
744        self.domain.format = 'sww'
745        self.domain.smooth = True
746        self.domain.set_quantity('elevation', lambda x,y: -x-y)
747
748        self.domain.xllcorner = 308500
749        self.domain.yllcorner = 6189000
750        self.domain.zone = 56
751       
752       
753        sww = get_dataobject(self.domain)
754        sww.store_connectivity()
755        sww.store_timestep('stage')
756
757        self.domain.evolve_to_end(finaltime = 0.01)
758        sww.store_timestep('stage')
759
760        cellsize = 0.25
761        #Check contents
762        #Get NetCDF
763
764        fid = NetCDFFile(sww.filename, 'r')
765
766        # Get the variables
767        x = fid.variables['x'][:]
768        y = fid.variables['y'][:]
769        z = fid.variables['elevation'][:]
770        time = fid.variables['time'][:]
771        stage = fid.variables['stage'][:]
772
773
774        #Export to ascii/prj files
775        sww2asc(self.domain.filename, 
776                quantity = 'stage',                         
777                cellsize = cellsize,
778                reduction = min)
779
780
781        #Check asc file
782        ascid = open(ascfile)
783        lines = ascid.readlines()
784        ascid.close()       
785
786        L = lines[0].strip().split()
787        assert L[0].strip().lower() == 'ncols'
788        assert L[1].strip().lower() == '5'
789
790        L = lines[1].strip().split()
791        assert L[0].strip().lower() == 'nrows'
792        assert L[1].strip().lower() == '5'       
793
794        L = lines[2].strip().split()
795        assert L[0].strip().lower() == 'xllcorner'
796        assert allclose(float(L[1].strip().lower()), 308500)
797
798        L = lines[3].strip().split()
799        assert L[0].strip().lower() == 'yllcorner'
800        assert allclose(float(L[1].strip().lower()), 6189000)
801
802        L = lines[4].strip().split()
803        assert L[0].strip().lower() == 'cellsize'
804        assert allclose(float(L[1].strip().lower()), cellsize)
805
806        L = lines[5].strip().split()
807        assert L[0].strip() == 'NODATA_value'
808        assert L[1].strip().lower() == '-9999'       
809
810       
811        #Check grid values (where applicable)
812        for j in range(5):
813            if j%2 == 0:
814                L = lines[6+j].strip().split()           
815                jj = 4-j
816                for i in range(5):
817                    if i%2 == 0:
818                        index = jj/2 + i/2*3
819                        val0 = stage[0,index]
820                        val1 = stage[1,index]
821
822                        #print i, j, index, ':', L[i], val0, val1
823                        assert allclose(float(L[i]), min(val0, val1))
824
825
826        fid.close()
827
828        #Cleanup
829        os.remove(prjfile)
830        os.remove(ascfile)       
831        #os.remove(swwfile)
832
833
834
835
836    def test_sww2asc_missing_points(self):
837        """Test that sww information can be converted correctly to asc/prj
838        format readable by e.g. ArcView
839
840        This test includes the writing of missing values
841        """
842       
843        import time, os
844        from Numeric import array, zeros, allclose, Float, concatenate
845        from Scientific.IO.NetCDF import NetCDFFile
846
847        #Setup mesh not coinciding with rectangle.
848        #This will cause missing values to occur in gridded data
849
850
851        points = [                        [1.0, 1.0],
852                              [0.5, 0.5], [1.0, 0.5],
853                  [0.0, 0.0], [0.5, 0.0], [1.0, 0.0]]           
854
855        vertices = [ [4,1,3], [5,2,4], [1,4,2], [2,0,1]]
856       
857        #Create shallow water domain
858        domain = Domain(points, vertices)
859        domain.default_order=2
860
861
862        #Set some field values
863        domain.set_quantity('elevation', lambda x,y: -x-y)       
864        domain.set_quantity('friction', 0.03)
865
866
867        ######################
868        # Boundary conditions
869        B = Transmissive_boundary(domain)
870        domain.set_boundary( {'exterior': B} )
871
872
873        ######################
874        #Initial condition - with jumps
875
876        bed = domain.quantities['elevation'].vertex_values
877        stage = zeros(bed.shape, Float)
878
879        h = 0.3
880        for i in range(stage.shape[0]):
881            if i % 2 == 0:
882                stage[i,:] = bed[i,:] + h
883            else:
884                stage[i,:] = bed[i,:]
885
886        domain.set_quantity('stage', stage)
887        domain.distribute_to_vertices_and_edges()
888
889        domain.filename = 'datatest'
890       
891        prjfile = domain.filename + '.prj'
892        ascfile = domain.filename + '.asc'       
893        swwfile = domain.filename + '.sww'
894       
895        domain.set_datadir('.')
896        domain.format = 'sww'
897        domain.smooth = True
898
899
900        domain.xllcorner = 308500
901        domain.yllcorner = 6189000
902        domain.zone = 56
903       
904       
905        sww = get_dataobject(domain)
906        sww.store_connectivity()
907        sww.store_timestep('stage')
908
909        cellsize = 0.25
910        #Check contents
911        #Get NetCDF
912
913        fid = NetCDFFile(swwfile, 'r')
914
915        # Get the variables
916        x = fid.variables['x'][:]
917        y = fid.variables['y'][:]
918        z = fid.variables['elevation'][:]
919        time = fid.variables['time'][:]
920
921        #Export to ascii/prj files
922        sww2asc(domain.filename, 
923                quantity = 'elevation',                         
924                cellsize = cellsize)
925
926
927        #Check asc file
928        ascid = open(ascfile)
929        lines = ascid.readlines()
930        ascid.close()       
931
932        L = lines[0].strip().split()
933        assert L[0].strip().lower() == 'ncols'
934        assert L[1].strip().lower() == '5'
935
936        L = lines[1].strip().split()
937        assert L[0].strip().lower() == 'nrows'
938        assert L[1].strip().lower() == '5'       
939
940        L = lines[2].strip().split()
941        assert L[0].strip().lower() == 'xllcorner'
942        assert allclose(float(L[1].strip().lower()), 308500)
943
944        L = lines[3].strip().split()
945        assert L[0].strip().lower() == 'yllcorner'
946        assert allclose(float(L[1].strip().lower()), 6189000)
947
948        L = lines[4].strip().split()
949        assert L[0].strip().lower() == 'cellsize'
950        assert allclose(float(L[1].strip().lower()), cellsize)
951
952        L = lines[5].strip().split()
953        assert L[0].strip() == 'NODATA_value'
954        assert L[1].strip().lower() == '-9999'       
955
956
957        #Check grid values
958        for j in range(5):
959            L = lines[6+j].strip().split()           
960            y = (4-j) * cellsize
961            for i in range(5):
962                if i+j >= 4:
963                    assert allclose(float(L[i]), -i*cellsize - y)
964                else:
965                    #Missing values
966                    assert allclose(float(L[i]), -9999)
967
968
969
970        fid.close()
971
972        #Cleanup
973        os.remove(prjfile)
974        os.remove(ascfile)       
975        os.remove(swwfile)
976
977
978    def test_ferret2sww(self):
979        """Test that georeferencing etc works when converting from
980        ferret format (lat/lon) to sww format (UTM)
981        """
982        from Scientific.IO.NetCDF import NetCDFFile
983
984        #The test file has
985        # LON = 150.66667, 150.83334, 151, 151.16667
986        # LAT = -34.5, -34.33333, -34.16667, -34 ;
987        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
988        #
989        # First value (index=0) in small_ha.nc is 0.3400644 cm,
990        # Fourth value (index==3) is -6.50198 cm
991
992
993        from coordinate_transforms.redfearn import redfearn
994
995        fid = NetCDFFile('small_ha.nc')
996        first_value = fid.variables['HA'][:][0,0,0]
997        fourth_value = fid.variables['HA'][:][0,0,3]
998
999
1000        #Call conversion (with zero origin)
1001        ferret2sww('small', verbose=False,
1002                   origin = (56, 0, 0))
1003
1004
1005        #Work out the UTM coordinates for first point
1006        zone, e, n = redfearn(-34.5, 150.66667)
1007        #print zone, e, n
1008
1009        #Read output file 'small.sww'
1010        fid = NetCDFFile('small.sww')
1011
1012        x = fid.variables['x'][:]
1013        y = fid.variables['y'][:]
1014
1015        #Check that first coordinate is correctly represented
1016        assert allclose(x[0], e)
1017        assert allclose(y[0], n)
1018
1019        #Check first value
1020        stage = fid.variables['stage'][:]
1021        xmomentum = fid.variables['xmomentum'][:]
1022        ymomentum = fid.variables['ymomentum'][:]       
1023
1024        #print ymomentum
1025               
1026        assert allclose(stage[0,0], first_value/100)  #Meters
1027
1028        #Check fourth value
1029        assert allclose(stage[0,3], fourth_value/100)  #Meters
1030
1031        fid.close()
1032
1033        #Cleanup
1034        import os
1035        os.remove('small.sww')
1036
1037
1038
1039    def test_ferret2sww_2(self):
1040        """Test that georeferencing etc works when converting from
1041        ferret format (lat/lon) to sww format (UTM)
1042        """
1043        from Scientific.IO.NetCDF import NetCDFFile
1044
1045        #The test file has
1046        # LON = 150.66667, 150.83334, 151, 151.16667
1047        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1048        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1049        #
1050        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1051        # Fourth value (index==3) is -6.50198 cm
1052
1053
1054        from coordinate_transforms.redfearn import redfearn
1055
1056        fid = NetCDFFile('small_ha.nc')
1057
1058        #Pick a coordinate and a value
1059
1060        time_index = 1
1061        lat_index = 0
1062        lon_index = 2
1063
1064        test_value = fid.variables['HA'][:][time_index, lat_index, lon_index]
1065        test_time = fid.variables['TIME'][:][time_index]
1066        test_lat = fid.variables['LAT'][:][lat_index]
1067        test_lon = fid.variables['LON'][:][lon_index]
1068
1069        linear_point_index = lat_index*4 + lon_index
1070        fid.close()
1071
1072        #Call conversion (with zero origin)
1073        ferret2sww('small', verbose=False,
1074                   origin = (56, 0, 0))
1075
1076
1077        #Work out the UTM coordinates for test point
1078        zone, e, n = redfearn(test_lat, test_lon)
1079
1080        #Read output file 'small.sww'
1081        fid = NetCDFFile('small.sww')
1082
1083        x = fid.variables['x'][:]
1084        y = fid.variables['y'][:]
1085
1086        #Check that test coordinate is correctly represented
1087        assert allclose(x[linear_point_index], e)
1088        assert allclose(y[linear_point_index], n)
1089
1090        #Check test value
1091        stage = fid.variables['stage'][:]
1092
1093        assert allclose(stage[time_index, linear_point_index], test_value/100)
1094
1095        fid.close()
1096
1097        #Cleanup
1098        import os
1099        os.remove('small.sww')
1100
1101
1102
1103    def test_ferret2sww3(self):
1104        """
1105        """
1106        from Scientific.IO.NetCDF import NetCDFFile
1107
1108        #The test file has
1109        # LON = 150.66667, 150.83334, 151, 151.16667
1110        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1111        # ELEVATION = [-1 -2 -3 -4
1112        #             -5 -6 -7 -8
1113        #              ...
1114        #              ...      -16]
1115        # where the top left corner is -1m,
1116        # and the ll corner is -13.0m
1117        #
1118        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1119        # Fourth value (index==3) is -6.50198 cm
1120
1121        from coordinate_transforms.redfearn import redfearn
1122        import os
1123        fid1 = NetCDFFile('test_ha.nc','w')
1124        fid2 = NetCDFFile('test_ua.nc','w')
1125        fid3 = NetCDFFile('test_va.nc','w')
1126        fid4 = NetCDFFile('test_e.nc','w')
1127
1128        h1_list = [150.66667,150.83334,151.]
1129        h2_list = [-34.5,-34.33333]
1130
1131        long_name = 'LON'
1132        lat_name = 'LAT'
1133
1134        nx = 3
1135        ny = 2
1136
1137        for fid in [fid1,fid2,fid3]:
1138            fid.createDimension(long_name,nx)
1139            fid.createVariable(long_name,'d',(long_name,))
1140            fid.variables[long_name].point_spacing='uneven'
1141            fid.variables[long_name].units='degrees_east'
1142            fid.variables[long_name].assignValue(h1_list)
1143
1144            fid.createDimension(lat_name,ny)
1145            fid.createVariable(lat_name,'d',(lat_name,))
1146            fid.variables[lat_name].point_spacing='uneven'
1147            fid.variables[lat_name].units='degrees_north'
1148            fid.variables[lat_name].assignValue(h2_list)
1149       
1150            fid.createDimension('TIME',2)
1151            fid.createVariable('TIME','d',('TIME',))
1152            fid.variables['TIME'].point_spacing='uneven'
1153            fid.variables['TIME'].units='seconds'
1154            fid.variables['TIME'].assignValue([0.,1.])
1155            if fid == fid3: break
1156       
1157
1158        for fid in [fid4]:
1159            fid.createDimension(long_name,nx)
1160            fid.createVariable(long_name,'d',(long_name,))
1161            fid.variables[long_name].point_spacing='uneven'
1162            fid.variables[long_name].units='degrees_east'
1163            fid.variables[long_name].assignValue(h1_list)
1164
1165            fid.createDimension(lat_name,ny)
1166            fid.createVariable(lat_name,'d',(lat_name,))
1167            fid.variables[lat_name].point_spacing='uneven'
1168            fid.variables[lat_name].units='degrees_north'
1169            fid.variables[lat_name].assignValue(h2_list)
1170
1171        name = {}
1172        name[fid1]='HA'
1173        name[fid2]='UA'
1174        name[fid3]='VA'
1175        name[fid4]='ELEVATION'
1176       
1177        units = {}
1178        units[fid1]='cm'
1179        units[fid2]='cm/s'
1180        units[fid3]='cm/s'
1181        units[fid4]='m'
1182
1183        values = {}
1184        values[fid1]=[[[5., 10.,15.], [13.,18.,23.]],[[50.,100.,150.],[130.,180.,230.]]]
1185        values[fid2]=[[[1., 2.,3.], [4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]]
1186        values[fid3]=[[[13., 12.,11.], [10.,9.,8.]],[[7.,6.,5.],[4.,3.,2.]]]
1187        values[fid4]=[[-3000,-3100,-3200],[-4000,-5000,-6000]]
1188       
1189        for fid in [fid1,fid2,fid3]:
1190          fid.createVariable(name[fid],'d',('TIME',lat_name,long_name))
1191          fid.variables[name[fid]].point_spacing='uneven'
1192          fid.variables[name[fid]].units=units[fid]
1193          fid.variables[name[fid]].assignValue(values[fid])
1194          fid.variables[name[fid]].missing_value = -99999999.
1195          if fid == fid3: break
1196
1197        for fid in [fid4]:
1198            fid.createVariable(name[fid],'d',(lat_name,long_name))
1199            fid.variables[name[fid]].point_spacing='uneven'
1200            fid.variables[name[fid]].units=units[fid]
1201            fid.variables[name[fid]].assignValue(values[fid])
1202            fid.variables[name[fid]].missing_value = -99999999.
1203
1204       
1205        fid1.sync(); fid1.close()
1206        fid2.sync(); fid2.close()
1207        fid3.sync(); fid3.close()
1208        fid4.sync(); fid4.close()
1209
1210        fid1 = NetCDFFile('test_ha.nc','r')
1211        fid2 = NetCDFFile('test_e.nc','r')
1212        fid3 = NetCDFFile('test_va.nc','r')
1213       
1214
1215        first_amp = fid1.variables['HA'][:][0,0,0]
1216        third_amp = fid1.variables['HA'][:][0,0,2]
1217        first_elevation = fid2.variables['ELEVATION'][0,0]
1218        third_elevation= fid2.variables['ELEVATION'][:][0,2]
1219        first_speed = fid3.variables['VA'][0,0,0]
1220        third_speed = fid3.variables['VA'][:][0,0,2]
1221
1222        fid1.close()
1223        fid2.close()
1224        fid3.close()
1225
1226        #Call conversion (with zero origin)
1227        ferret2sww('test', verbose=False,
1228                   origin = (56, 0, 0))
1229
1230        os.remove('test_va.nc')
1231        os.remove('test_ua.nc')
1232        os.remove('test_ha.nc')
1233        os.remove('test_e.nc')
1234
1235        #Read output file 'test.sww'
1236        fid = NetCDFFile('test.sww')
1237
1238
1239        #Check first value
1240        elevation = fid.variables['elevation'][:]
1241        stage = fid.variables['stage'][:]
1242        xmomentum = fid.variables['xmomentum'][:]
1243        ymomentum = fid.variables['ymomentum'][:]       
1244
1245        #print ymomentum
1246        first_height = first_amp/100 - first_elevation
1247        third_height = third_amp/100 - third_elevation
1248        first_momentum=first_speed*first_height/100
1249        third_momentum=third_speed*third_height/100
1250
1251        assert allclose(ymomentum[0][0],first_momentum)  #Meters
1252        assert allclose(ymomentum[0][2],third_momentum)  #Meters
1253
1254        fid.close()
1255
1256        #Cleanup
1257        os.remove('test.sww')
1258
1259
1260
1261
1262    def test_sww_extent(self):
1263        """Not a test, rather a look at the sww format
1264        """
1265
1266        import time, os
1267        from Numeric import array, zeros, allclose, Float, concatenate
1268        from Scientific.IO.NetCDF import NetCDFFile
1269
1270        self.domain.filename = 'datatest' + str(id(self))
1271        self.domain.format = 'sww'
1272        self.domain.smooth = True
1273        self.domain.reduction = mean
1274        self.domain.set_datadir('.')
1275
1276
1277        sww = get_dataobject(self.domain)
1278        sww.store_connectivity()
1279        sww.store_timestep('stage')
1280        self.domain.time = 2.
1281
1282        #Modify stage at second timestep
1283        stage = self.domain.quantities['stage'].vertex_values
1284        self.domain.set_quantity('stage', stage/2)
1285
1286        sww.store_timestep('stage')
1287
1288        file_and_extension_name = self.domain.filename + ".sww"
1289        #print "file_and_extension_name",file_and_extension_name
1290        [xmin, xmax, ymin, ymax, stagemin, stagemax] = \
1291               extent_sww(file_and_extension_name )
1292
1293        assert allclose(xmin, 0.0)
1294        assert allclose(xmax, 1.0)
1295        assert allclose(ymin, 0.0)
1296        assert allclose(ymax, 1.0)
1297        assert allclose(stagemin, -0.85)
1298        assert allclose(stagemax, 0.15)
1299
1300
1301        #Cleanup
1302        os.remove(sww.filename)
1303
1304
1305    def test_ferret2sww_nz_origin(self):
1306        from Scientific.IO.NetCDF import NetCDFFile
1307        from coordinate_transforms.redfearn import redfearn
1308
1309        #Call conversion (with nonzero origin)
1310        ferret2sww('small', verbose=False,
1311                   origin = (56, 100000, 200000))
1312
1313
1314        #Work out the UTM coordinates for first point
1315        zone, e, n = redfearn(-34.5, 150.66667)
1316
1317        #Read output file 'small.sww'
1318        fid = NetCDFFile('small.sww', 'r')
1319
1320        x = fid.variables['x'][:]
1321        y = fid.variables['y'][:]
1322
1323        #Check that first coordinate is correctly represented
1324        assert allclose(x[0], e-100000)
1325        assert allclose(y[0], n-200000)
1326
1327        fid.close()
1328
1329        #Cleanup
1330        import os
1331        os.remove('small.sww')
1332
1333    def test_sww2domain(self):
1334        ################################################
1335        #Create a test domain, and evolve and save it.
1336        ################################################
1337        from mesh_factory import rectangular
1338        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1339             Constant_height, Time_boundary, Transmissive_boundary
1340        from Numeric import array
1341
1342        #Create basic mesh
1343        points, vertices, boundary = rectangular(2,2)
1344
1345        #Create shallow water domain
1346        domain = Domain(points, vertices, boundary)
1347        domain.smooth = False
1348        domain.visualise = False
1349        domain.store = True
1350        domain.filename = 'bedslope'
1351        domain.default_order=2
1352        #Bed-slope and friction
1353        domain.set_quantity('elevation', lambda x,y: -x/3)
1354        domain.set_quantity('friction', 0.1)
1355        # Boundary conditions
1356        from math import sin, pi
1357        Br = Reflective_boundary(domain)
1358        Bt = Transmissive_boundary(domain)
1359        Bd = Dirichlet_boundary([0.2,0.,0.])
1360        Bw = Time_boundary(domain=domain,
1361                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1362
1363        domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1364        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1365        #Initial condition
1366        h = 0.05
1367        elevation = domain.quantities['elevation'].vertex_values
1368        domain.set_quantity('stage', elevation + h)
1369        #elevation = domain.get_quantity('elevation',location='unique vertices')
1370        #domain.set_quantity('stage', elevation + h,location='unique vertices')
1371
1372        domain.check_integrity()
1373        dir(domain)
1374        #Evolution
1375        for t in domain.evolve(yieldstep = 1, finaltime = 2.0):
1376        #    domain.write_time()
1377            pass
1378
1379
1380        ##########################################
1381        #Import the example's file as a new domain
1382        ##########################################
1383        from data_manager import sww2domain
1384        from Numeric import allclose
1385
1386        filename = domain.datadir+'\\'+domain.filename+'.sww'
1387
1388        domain2 = sww2domain(filename,fail_if_NaN=False,verbose = False)
1389
1390        ###################
1391        ##NOW TEST IT!!!
1392        ##################
1393
1394        bits = ['xllcorner','yllcorner','vertex_coordinates','time','starttime']
1395
1396        for quantity in ['elevation']+domain.quantities_to_be_stored:
1397            bits.append('get_quantity("%s")'%quantity)
1398
1399        for bit in bits:
1400        #    print 'testing that domain.'+bit+' has been restored'
1401            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1402
1403        #print 'passed'
1404
1405
1406    def test_sww2domain2(self):
1407        ##################################################################
1408        #Same as previous test, but this checks how NaNs are handled.
1409        ##################################################################
1410
1411
1412        from mesh_factory import rectangular
1413        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1414             Constant_height, Time_boundary, Transmissive_boundary
1415        from Numeric import array
1416
1417        #Create basic mesh
1418        points, vertices, boundary = rectangular(2,2)
1419
1420        #Create shallow water domain
1421        domain = Domain(points, vertices, boundary)
1422        domain.smooth = False
1423        domain.visualise = False
1424        domain.store = True
1425        domain.filename = 'bedslope'
1426        domain.default_order=2
1427        domain.quantities_to_be_stored=['stage']
1428
1429        domain.set_quantity('elevation', lambda x,y: -x/3)
1430        domain.set_quantity('friction', 0.1)
1431
1432        from math import sin, pi
1433        Br = Reflective_boundary(domain)
1434        Bt = Transmissive_boundary(domain)
1435        Bd = Dirichlet_boundary([0.2,0.,0.])
1436        Bw = Time_boundary(domain=domain,
1437                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1438
1439        domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1440
1441        h = 0.05
1442        elevation = domain.quantities['elevation'].vertex_values
1443        domain.set_quantity('stage', elevation + h)
1444        #elevation = domain.get_quantity('elevation',location='unique vertices')
1445        #domain.set_quantity('stage', elevation + h,location='unique vertices')
1446
1447        domain.check_integrity()
1448       
1449        for t in domain.evolve(yieldstep = 1, finaltime = 2.0):
1450            pass
1451            #domain.write_time()
1452
1453        ##################################
1454        #Import the file as a new domain
1455        ##################################
1456        from data_manager import sww2domain
1457        from Numeric import allclose
1458
1459        filename = domain.datadir+'\\'+domain.filename+'.sww'
1460
1461        #Fail because NaNs are present
1462        try:
1463            domain2 = sww2domain(filename,fail_if_NaN=True,verbose=False)
1464            assert True == False
1465        except:
1466            #Now import it, filling NaNs to be 0
1467            filler = 0
1468            domain2 = sww2domain(filename,fail_if_NaN=False,NaN_filler = filler,verbose=False)
1469
1470        bits = ['xllcorner','yllcorner','vertex_coordinates','time','starttime']
1471
1472        for quantity in ['elevation']+domain.quantities_to_be_stored:
1473            bits.append('get_quantity("%s")'%quantity)
1474
1475        for bit in bits:
1476        #    print 'testing that domain.'+bit+' has been restored'
1477            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1478
1479        #print max(max(domain2.get_quantity('xmomentum')))
1480        #print min(min(domain2.get_quantity('xmomentum')))
1481        #print max(max(domain2.get_quantity('ymomentum')))
1482        #print min(min(domain2.get_quantity('ymomentum')))
1483
1484        assert max(max(domain2.get_quantity('xmomentum')))==filler
1485        assert min(min(domain2.get_quantity('xmomentum')))==filler
1486        assert max(max(domain2.get_quantity('ymomentum')))==filler
1487        assert min(min(domain2.get_quantity('ymomentum')))==filler
1488
1489        #print 'passed'
1490
1491        #cleanup
1492        #import os
1493        #os.remove(domain.datadir+'/'+domain.filename+'.sww')
1494
1495#-------------------------------------------------------------
1496if __name__ == "__main__":
1497    suite = unittest.makeSuite(Test_Data_Manager,'test')   
1498    #suite = unittest.makeSuite(Test_Data_Manager,'test_sww2asc_mis')
1499    runner = unittest.TextTestRunner()
1500    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.