source: inundation/pyvolution/test_data_manager.py @ 1865

Last change on this file since 1865 was 1865, checked in by ole, 18 years ago

Wrapped sww2asc and sww2ers into one new function: sww2dem
Finished unittest for ers format

File size: 62.5 KB
Line 
1#!/usr/bin/env python
2#
3
4import unittest
5import copy
6from Numeric import zeros, array, allclose, Float
7from util import mean
8
9from data_manager import *
10from shallow_water import *
11from config import epsilon
12
13from coordinate_transforms.geo_reference import Geo_reference
14
15class Test_Data_Manager(unittest.TestCase):
16    def setUp(self):
17        import time
18        from mesh_factory import rectangular
19
20
21        #Create basic mesh
22        points, vertices, boundary = rectangular(2, 2)
23
24        #Create shallow water domain
25        domain = Domain(points, vertices, boundary)
26        domain.default_order=2
27
28
29        #Set some field values
30        domain.set_quantity('elevation', lambda x,y: -x)
31        domain.set_quantity('friction', 0.03)
32
33
34        ######################
35        # Boundary conditions
36        B = Transmissive_boundary(domain)
37        domain.set_boundary( {'left': B, 'right': B, 'top': B, 'bottom': B})
38
39
40        ######################
41        #Initial condition - with jumps
42
43
44        bed = domain.quantities['elevation'].vertex_values
45        stage = zeros(bed.shape, Float)
46
47        h = 0.3
48        for i in range(stage.shape[0]):
49            if i % 2 == 0:
50                stage[i,:] = bed[i,:] + h
51            else:
52                stage[i,:] = bed[i,:]
53
54        domain.set_quantity('stage', stage)
55        self.initial_stage = copy.copy(domain.quantities['stage'].vertex_values)
56
57        domain.distribute_to_vertices_and_edges()
58
59
60        self.domain = domain
61
62        C = domain.get_vertex_coordinates()
63        self.X = C[:,0:6:2].copy()
64        self.Y = C[:,1:6:2].copy()
65
66        self.F = bed
67
68
69    def tearDown(self):
70        pass
71
72
73
74
75#     def test_xya(self):
76#         import os
77#         from Numeric import concatenate
78
79#         import time, os
80#         from Numeric import array, zeros, allclose, Float, concatenate
81
82#         domain = self.domain
83
84#         domain.filename = 'datatest' + str(time.time())
85#         domain.format = 'xya'
86#         domain.smooth = True
87
88#         xya = get_dataobject(self.domain)
89#         xya.store_all()
90
91
92#         #Read back
93#         file = open(xya.filename)
94#         lFile = file.read().split('\n')
95#         lFile = lFile[:-1]
96
97#         file.close()
98#         os.remove(xya.filename)
99
100#         #Check contents
101#         if domain.smooth:
102#             self.failUnless(lFile[0] == '9 3 # <vertex #> <x> <y> [attributes]')
103#         else:
104#             self.failUnless(lFile[0] == '24 3 # <vertex #> <x> <y> [attributes]')
105
106#         #Get smoothed field values with X and Y
107#         X,Y,F,V = domain.get_vertex_values(xy=True, value_array='field_values',
108#                                            indices = (0,1), precision = Float)
109
110
111#         Q,V = domain.get_vertex_values(xy=False, value_array='conserved_quantities',
112#                                            indices = (0,), precision = Float)
113
114
115
116#         for i, line in enumerate(lFile[1:]):
117#             fields = line.split()
118
119#             assert len(fields) == 5
120
121#             assert allclose(float(fields[0]), X[i])
122#             assert allclose(float(fields[1]), Y[i])
123#             assert allclose(float(fields[2]), F[i,0])
124#             assert allclose(float(fields[3]), Q[i,0])
125#             assert allclose(float(fields[4]), F[i,1])
126
127
128
129
130    def test_sww_constant(self):
131        """Test that constant sww information can be written correctly
132        (non smooth)
133        """
134
135        import time, os
136        from Numeric import array, zeros, allclose, Float, concatenate
137        from Scientific.IO.NetCDF import NetCDFFile
138
139        self.domain.filename = 'datatest' + str(id(self))
140        self.domain.format = 'sww'
141        self.domain.smooth = False
142
143        sww = get_dataobject(self.domain)
144        sww.store_connectivity()
145
146        #Check contents
147        #Get NetCDF
148        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
149
150        # Get the variables
151        x = fid.variables['x']
152        y = fid.variables['y']
153        z = fid.variables['elevation']
154
155        volumes = fid.variables['volumes']
156
157
158        assert allclose (x[:], self.X.flat)
159        assert allclose (y[:], self.Y.flat)
160        assert allclose (z[:], self.F.flat)
161
162        V = volumes
163
164        P = len(self.domain)
165        for k in range(P):
166            assert V[k, 0] == 3*k
167            assert V[k, 1] == 3*k+1
168            assert V[k, 2] == 3*k+2
169
170
171        fid.close()
172
173        #Cleanup
174        os.remove(sww.filename)
175
176
177    def test_sww_constant_smooth(self):
178        """Test that constant sww information can be written correctly
179        (non smooth)
180        """
181
182        import time, os
183        from Numeric import array, zeros, allclose, Float, concatenate
184        from Scientific.IO.NetCDF import NetCDFFile
185
186        self.domain.filename = 'datatest' + str(id(self))
187        self.domain.format = 'sww'
188        self.domain.smooth = True
189
190        sww = get_dataobject(self.domain)
191        sww.store_connectivity()
192
193        #Check contents
194        #Get NetCDF
195        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
196
197        # Get the variables
198        x = fid.variables['x']
199        y = fid.variables['y']
200        z = fid.variables['elevation']
201
202        volumes = fid.variables['volumes']
203
204        X = x[:]
205        Y = y[:]
206
207        assert allclose([X[0], Y[0]], array([0.0, 0.0]))
208        assert allclose([X[1], Y[1]], array([0.0, 0.5]))
209        assert allclose([X[2], Y[2]], array([0.0, 1.0]))
210
211        assert allclose([X[4], Y[4]], array([0.5, 0.5]))
212
213        assert allclose([X[7], Y[7]], array([1.0, 0.5]))
214
215        Z = z[:]
216        assert Z[4] == -0.5
217
218        V = volumes
219        assert V[2,0] == 4
220        assert V[2,1] == 5
221        assert V[2,2] == 1
222
223        assert V[4,0] == 6
224        assert V[4,1] == 7
225        assert V[4,2] == 3
226
227
228        fid.close()
229
230        #Cleanup
231        os.remove(sww.filename)
232
233
234
235    def test_sww_variable(self):
236        """Test that sww information can be written correctly
237        """
238
239        import time, os
240        from Numeric import array, zeros, allclose, Float, concatenate
241        from Scientific.IO.NetCDF import NetCDFFile
242
243        self.domain.filename = 'datatest' + str(id(self))
244        self.domain.format = 'sww'
245        self.domain.smooth = True
246        self.domain.reduction = mean
247
248        sww = get_dataobject(self.domain)
249        sww.store_connectivity()
250        sww.store_timestep('stage')
251
252        #Check contents
253        #Get NetCDF
254        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
255
256
257        # Get the variables
258        x = fid.variables['x']
259        y = fid.variables['y']
260        z = fid.variables['elevation']
261        time = fid.variables['time']
262        stage = fid.variables['stage']
263
264
265        Q = self.domain.quantities['stage']
266        Q0 = Q.vertex_values[:,0]
267        Q1 = Q.vertex_values[:,1]
268        Q2 = Q.vertex_values[:,2]
269
270        A = stage[0,:]
271        #print A[0], (Q2[0,0] + Q1[1,0])/2
272        assert allclose(A[0], (Q2[0] + Q1[1])/2)
273        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
274        assert allclose(A[2], Q0[3])
275        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
276
277        #Center point
278        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
279                                 Q0[5] + Q2[6] + Q1[7])/6)
280
281
282
283        fid.close()
284
285        #Cleanup
286        os.remove(sww.filename)
287
288
289    def test_sww_variable2(self):
290        """Test that sww information can be written correctly
291        multiple timesteps. Use average as reduction operator
292        """
293
294        import time, os
295        from Numeric import array, zeros, allclose, Float, concatenate
296        from Scientific.IO.NetCDF import NetCDFFile
297
298        self.domain.filename = 'datatest' + str(id(self))
299        self.domain.format = 'sww'
300        self.domain.smooth = True
301
302        self.domain.reduction = mean
303
304        sww = get_dataobject(self.domain)
305        sww.store_connectivity()
306        sww.store_timestep('stage')
307        self.domain.evolve_to_end(finaltime = 0.01)
308        sww.store_timestep('stage')
309
310
311        #Check contents
312        #Get NetCDF
313        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
314
315        # Get the variables
316        x = fid.variables['x']
317        y = fid.variables['y']
318        z = fid.variables['elevation']
319        time = fid.variables['time']
320        stage = fid.variables['stage']
321
322        #Check values
323        Q = self.domain.quantities['stage']
324        Q0 = Q.vertex_values[:,0]
325        Q1 = Q.vertex_values[:,1]
326        Q2 = Q.vertex_values[:,2]
327
328        A = stage[1,:]
329        assert allclose(A[0], (Q2[0] + Q1[1])/2)
330        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
331        assert allclose(A[2], Q0[3])
332        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
333
334        #Center point
335        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
336                                 Q0[5] + Q2[6] + Q1[7])/6)
337
338
339        fid.close()
340
341        #Cleanup
342        os.remove(sww.filename)
343
344    def test_sww_variable3(self):
345        """Test that sww information can be written correctly
346        multiple timesteps using a different reduction operator (min)
347        """
348
349        import time, os
350        from Numeric import array, zeros, allclose, Float, concatenate
351        from Scientific.IO.NetCDF import NetCDFFile
352
353        self.domain.filename = 'datatest' + str(id(self))
354        self.domain.format = 'sww'
355        self.domain.smooth = True
356        self.domain.reduction = min
357
358        sww = get_dataobject(self.domain)
359        sww.store_connectivity()
360        sww.store_timestep('stage')
361
362        self.domain.evolve_to_end(finaltime = 0.01)
363        sww.store_timestep('stage')
364
365
366        #Check contents
367        #Get NetCDF
368        fid = NetCDFFile(sww.filename, 'r')
369
370
371        # Get the variables
372        x = fid.variables['x']
373        y = fid.variables['y']
374        z = fid.variables['elevation']
375        time = fid.variables['time']
376        stage = fid.variables['stage']
377
378        #Check values
379        Q = self.domain.quantities['stage']
380        Q0 = Q.vertex_values[:,0]
381        Q1 = Q.vertex_values[:,1]
382        Q2 = Q.vertex_values[:,2]
383
384        A = stage[1,:]
385        assert allclose(A[0], min(Q2[0], Q1[1]))
386        assert allclose(A[1], min(Q0[1], Q1[3], Q2[2]))
387        assert allclose(A[2], Q0[3])
388        assert allclose(A[3], min(Q0[0], Q1[5], Q2[4]))
389
390        #Center point
391        assert allclose(A[4], min(Q1[0], Q2[1], Q0[2],\
392                                  Q0[5], Q2[6], Q1[7]))
393
394
395        fid.close()
396
397        #Cleanup
398        os.remove(sww.filename)
399
400
401    def test_sync(self):
402        """Test info stored at each timestep is as expected (incl initial condition)
403        """
404
405        import time, os, config
406        from Numeric import array, zeros, allclose, Float, concatenate
407        from Scientific.IO.NetCDF import NetCDFFile
408
409        self.domain.filename = 'synctest'
410        self.domain.format = 'sww'
411        self.domain.smooth = False
412        self.domain.store = True
413        self.domain.beta_h = 0
414
415        #Evolution
416        for t in self.domain.evolve(yieldstep = 1.0, finaltime = 4.0):
417            stage = self.domain.quantities['stage'].vertex_values
418
419            #Get NetCDF
420            fid = NetCDFFile(self.domain.writer.filename, 'r')
421            stage_file = fid.variables['stage']
422
423            if t == 0.0:
424                assert allclose(stage, self.initial_stage)
425                assert allclose(stage_file[:], stage.flat)
426            else:
427                assert not allclose(stage, self.initial_stage)
428                assert not allclose(stage_file[:], stage.flat)
429
430            fid.close()
431
432        os.remove(self.domain.writer.filename)
433
434
435
436    def test_sww_DSG(self):
437        """Not a test, rather a look at the sww format
438        """
439
440        import time, os
441        from Numeric import array, zeros, allclose, Float, concatenate
442        from Scientific.IO.NetCDF import NetCDFFile
443
444        self.domain.filename = 'datatest' + str(id(self))
445        self.domain.format = 'sww'
446        self.domain.smooth = True
447        self.domain.reduction = mean
448
449        sww = get_dataobject(self.domain)
450        sww.store_connectivity()
451        sww.store_timestep('stage')
452
453        #Check contents
454        #Get NetCDF
455        fid = NetCDFFile(sww.filename, 'r')
456
457        # Get the variables
458        x = fid.variables['x']
459        y = fid.variables['y']
460        z = fid.variables['elevation']
461
462        volumes = fid.variables['volumes']
463        time = fid.variables['time']
464
465        # 2D
466        stage = fid.variables['stage']
467
468        X = x[:]
469        Y = y[:]
470        Z = z[:]
471        V = volumes[:]
472        T = time[:]
473        S = stage[:,:]
474
475#         print "****************************"
476#         print "X ",X
477#         print "****************************"
478#         print "Y ",Y
479#         print "****************************"
480#         print "Z ",Z
481#         print "****************************"
482#         print "V ",V
483#         print "****************************"
484#         print "Time ",T
485#         print "****************************"
486#         print "Stage ",S
487#         print "****************************"
488
489
490        fid.close()
491
492        #Cleanup
493        os.remove(sww.filename)
494
495
496    def test_write_pts(self):
497        #Get (enough) datapoints
498
499        from Numeric import array
500        points = array([[ 0.66666667, 0.66666667],
501                        [ 1.33333333, 1.33333333],
502                        [ 2.66666667, 0.66666667],
503                        [ 0.66666667, 2.66666667],
504                        [ 0.0, 1.0],
505                        [ 0.0, 3.0],
506                        [ 1.0, 0.0],
507                        [ 1.0, 1.0],
508                        [ 1.0, 2.0],
509                        [ 1.0, 3.0],
510                        [ 2.0, 1.0],
511                        [ 3.0, 0.0],
512                        [ 3.0, 1.0]])
513
514        z = points[:,0] + 2*points[:,1]
515
516        ptsfile = 'testptsfile.pts'
517        write_ptsfile(ptsfile, points, z,
518                      attribute_name = 'linear_combination')
519
520        #Check contents
521        #Get NetCDF
522        from Scientific.IO.NetCDF import NetCDFFile       
523        fid = NetCDFFile(ptsfile, 'r')
524
525        # Get the variables
526        #print fid.variables.keys()
527        points1 = fid.variables['points']
528        z1 = fid.variables['linear_combination']
529
530        #Check values
531
532        #print points[:]
533        #print ref_points
534        assert allclose(points, points1)
535
536        #print attributes[:]
537        #print ref_elevation
538        assert allclose(z, z1)
539
540        #Cleanup
541        fid.close()
542
543        import os
544        os.remove(ptsfile)
545       
546
547
548
549    def test_dem2pts(self):
550        """Test conversion from dem in ascii format to native NetCDF xya format
551        """
552
553        import time, os
554        from Numeric import array, zeros, allclose, Float, concatenate
555        from Scientific.IO.NetCDF import NetCDFFile
556
557        #Write test asc file
558        root = 'demtest'
559
560        filename = root+'.asc'
561        fid = open(filename, 'w')
562        fid.write("""ncols         5
563nrows         6
564xllcorner     2000.5
565yllcorner     3000.5
566cellsize      25
567NODATA_value  -9999
568""")
569        #Create linear function
570
571        ref_points = []
572        ref_elevation = []
573        for i in range(6):
574            y = (6-i)*25.0
575            for j in range(5):
576                x = j*25.0
577                z = x+2*y
578
579                ref_points.append( [x,y] )
580                ref_elevation.append(z)
581                fid.write('%f ' %z)
582            fid.write('\n')
583
584        fid.close()
585
586        #Write prj file with metadata
587        metafilename = root+'.prj'
588        fid = open(metafilename, 'w')
589
590
591        fid.write("""Projection UTM
592Zone 56
593Datum WGS84
594Zunits NO
595Units METERS
596Spheroid WGS84
597Xshift 0.0000000000
598Yshift 10000000.0000000000
599Parameters
600""")
601        fid.close()
602
603        #Convert to NetCDF pts
604        convert_dem_from_ascii2netcdf(root)
605        dem2pts(root)
606
607        #Check contents
608        #Get NetCDF
609        fid = NetCDFFile(root+'.pts', 'r')
610
611        # Get the variables
612        #print fid.variables.keys()
613        points = fid.variables['points']
614        elevation = fid.variables['elevation']
615
616        #Check values
617
618        #print points[:]
619        #print ref_points
620        assert allclose(points, ref_points)
621
622        #print attributes[:]
623        #print ref_elevation
624        assert allclose(elevation, ref_elevation)
625
626        #Cleanup
627        fid.close()
628
629
630        os.remove(root + '.pts')
631        os.remove(root + '.dem')
632        os.remove(root + '.asc')
633        os.remove(root + '.prj')
634
635
636
637    def test_dem2pts_bounding_box(self):
638        """Test conversion from dem in ascii format to native NetCDF xya format
639        """
640
641        import time, os
642        from Numeric import array, zeros, allclose, Float, concatenate
643        from Scientific.IO.NetCDF import NetCDFFile
644
645        #Write test asc file
646        root = 'demtest'
647
648        filename = root+'.asc'
649        fid = open(filename, 'w')
650        fid.write("""ncols         5
651nrows         6
652xllcorner     2000.5
653yllcorner     3000.5
654cellsize      25
655NODATA_value  -9999
656""")
657        #Create linear function
658
659        ref_points = []
660        ref_elevation = []
661        for i in range(6):
662            y = (6-i)*25.0
663            for j in range(5):
664                x = j*25.0
665                z = x+2*y
666
667                ref_points.append( [x,y] )
668                ref_elevation.append(z)
669                fid.write('%f ' %z)
670            fid.write('\n')
671
672        fid.close()
673
674        #Write prj file with metadata
675        metafilename = root+'.prj'
676        fid = open(metafilename, 'w')
677
678
679        fid.write("""Projection UTM
680Zone 56
681Datum WGS84
682Zunits NO
683Units METERS
684Spheroid WGS84
685Xshift 0.0000000000
686Yshift 10000000.0000000000
687Parameters
688""")
689        fid.close()
690
691        #Convert to NetCDF pts
692        convert_dem_from_ascii2netcdf(root)
693        dem2pts(root, easting_min=2010.0, easting_max=2110.0,
694                northing_min=3035.0, northing_max=3125.5)
695
696        #Check contents
697        #Get NetCDF
698        fid = NetCDFFile(root+'.pts', 'r')
699
700        # Get the variables
701        #print fid.variables.keys()
702        points = fid.variables['points']
703        elevation = fid.variables['elevation']
704
705        #Check values
706        assert fid.xllcorner[0] == 2010.0
707        assert fid.yllcorner[0] == 3035.0
708
709        #create new reference points
710        ref_points = []
711        ref_elevation = []
712        for i in range(4):
713            y = (4-i)*25.0 + 25.0
714            y_new = y + 3000.5 - 3035.0
715            for j in range(4):
716                x = j*25.0 + 25.0
717                x_new = x + 2000.5 - 2010.0
718                z = x+2*y
719
720                ref_points.append( [x_new,y_new] )
721                ref_elevation.append(z)
722
723        #print points[:]
724        #print ref_points
725        assert allclose(points, ref_points)
726
727        #print attributes[:]
728        #print ref_elevation
729        assert allclose(elevation, ref_elevation)
730
731        #Cleanup
732        fid.close()
733
734
735        os.remove(root + '.pts')
736        os.remove(root + '.dem')
737        os.remove(root + '.asc')
738        os.remove(root + '.prj')
739
740
741
742    def test_sww2dem_asc_elevation(self):
743        """Test that sww information can be converted correctly to asc/prj
744        format readable by e.g. ArcView
745        """
746
747        import time, os
748        from Numeric import array, zeros, allclose, Float, concatenate
749        from Scientific.IO.NetCDF import NetCDFFile
750
751        #Setup
752        self.domain.filename = 'datatest'
753
754        prjfile = self.domain.filename + '_elevation.prj'
755        ascfile = self.domain.filename + '_elevation.asc'
756        swwfile = self.domain.filename + '.sww'
757
758        self.domain.set_datadir('.')
759        self.domain.format = 'sww'
760        self.domain.smooth = True
761        self.domain.set_quantity('elevation', lambda x,y: -x-y)
762
763        self.domain.geo_reference = Geo_reference(56,308500,6189000)
764
765        sww = get_dataobject(self.domain)
766        sww.store_connectivity()
767        sww.store_timestep('stage')
768
769        self.domain.evolve_to_end(finaltime = 0.01)
770        sww.store_timestep('stage')
771
772        cellsize = 0.25
773        #Check contents
774        #Get NetCDF
775
776        fid = NetCDFFile(sww.filename, 'r')
777
778        # Get the variables
779        x = fid.variables['x'][:]
780        y = fid.variables['y'][:]
781        z = fid.variables['elevation'][:]
782        time = fid.variables['time'][:]
783        stage = fid.variables['stage'][:]
784
785
786        #Export to ascii/prj files
787        sww2dem(self.domain.filename,
788                quantity = 'elevation',
789                cellsize = cellsize,
790                verbose = False,
791                format = 'asc')
792               
793        #Check prj (meta data)
794        prjid = open(prjfile)
795        lines = prjid.readlines()
796        prjid.close()
797
798        L = lines[0].strip().split()
799        assert L[0].strip().lower() == 'projection'
800        assert L[1].strip().lower() == 'utm'
801
802        L = lines[1].strip().split()
803        assert L[0].strip().lower() == 'zone'
804        assert L[1].strip().lower() == '56'
805
806        L = lines[2].strip().split()
807        assert L[0].strip().lower() == 'datum'
808        assert L[1].strip().lower() == 'wgs84'
809
810        L = lines[3].strip().split()
811        assert L[0].strip().lower() == 'zunits'
812        assert L[1].strip().lower() == 'no'
813
814        L = lines[4].strip().split()
815        assert L[0].strip().lower() == 'units'
816        assert L[1].strip().lower() == 'meters'
817
818        L = lines[5].strip().split()
819        assert L[0].strip().lower() == 'spheroid'
820        assert L[1].strip().lower() == 'wgs84'
821
822        L = lines[6].strip().split()
823        assert L[0].strip().lower() == 'xshift'
824        assert L[1].strip().lower() == '500000'
825
826        L = lines[7].strip().split()
827        assert L[0].strip().lower() == 'yshift'
828        assert L[1].strip().lower() == '10000000'
829
830        L = lines[8].strip().split()
831        assert L[0].strip().lower() == 'parameters'
832
833
834        #Check asc file
835        ascid = open(ascfile)
836        lines = ascid.readlines()
837        ascid.close()
838
839        L = lines[0].strip().split()
840        assert L[0].strip().lower() == 'ncols'
841        assert L[1].strip().lower() == '5'
842
843        L = lines[1].strip().split()
844        assert L[0].strip().lower() == 'nrows'
845        assert L[1].strip().lower() == '5'
846
847        L = lines[2].strip().split()
848        assert L[0].strip().lower() == 'xllcorner'
849        assert allclose(float(L[1].strip().lower()), 308500)
850
851        L = lines[3].strip().split()
852        assert L[0].strip().lower() == 'yllcorner'
853        assert allclose(float(L[1].strip().lower()), 6189000)
854
855        L = lines[4].strip().split()
856        assert L[0].strip().lower() == 'cellsize'
857        assert allclose(float(L[1].strip().lower()), cellsize)
858
859        L = lines[5].strip().split()
860        assert L[0].strip() == 'NODATA_value'
861        assert L[1].strip().lower() == '-9999'
862
863        #Check grid values
864        for j in range(5):
865            L = lines[6+j].strip().split()
866            y = (4-j) * cellsize
867            for i in range(5):
868                assert allclose(float(L[i]), -i*cellsize - y)
869
870
871        fid.close()
872
873        #Cleanup
874        os.remove(prjfile)
875        os.remove(ascfile)
876        os.remove(swwfile)
877
878
879    def test_sww2dem_asc_stage_reduction(self):
880        """Test that sww information can be converted correctly to asc/prj
881        format readable by e.g. ArcView
882
883        This tests the reduction of quantity stage using min
884        """
885
886        import time, os
887        from Numeric import array, zeros, allclose, Float, concatenate
888        from Scientific.IO.NetCDF import NetCDFFile
889
890        #Setup
891        self.domain.filename = 'datatest'
892
893        prjfile = self.domain.filename + '_stage.prj'
894        ascfile = self.domain.filename + '_stage.asc'
895        swwfile = self.domain.filename + '.sww'
896
897        self.domain.set_datadir('.')
898        self.domain.format = 'sww'
899        self.domain.smooth = True
900        self.domain.set_quantity('elevation', lambda x,y: -x-y)
901
902        self.domain.geo_reference = Geo_reference(56,308500,6189000)
903
904
905        sww = get_dataobject(self.domain)
906        sww.store_connectivity()
907        sww.store_timestep('stage')
908
909        self.domain.evolve_to_end(finaltime = 0.01)
910        sww.store_timestep('stage')
911
912        cellsize = 0.25
913        #Check contents
914        #Get NetCDF
915
916        fid = NetCDFFile(sww.filename, 'r')
917
918        # Get the variables
919        x = fid.variables['x'][:]
920        y = fid.variables['y'][:]
921        z = fid.variables['elevation'][:]
922        time = fid.variables['time'][:]
923        stage = fid.variables['stage'][:]
924
925
926        #Export to ascii/prj files
927        sww2dem(self.domain.filename,
928                quantity = 'stage',
929                cellsize = cellsize,
930                reduction = min,
931                format = 'asc')
932
933
934        #Check asc file
935        ascid = open(ascfile)
936        lines = ascid.readlines()
937        ascid.close()
938
939        L = lines[0].strip().split()
940        assert L[0].strip().lower() == 'ncols'
941        assert L[1].strip().lower() == '5'
942
943        L = lines[1].strip().split()
944        assert L[0].strip().lower() == 'nrows'
945        assert L[1].strip().lower() == '5'
946
947        L = lines[2].strip().split()
948        assert L[0].strip().lower() == 'xllcorner'
949        assert allclose(float(L[1].strip().lower()), 308500)
950
951        L = lines[3].strip().split()
952        assert L[0].strip().lower() == 'yllcorner'
953        assert allclose(float(L[1].strip().lower()), 6189000)
954
955        L = lines[4].strip().split()
956        assert L[0].strip().lower() == 'cellsize'
957        assert allclose(float(L[1].strip().lower()), cellsize)
958
959        L = lines[5].strip().split()
960        assert L[0].strip() == 'NODATA_value'
961        assert L[1].strip().lower() == '-9999'
962
963
964        #Check grid values (where applicable)
965        for j in range(5):
966            if j%2 == 0:
967                L = lines[6+j].strip().split()
968                jj = 4-j
969                for i in range(5):
970                    if i%2 == 0:
971                        index = jj/2 + i/2*3
972                        val0 = stage[0,index]
973                        val1 = stage[1,index]
974
975                        #print i, j, index, ':', L[i], val0, val1
976                        assert allclose(float(L[i]), min(val0, val1))
977
978
979        fid.close()
980
981        #Cleanup
982        os.remove(prjfile)
983        os.remove(ascfile)
984        #os.remove(swwfile)
985
986
987
988
989    def test_sww2dem_asc_missing_points(self):
990        """Test that sww information can be converted correctly to asc/prj
991        format readable by e.g. ArcView
992
993        This test includes the writing of missing values
994        """
995
996        import time, os
997        from Numeric import array, zeros, allclose, Float, concatenate
998        from Scientific.IO.NetCDF import NetCDFFile
999
1000        #Setup mesh not coinciding with rectangle.
1001        #This will cause missing values to occur in gridded data
1002
1003
1004        points = [                        [1.0, 1.0],
1005                              [0.5, 0.5], [1.0, 0.5],
1006                  [0.0, 0.0], [0.5, 0.0], [1.0, 0.0]]
1007
1008        vertices = [ [4,1,3], [5,2,4], [1,4,2], [2,0,1]]
1009
1010        #Create shallow water domain
1011        domain = Domain(points, vertices)
1012        domain.default_order=2
1013
1014
1015        #Set some field values
1016        domain.set_quantity('elevation', lambda x,y: -x-y)
1017        domain.set_quantity('friction', 0.03)
1018
1019
1020        ######################
1021        # Boundary conditions
1022        B = Transmissive_boundary(domain)
1023        domain.set_boundary( {'exterior': B} )
1024
1025
1026        ######################
1027        #Initial condition - with jumps
1028
1029        bed = domain.quantities['elevation'].vertex_values
1030        stage = zeros(bed.shape, Float)
1031
1032        h = 0.3
1033        for i in range(stage.shape[0]):
1034            if i % 2 == 0:
1035                stage[i,:] = bed[i,:] + h
1036            else:
1037                stage[i,:] = bed[i,:]
1038
1039        domain.set_quantity('stage', stage)
1040        domain.distribute_to_vertices_and_edges()
1041
1042        domain.filename = 'datatest'
1043
1044        prjfile = domain.filename + '_elevation.prj'
1045        ascfile = domain.filename + '_elevation.asc'
1046        swwfile = domain.filename + '.sww'
1047
1048        domain.set_datadir('.')
1049        domain.format = 'sww'
1050        domain.smooth = True
1051
1052        domain.geo_reference = Geo_reference(56,308500,6189000)
1053
1054        sww = get_dataobject(domain)
1055        sww.store_connectivity()
1056        sww.store_timestep('stage')
1057
1058        cellsize = 0.25
1059        #Check contents
1060        #Get NetCDF
1061
1062        fid = NetCDFFile(swwfile, 'r')
1063
1064        # Get the variables
1065        x = fid.variables['x'][:]
1066        y = fid.variables['y'][:]
1067        z = fid.variables['elevation'][:]
1068        time = fid.variables['time'][:]
1069
1070        try:
1071            geo_reference = Geo_reference(NetCDFObject=fid)
1072        except AttributeError, e:
1073            geo_reference = Geo_reference(DEFAULT_ZONE,0,0)
1074
1075        #Export to ascii/prj files
1076        sww2dem(domain.filename,
1077                quantity = 'elevation',
1078                cellsize = cellsize,
1079                verbose = False,
1080                format = 'asc')
1081
1082
1083        #Check asc file
1084        ascid = open(ascfile)
1085        lines = ascid.readlines()
1086        ascid.close()
1087
1088        L = lines[0].strip().split()
1089        assert L[0].strip().lower() == 'ncols'
1090        assert L[1].strip().lower() == '5'
1091
1092        L = lines[1].strip().split()
1093        assert L[0].strip().lower() == 'nrows'
1094        assert L[1].strip().lower() == '5'
1095
1096        L = lines[2].strip().split()
1097        assert L[0].strip().lower() == 'xllcorner'
1098        assert allclose(float(L[1].strip().lower()), 308500)
1099
1100        L = lines[3].strip().split()
1101        assert L[0].strip().lower() == 'yllcorner'
1102        assert allclose(float(L[1].strip().lower()), 6189000)
1103
1104        L = lines[4].strip().split()
1105        assert L[0].strip().lower() == 'cellsize'
1106        assert allclose(float(L[1].strip().lower()), cellsize)
1107
1108        L = lines[5].strip().split()
1109        assert L[0].strip() == 'NODATA_value'
1110        assert L[1].strip().lower() == '-9999'
1111
1112
1113        #Check grid values
1114        for j in range(5):
1115            L = lines[6+j].strip().split()
1116            y = (4-j) * cellsize
1117            for i in range(5):
1118                if i+j >= 4:
1119                    assert allclose(float(L[i]), -i*cellsize - y)
1120                else:
1121                    #Missing values
1122                    assert allclose(float(L[i]), -9999)
1123
1124
1125
1126        fid.close()
1127
1128        #Cleanup
1129        os.remove(prjfile)
1130        os.remove(ascfile)
1131        os.remove(swwfile)
1132
1133    def test_sww2ers_simple(self):
1134        """Test that sww information can be converted correctly to asc/prj
1135        format readable by e.g. ArcView
1136        """
1137
1138        import time, os
1139        from Numeric import array, zeros, allclose, Float, concatenate
1140        from Scientific.IO.NetCDF import NetCDFFile
1141
1142        #Setup
1143        self.domain.filename = 'datatest'
1144
1145        headerfile = self.domain.filename + '.ers'
1146        swwfile = self.domain.filename + '.sww'
1147
1148        self.domain.set_datadir('.')
1149        self.domain.format = 'sww'
1150        self.domain.smooth = True
1151        self.domain.set_quantity('elevation', lambda x,y: -x-y)
1152
1153        self.domain.geo_reference = Geo_reference(56,308500,6189000)
1154
1155        sww = get_dataobject(self.domain)
1156        sww.store_connectivity()
1157        sww.store_timestep('stage')
1158
1159        self.domain.evolve_to_end(finaltime = 0.01)
1160        sww.store_timestep('stage')
1161
1162        cellsize = 0.25
1163        #Check contents
1164        #Get NetCDF
1165
1166        fid = NetCDFFile(sww.filename, 'r')
1167
1168        # Get the variables
1169        x = fid.variables['x'][:]
1170        y = fid.variables['y'][:]
1171        z = fid.variables['elevation'][:]
1172        time = fid.variables['time'][:]
1173        stage = fid.variables['stage'][:]
1174
1175
1176        #Export to ers files
1177        #sww2ers(self.domain.filename,
1178        #        quantity = 'elevation',
1179        #        cellsize = cellsize,
1180        #        verbose = False)
1181               
1182        sww2dem(self.domain.filename,
1183                quantity = 'elevation',
1184                cellsize = cellsize,
1185                verbose = False,
1186                format = 'ers')
1187
1188        #Check header data
1189        from ermapper_grids import read_ermapper_header, read_ermapper_data
1190       
1191        header = read_ermapper_header(self.domain.filename + '_elevation.ers')
1192        #print header
1193        assert header['projection'].lower() == '"utm-56"'
1194        assert header['datum'].lower() == '"wgs84"'
1195        assert header['units'].lower() == '"meters"'   
1196        assert header['value'].lower() == '"elevation"'         
1197        assert header['xdimension'] == '0.25'
1198        assert header['ydimension'] == '0.25'   
1199        assert float(header['eastings']) == 308500.0   #xllcorner
1200        assert float(header['northings']) == 6189000.0 #yllcorner       
1201        assert int(header['nroflines']) == 5
1202        assert int(header['nrofcellsperline']) == 5     
1203        assert int(header['nullcellvalue']) == 0   #?           
1204        #FIXME - there is more in the header                   
1205
1206               
1207        #Check grid data               
1208        grid = read_ermapper_data(self.domain.filename + '_elevation') 
1209       
1210        #FIXME (Ole): Why is this the desired reference grid for -x-y?
1211        ref_grid = [0,      0,     0,     0,     0,
1212                    -1,    -1.25, -1.5,  -1.75, -2.0,
1213                    -0.75, -1.0,  -1.25, -1.5,  -1.75,             
1214                    -0.5,  -0.75, -1.0,  -1.25, -1.5,
1215                    -0.25, -0.5,  -0.75, -1.0,  -1.25]             
1216                                         
1217                                         
1218        assert allclose(grid, ref_grid)
1219
1220        fid.close()
1221       
1222        #Cleanup
1223        #FIXME the file clean-up doesn't work (eg Permission Denied Error)
1224        #Done (Ole) - it was because sww2ers didn't close it's sww file
1225        os.remove(sww.filename)
1226
1227
1228    def xxxtestz_sww2ers_real(self):
1229        """Test that sww information can be converted correctly to asc/prj
1230        format readable by e.g. ArcView
1231        """
1232
1233        import time, os
1234        from Numeric import array, zeros, allclose, Float, concatenate
1235        from Scientific.IO.NetCDF import NetCDFFile
1236
1237
1238
1239
1240        #Export to ascii/prj files
1241        sww2ers('karratha_100m.sww',
1242                quantity = 'depth',
1243                cellsize = 5,
1244                verbose = False)
1245
1246
1247    def test_ferret2sww(self):
1248        """Test that georeferencing etc works when converting from
1249        ferret format (lat/lon) to sww format (UTM)
1250        """
1251        from Scientific.IO.NetCDF import NetCDFFile
1252
1253        #The test file has
1254        # LON = 150.66667, 150.83334, 151, 151.16667
1255        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1256        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1257        #
1258        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1259        # Fourth value (index==3) is -6.50198 cm
1260
1261
1262        from coordinate_transforms.redfearn import redfearn
1263
1264        fid = NetCDFFile('small_ha.nc')
1265        first_value = fid.variables['HA'][:][0,0,0]
1266        fourth_value = fid.variables['HA'][:][0,0,3]
1267
1268
1269        #Call conversion (with zero origin)
1270        ferret2sww('small', verbose=False,
1271                   origin = (56, 0, 0))
1272
1273
1274        #Work out the UTM coordinates for first point
1275        zone, e, n = redfearn(-34.5, 150.66667)
1276        #print zone, e, n
1277
1278        #Read output file 'small.sww'
1279        fid = NetCDFFile('small.sww')
1280
1281        x = fid.variables['x'][:]
1282        y = fid.variables['y'][:]
1283
1284        #Check that first coordinate is correctly represented
1285        assert allclose(x[0], e)
1286        assert allclose(y[0], n)
1287
1288        #Check first value
1289        stage = fid.variables['stage'][:]
1290        xmomentum = fid.variables['xmomentum'][:]
1291        ymomentum = fid.variables['ymomentum'][:]
1292
1293        #print ymomentum
1294
1295        assert allclose(stage[0,0], first_value/100)  #Meters
1296
1297        #Check fourth value
1298        assert allclose(stage[0,3], fourth_value/100)  #Meters
1299
1300        fid.close()
1301
1302        #Cleanup
1303        import os
1304        os.remove('small.sww')
1305
1306
1307
1308    def test_ferret2sww_2(self):
1309        """Test that georeferencing etc works when converting from
1310        ferret format (lat/lon) to sww format (UTM)
1311        """
1312        from Scientific.IO.NetCDF import NetCDFFile
1313
1314        #The test file has
1315        # LON = 150.66667, 150.83334, 151, 151.16667
1316        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1317        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1318        #
1319        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1320        # Fourth value (index==3) is -6.50198 cm
1321
1322
1323        from coordinate_transforms.redfearn import redfearn
1324
1325        fid = NetCDFFile('small_ha.nc')
1326
1327        #Pick a coordinate and a value
1328
1329        time_index = 1
1330        lat_index = 0
1331        lon_index = 2
1332
1333        test_value = fid.variables['HA'][:][time_index, lat_index, lon_index]
1334        test_time = fid.variables['TIME'][:][time_index]
1335        test_lat = fid.variables['LAT'][:][lat_index]
1336        test_lon = fid.variables['LON'][:][lon_index]
1337
1338        linear_point_index = lat_index*4 + lon_index
1339        fid.close()
1340
1341        #Call conversion (with zero origin)
1342        ferret2sww('small', verbose=False,
1343                   origin = (56, 0, 0))
1344
1345
1346        #Work out the UTM coordinates for test point
1347        zone, e, n = redfearn(test_lat, test_lon)
1348
1349        #Read output file 'small.sww'
1350        fid = NetCDFFile('small.sww')
1351
1352        x = fid.variables['x'][:]
1353        y = fid.variables['y'][:]
1354
1355        #Check that test coordinate is correctly represented
1356        assert allclose(x[linear_point_index], e)
1357        assert allclose(y[linear_point_index], n)
1358
1359        #Check test value
1360        stage = fid.variables['stage'][:]
1361
1362        assert allclose(stage[time_index, linear_point_index], test_value/100)
1363
1364        fid.close()
1365
1366        #Cleanup
1367        import os
1368        os.remove('small.sww')
1369
1370
1371
1372    def test_ferret2sww3(self):
1373        """Elevation included
1374        """
1375        from Scientific.IO.NetCDF import NetCDFFile
1376
1377        #The test file has
1378        # LON = 150.66667, 150.83334, 151, 151.16667
1379        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1380        # ELEVATION = [-1 -2 -3 -4
1381        #             -5 -6 -7 -8
1382        #              ...
1383        #              ...      -16]
1384        # where the top left corner is -1m,
1385        # and the ll corner is -13.0m
1386        #
1387        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1388        # Fourth value (index==3) is -6.50198 cm
1389
1390        from coordinate_transforms.redfearn import redfearn
1391        import os
1392        fid1 = NetCDFFile('test_ha.nc','w')
1393        fid2 = NetCDFFile('test_ua.nc','w')
1394        fid3 = NetCDFFile('test_va.nc','w')
1395        fid4 = NetCDFFile('test_e.nc','w')
1396
1397        h1_list = [150.66667,150.83334,151.]
1398        h2_list = [-34.5,-34.33333]
1399
1400        long_name = 'LON'
1401        lat_name = 'LAT'
1402
1403        nx = 3
1404        ny = 2
1405
1406        for fid in [fid1,fid2,fid3]:
1407            fid.createDimension(long_name,nx)
1408            fid.createVariable(long_name,'d',(long_name,))
1409            fid.variables[long_name].point_spacing='uneven'
1410            fid.variables[long_name].units='degrees_east'
1411            fid.variables[long_name].assignValue(h1_list)
1412
1413            fid.createDimension(lat_name,ny)
1414            fid.createVariable(lat_name,'d',(lat_name,))
1415            fid.variables[lat_name].point_spacing='uneven'
1416            fid.variables[lat_name].units='degrees_north'
1417            fid.variables[lat_name].assignValue(h2_list)
1418
1419            fid.createDimension('TIME',2)
1420            fid.createVariable('TIME','d',('TIME',))
1421            fid.variables['TIME'].point_spacing='uneven'
1422            fid.variables['TIME'].units='seconds'
1423            fid.variables['TIME'].assignValue([0.,1.])
1424            if fid == fid3: break
1425
1426
1427        for fid in [fid4]:
1428            fid.createDimension(long_name,nx)
1429            fid.createVariable(long_name,'d',(long_name,))
1430            fid.variables[long_name].point_spacing='uneven'
1431            fid.variables[long_name].units='degrees_east'
1432            fid.variables[long_name].assignValue(h1_list)
1433
1434            fid.createDimension(lat_name,ny)
1435            fid.createVariable(lat_name,'d',(lat_name,))
1436            fid.variables[lat_name].point_spacing='uneven'
1437            fid.variables[lat_name].units='degrees_north'
1438            fid.variables[lat_name].assignValue(h2_list)
1439
1440        name = {}
1441        name[fid1]='HA'
1442        name[fid2]='UA'
1443        name[fid3]='VA'
1444        name[fid4]='ELEVATION'
1445
1446        units = {}
1447        units[fid1]='cm'
1448        units[fid2]='cm/s'
1449        units[fid3]='cm/s'
1450        units[fid4]='m'
1451
1452        values = {}
1453        values[fid1]=[[[5., 10.,15.], [13.,18.,23.]],[[50.,100.,150.],[130.,180.,230.]]]
1454        values[fid2]=[[[1., 2.,3.], [4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]]
1455        values[fid3]=[[[13., 12.,11.], [10.,9.,8.]],[[7.,6.,5.],[4.,3.,2.]]]
1456        values[fid4]=[[-3000,-3100,-3200],[-4000,-5000,-6000]]
1457
1458        for fid in [fid1,fid2,fid3]:
1459          fid.createVariable(name[fid],'d',('TIME',lat_name,long_name))
1460          fid.variables[name[fid]].point_spacing='uneven'
1461          fid.variables[name[fid]].units=units[fid]
1462          fid.variables[name[fid]].assignValue(values[fid])
1463          fid.variables[name[fid]].missing_value = -99999999.
1464          if fid == fid3: break
1465
1466        for fid in [fid4]:
1467            fid.createVariable(name[fid],'d',(lat_name,long_name))
1468            fid.variables[name[fid]].point_spacing='uneven'
1469            fid.variables[name[fid]].units=units[fid]
1470            fid.variables[name[fid]].assignValue(values[fid])
1471            fid.variables[name[fid]].missing_value = -99999999.
1472
1473
1474        fid1.sync(); fid1.close()
1475        fid2.sync(); fid2.close()
1476        fid3.sync(); fid3.close()
1477        fid4.sync(); fid4.close()
1478
1479        fid1 = NetCDFFile('test_ha.nc','r')
1480        fid2 = NetCDFFile('test_e.nc','r')
1481        fid3 = NetCDFFile('test_va.nc','r')
1482
1483
1484        first_amp = fid1.variables['HA'][:][0,0,0]
1485        third_amp = fid1.variables['HA'][:][0,0,2]
1486        first_elevation = fid2.variables['ELEVATION'][0,0]
1487        third_elevation= fid2.variables['ELEVATION'][:][0,2]
1488        first_speed = fid3.variables['VA'][0,0,0]
1489        third_speed = fid3.variables['VA'][:][0,0,2]
1490
1491        fid1.close()
1492        fid2.close()
1493        fid3.close()
1494
1495        #Call conversion (with zero origin)
1496        ferret2sww('test', verbose=False,
1497                   origin = (56, 0, 0))
1498
1499        os.remove('test_va.nc')
1500        os.remove('test_ua.nc')
1501        os.remove('test_ha.nc')
1502        os.remove('test_e.nc')
1503
1504        #Read output file 'test.sww'
1505        fid = NetCDFFile('test.sww')
1506
1507
1508        #Check first value
1509        elevation = fid.variables['elevation'][:]
1510        stage = fid.variables['stage'][:]
1511        xmomentum = fid.variables['xmomentum'][:]
1512        ymomentum = fid.variables['ymomentum'][:]
1513
1514        #print ymomentum
1515        first_height = first_amp/100 - first_elevation
1516        third_height = third_amp/100 - third_elevation
1517        first_momentum=first_speed*first_height/100
1518        third_momentum=third_speed*third_height/100
1519
1520        assert allclose(ymomentum[0][0],first_momentum)  #Meters
1521        assert allclose(ymomentum[0][2],third_momentum)  #Meters
1522
1523        fid.close()
1524
1525        #Cleanup
1526        os.remove('test.sww')
1527
1528
1529
1530
1531    def test_ferret2sww_nz_origin(self):
1532        from Scientific.IO.NetCDF import NetCDFFile
1533        from coordinate_transforms.redfearn import redfearn
1534
1535        #Call conversion (with nonzero origin)
1536        ferret2sww('small', verbose=False,
1537                   origin = (56, 100000, 200000))
1538
1539
1540        #Work out the UTM coordinates for first point
1541        zone, e, n = redfearn(-34.5, 150.66667)
1542
1543        #Read output file 'small.sww'
1544        fid = NetCDFFile('small.sww', 'r')
1545
1546        x = fid.variables['x'][:]
1547        y = fid.variables['y'][:]
1548
1549        #Check that first coordinate is correctly represented
1550        assert allclose(x[0], e-100000)
1551        assert allclose(y[0], n-200000)
1552
1553        fid.close()
1554
1555        #Cleanup
1556        import os
1557        os.remove('small.sww')
1558
1559
1560
1561    def test_sww_extent(self):
1562        """Not a test, rather a look at the sww format
1563        """
1564
1565        import time, os
1566        from Numeric import array, zeros, allclose, Float, concatenate
1567        from Scientific.IO.NetCDF import NetCDFFile
1568
1569        self.domain.filename = 'datatest' + str(id(self))
1570        self.domain.format = 'sww'
1571        self.domain.smooth = True
1572        self.domain.reduction = mean
1573        self.domain.set_datadir('.')
1574
1575
1576        sww = get_dataobject(self.domain)
1577        sww.store_connectivity()
1578        sww.store_timestep('stage')
1579        self.domain.time = 2.
1580
1581        #Modify stage at second timestep
1582        stage = self.domain.quantities['stage'].vertex_values
1583        self.domain.set_quantity('stage', stage/2)
1584
1585        sww.store_timestep('stage')
1586
1587        file_and_extension_name = self.domain.filename + ".sww"
1588        #print "file_and_extension_name",file_and_extension_name
1589        [xmin, xmax, ymin, ymax, stagemin, stagemax] = \
1590               extent_sww(file_and_extension_name )
1591
1592        assert allclose(xmin, 0.0)
1593        assert allclose(xmax, 1.0)
1594        assert allclose(ymin, 0.0)
1595        assert allclose(ymax, 1.0)
1596        assert allclose(stagemin, -0.85)
1597        assert allclose(stagemax, 0.15)
1598
1599
1600        #Cleanup
1601        os.remove(sww.filename)
1602
1603
1604
1605    def test_sww2domain(self):
1606        ################################################
1607        #Create a test domain, and evolve and save it.
1608        ################################################
1609        from mesh_factory import rectangular
1610        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1611             Constant_height, Time_boundary, Transmissive_boundary
1612        from Numeric import array
1613
1614        #Create basic mesh
1615
1616        yiel=0.01
1617        points, vertices, boundary = rectangular(10,10)
1618
1619        #Create shallow water domain
1620        domain = Domain(points, vertices, boundary)
1621        domain.geo_reference = Geo_reference(56,11,11)
1622        domain.smooth = False
1623        domain.visualise = False
1624        domain.store = True
1625        domain.filename = 'bedslope'
1626        domain.default_order=2
1627        #Bed-slope and friction
1628        domain.set_quantity('elevation', lambda x,y: -x/3)
1629        domain.set_quantity('friction', 0.1)
1630        # Boundary conditions
1631        from math import sin, pi
1632        Br = Reflective_boundary(domain)
1633        Bt = Transmissive_boundary(domain)
1634        Bd = Dirichlet_boundary([0.2,0.,0.])
1635        Bw = Time_boundary(domain=domain,
1636                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1637
1638        #domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1639        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1640
1641        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1642        #Initial condition
1643        h = 0.05
1644        elevation = domain.quantities['elevation'].vertex_values
1645        domain.set_quantity('stage', elevation + h)
1646        #elevation = domain.get_quantity('elevation')
1647        #domain.set_quantity('stage', elevation + h)
1648
1649        domain.check_integrity()
1650        #Evolution
1651        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1652        #    domain.write_time()
1653            pass
1654
1655
1656        ##########################################
1657        #Import the example's file as a new domain
1658        ##########################################
1659        from data_manager import sww2domain
1660        from Numeric import allclose
1661        import os
1662
1663        filename = domain.datadir+os.sep+domain.filename+'.sww'
1664        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
1665        #points, vertices, boundary = rectangular(15,15)
1666        #domain2.boundary = boundary
1667        ###################
1668        ##NOW TEST IT!!!
1669        ###################
1670
1671        bits = ['vertex_coordinates']
1672        for quantity in ['elevation']+domain.quantities_to_be_stored:
1673            bits.append('quantities["%s"].get_integral()'%quantity)
1674            bits.append('get_quantity("%s")'%quantity)
1675
1676        for bit in bits:
1677            #print 'testing that domain.'+bit+' has been restored'
1678            #print bit
1679        #print 'done'
1680            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1681
1682        ######################################
1683        #Now evolve them both, just to be sure
1684        ######################################x
1685        visualise = False
1686        #visualise = True
1687        domain.visualise = visualise
1688        domain.time = 0.
1689        from time import sleep
1690
1691        final = .1
1692        domain.set_quantity('friction', 0.1)
1693        domain.store = False
1694        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1695
1696
1697        for t in domain.evolve(yieldstep = yiel, finaltime = final):
1698            if visualise: sleep(1.)
1699            #domain.write_time()
1700            pass
1701
1702        final = final - (domain2.starttime-domain.starttime)
1703        #BUT since domain1 gets time hacked back to 0:
1704        final = final + (domain2.starttime-domain.starttime)
1705
1706        domain2.smooth = False
1707        domain2.visualise = visualise
1708        domain2.store = False
1709        domain2.default_order=2
1710        domain2.set_quantity('friction', 0.1)
1711        #Bed-slope and friction
1712        # Boundary conditions
1713        Bd2=Dirichlet_boundary([0.2,0.,0.])
1714        domain2.boundary = domain.boundary
1715        #print 'domain2.boundary'
1716        #print domain2.boundary
1717        domain2.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1718        #domain2.set_boundary({'exterior': Bd})
1719
1720        domain2.check_integrity()
1721
1722        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
1723            if visualise: sleep(1.)
1724            #domain2.write_time()
1725            pass
1726
1727        ###################
1728        ##NOW TEST IT!!!
1729        ##################
1730
1731        bits = [ 'vertex_coordinates']
1732
1733        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
1734            bits.append('quantities["%s"].get_integral()'%quantity)
1735            bits.append('get_quantity("%s")'%quantity)
1736
1737        for bit in bits:
1738            #print bit
1739            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1740
1741
1742    def test_sww2domain2(self):
1743        ##################################################################
1744        #Same as previous test, but this checks how NaNs are handled.
1745        ##################################################################
1746
1747
1748        from mesh_factory import rectangular
1749        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1750             Constant_height, Time_boundary, Transmissive_boundary
1751        from Numeric import array
1752
1753        #Create basic mesh
1754        points, vertices, boundary = rectangular(2,2)
1755
1756        #Create shallow water domain
1757        domain = Domain(points, vertices, boundary)
1758        domain.smooth = False
1759        domain.visualise = False
1760        domain.store = True
1761        domain.filename = 'bedslope'
1762        domain.default_order=2
1763        domain.quantities_to_be_stored=['stage']
1764
1765        domain.set_quantity('elevation', lambda x,y: -x/3)
1766        domain.set_quantity('friction', 0.1)
1767
1768        from math import sin, pi
1769        Br = Reflective_boundary(domain)
1770        Bt = Transmissive_boundary(domain)
1771        Bd = Dirichlet_boundary([0.2,0.,0.])
1772        Bw = Time_boundary(domain=domain,
1773                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1774
1775        domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1776
1777        h = 0.05
1778        elevation = domain.quantities['elevation'].vertex_values
1779        domain.set_quantity('stage', elevation + h)
1780
1781        domain.check_integrity()
1782
1783        for t in domain.evolve(yieldstep = 1, finaltime = 2.0):
1784            pass
1785            #domain.write_time()
1786
1787
1788
1789        ##################################
1790        #Import the file as a new domain
1791        ##################################
1792        from data_manager import sww2domain
1793        from Numeric import allclose
1794        import os
1795
1796        filename = domain.datadir+os.sep+domain.filename+'.sww'
1797
1798        #Fail because NaNs are present
1799        try:
1800            domain2 = sww2domain(filename,boundary,fail_if_NaN=True,verbose=False)
1801            assert True == False
1802        except:
1803            #Now import it, filling NaNs to be 0
1804            filler = 0
1805            domain2 = sww2domain(filename,None,fail_if_NaN=False,NaN_filler = filler,verbose=False)
1806        bits = [ 'geo_reference.get_xllcorner()',
1807                'geo_reference.get_yllcorner()',
1808                'vertex_coordinates']
1809
1810        for quantity in ['elevation']+domain.quantities_to_be_stored:
1811            bits.append('quantities["%s"].get_integral()'%quantity)
1812            bits.append('get_quantity("%s")'%quantity)
1813
1814        for bit in bits:
1815        #    print 'testing that domain.'+bit+' has been restored'
1816            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1817
1818        assert max(max(domain2.get_quantity('xmomentum')))==filler
1819        assert min(min(domain2.get_quantity('xmomentum')))==filler
1820        assert max(max(domain2.get_quantity('ymomentum')))==filler
1821        assert min(min(domain2.get_quantity('ymomentum')))==filler
1822
1823        #print 'passed'
1824
1825        #cleanup
1826        #import os
1827        #os.remove(domain.datadir+'/'+domain.filename+'.sww')
1828
1829
1830    #def test_weed(self):
1831        from data_manager import weed
1832
1833        coordinates1 = [[0.,0.],[1.,0.],[1.,1.],[1.,0.],[2.,0.],[1.,1.]]
1834        volumes1 = [[0,1,2],[3,4,5]]
1835        boundary1= {(0,1): 'external',(1,2): 'not external',(2,0): 'external',(3,4): 'external',(4,5): 'external',(5,3): 'not external'}
1836        coordinates2,volumes2,boundary2=weed(coordinates1,volumes1,boundary1)
1837
1838        points2 = {(0.,0.):None,(1.,0.):None,(1.,1.):None,(2.,0.):None}
1839
1840        assert len(points2)==len(coordinates2)
1841        for i in range(len(coordinates2)):
1842            coordinate = tuple(coordinates2[i])
1843            assert points2.has_key(coordinate)
1844            points2[coordinate]=i
1845
1846        for triangle in volumes1:
1847            for coordinate in triangle:
1848                assert coordinates2[points2[tuple(coordinates1[coordinate])]][0]==coordinates1[coordinate][0]
1849                assert coordinates2[points2[tuple(coordinates1[coordinate])]][1]==coordinates1[coordinate][1]
1850
1851
1852    #FIXME This fails - smooth makes the comparism too hard for allclose
1853    def ztest_sww2domain3(self):
1854        ################################################
1855        #DOMAIN.SMOOTH = TRUE !!!!!!!!!!!!!!!!!!!!!!!!!!
1856        ################################################
1857        from mesh_factory import rectangular
1858        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1859             Constant_height, Time_boundary, Transmissive_boundary
1860        from Numeric import array
1861        #Create basic mesh
1862
1863        yiel=0.01
1864        points, vertices, boundary = rectangular(10,10)
1865
1866        #Create shallow water domain
1867        domain = Domain(points, vertices, boundary)
1868        domain.geo_reference = Geo_reference(56,11,11)
1869        domain.smooth = True
1870        domain.visualise = False
1871        domain.store = True
1872        domain.filename = 'bedslope'
1873        domain.default_order=2
1874        #Bed-slope and friction
1875        domain.set_quantity('elevation', lambda x,y: -x/3)
1876        domain.set_quantity('friction', 0.1)
1877        # Boundary conditions
1878        from math import sin, pi
1879        Br = Reflective_boundary(domain)
1880        Bt = Transmissive_boundary(domain)
1881        Bd = Dirichlet_boundary([0.2,0.,0.])
1882        Bw = Time_boundary(domain=domain,
1883                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1884
1885        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1886
1887        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1888        #Initial condition
1889        h = 0.05
1890        elevation = domain.quantities['elevation'].vertex_values
1891        domain.set_quantity('stage', elevation + h)
1892        #elevation = domain.get_quantity('elevation')
1893        #domain.set_quantity('stage', elevation + h)
1894
1895        domain.check_integrity()
1896        #Evolution
1897        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1898        #    domain.write_time()
1899            pass
1900
1901
1902        ##########################################
1903        #Import the example's file as a new domain
1904        ##########################################
1905        from data_manager import sww2domain
1906        from Numeric import allclose
1907        import os
1908
1909        filename = domain.datadir+os.sep+domain.filename+'.sww'
1910        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
1911        #points, vertices, boundary = rectangular(15,15)
1912        #domain2.boundary = boundary
1913        ###################
1914        ##NOW TEST IT!!!
1915        ###################
1916
1917        #FIXME smooth domain so that they can be compared
1918
1919
1920        bits = []#'vertex_coordinates']
1921        for quantity in ['elevation']+domain.quantities_to_be_stored:
1922            bits.append('quantities["%s"].get_integral()'%quantity)
1923            #bits.append('get_quantity("%s")'%quantity)
1924
1925        for bit in bits:
1926            #print 'testing that domain.'+bit+' has been restored'
1927            #print bit
1928            #print 'done'
1929            #print ('domain.'+bit), eval('domain.'+bit)
1930            #print ('domain2.'+bit), eval('domain2.'+bit)
1931            assert allclose(eval('domain.'+bit),eval('domain2.'+bit),rtol=1.0e-1,atol=1.e-3)
1932            pass
1933
1934        ######################################
1935        #Now evolve them both, just to be sure
1936        ######################################x
1937        visualise = False
1938        visualise = True
1939        domain.visualise = visualise
1940        domain.time = 0.
1941        from time import sleep
1942
1943        final = .5
1944        domain.set_quantity('friction', 0.1)
1945        domain.store = False
1946        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Br})
1947
1948        for t in domain.evolve(yieldstep = yiel, finaltime = final):
1949            if visualise: sleep(.03)
1950            #domain.write_time()
1951            pass
1952
1953        domain2.smooth = True
1954        domain2.visualise = visualise
1955        domain2.store = False
1956        domain2.default_order=2
1957        domain2.set_quantity('friction', 0.1)
1958        #Bed-slope and friction
1959        # Boundary conditions
1960        Bd2=Dirichlet_boundary([0.2,0.,0.])
1961        Br2 = Reflective_boundary(domain2)
1962        domain2.boundary = domain.boundary
1963        #print 'domain2.boundary'
1964        #print domain2.boundary
1965        domain2.set_boundary({'left': Bd2, 'right': Bd2, 'top': Bd2, 'bottom': Br2})
1966        #domain2.boundary = domain.boundary
1967        #domain2.set_boundary({'exterior': Bd})
1968
1969        domain2.check_integrity()
1970
1971        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
1972            if visualise: sleep(.03)
1973            #domain2.write_time()
1974            pass
1975
1976        ###################
1977        ##NOW TEST IT!!!
1978        ##################
1979
1980        bits = [ 'vertex_coordinates']
1981
1982        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
1983            #bits.append('quantities["%s"].get_integral()'%quantity)
1984            bits.append('get_quantity("%s")'%quantity)
1985
1986        for bit in bits:
1987            print bit
1988            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1989
1990
1991    def test_decimate_dem(self):
1992        """Test decimation of dem file
1993        """
1994
1995        import os
1996        from Numeric import ones, allclose, Float, arange
1997        from Scientific.IO.NetCDF import NetCDFFile
1998
1999        #Write test dem file
2000        root = 'decdemtest'
2001
2002        filename = root + '.dem'
2003        fid = NetCDFFile(filename, 'w')
2004
2005        fid.institution = 'Geoscience Australia'
2006        fid.description = 'NetCDF DEM format for compact and portable ' +\
2007                          'storage of spatial point data'
2008
2009        nrows = 15
2010        ncols = 18
2011
2012        fid.ncols = ncols
2013        fid.nrows = nrows
2014        fid.xllcorner = 2000.5
2015        fid.yllcorner = 3000.5
2016        fid.cellsize = 25
2017        fid.NODATA_value = -9999
2018
2019        fid.zone = 56
2020        fid.false_easting = 0.0
2021        fid.false_northing = 0.0
2022        fid.projection = 'UTM'
2023        fid.datum = 'WGS84'
2024        fid.units = 'METERS'
2025
2026        fid.createDimension('number_of_points', nrows*ncols)
2027
2028        fid.createVariable('elevation', Float, ('number_of_points',))
2029
2030        elevation = fid.variables['elevation']
2031
2032        elevation[:] = (arange(nrows*ncols))
2033
2034        fid.close()
2035
2036        #generate the elevation values expected in the decimated file
2037        ref_elevation = [(  0+  1+  2+ 18+ 19+ 20+ 36+ 37+ 38) / 9.0,
2038                         (  4+  5+  6+ 22+ 23+ 24+ 40+ 41+ 42) / 9.0,
2039                         (  8+  9+ 10+ 26+ 27+ 28+ 44+ 45+ 46) / 9.0,
2040                         ( 12+ 13+ 14+ 30+ 31+ 32+ 48+ 49+ 50) / 9.0,
2041                         ( 72+ 73+ 74+ 90+ 91+ 92+108+109+110) / 9.0,
2042                         ( 76+ 77+ 78+ 94+ 95+ 96+112+113+114) / 9.0,
2043                         ( 80+ 81+ 82+ 98+ 99+100+116+117+118) / 9.0,
2044                         ( 84+ 85+ 86+102+103+104+120+121+122) / 9.0,
2045                         (144+145+146+162+163+164+180+181+182) / 9.0,
2046                         (148+149+150+166+167+168+184+185+186) / 9.0,
2047                         (152+153+154+170+171+172+188+189+190) / 9.0,
2048                         (156+157+158+174+175+176+192+193+194) / 9.0,
2049                         (216+217+218+234+235+236+252+253+254) / 9.0,
2050                         (220+221+222+238+239+240+256+257+258) / 9.0,
2051                         (224+225+226+242+243+244+260+261+262) / 9.0,
2052                         (228+229+230+246+247+248+264+265+266) / 9.0]
2053
2054        #generate a stencil for computing the decimated values
2055        stencil = ones((3,3), Float) / 9.0
2056
2057        decimate_dem(root, stencil=stencil, cellsize_new=100)
2058
2059        #Open decimated NetCDF file
2060        fid = NetCDFFile(root + '_100.dem', 'r')
2061
2062        # Get decimated elevation
2063        elevation = fid.variables['elevation']
2064
2065        #Check values
2066        assert allclose(elevation, ref_elevation)
2067
2068        #Cleanup
2069        fid.close()
2070
2071        os.remove(root + '.dem')
2072        os.remove(root + '_100.dem')
2073
2074    def test_decimate_dem_NODATA(self):
2075        """Test decimation of dem file that includes NODATA values
2076        """
2077
2078        import os
2079        from Numeric import ones, allclose, Float, arange, reshape
2080        from Scientific.IO.NetCDF import NetCDFFile
2081
2082        #Write test dem file
2083        root = 'decdemtest'
2084
2085        filename = root + '.dem'
2086        fid = NetCDFFile(filename, 'w')
2087
2088        fid.institution = 'Geoscience Australia'
2089        fid.description = 'NetCDF DEM format for compact and portable ' +\
2090                          'storage of spatial point data'
2091
2092        nrows = 15
2093        ncols = 18
2094        NODATA_value = -9999
2095
2096        fid.ncols = ncols
2097        fid.nrows = nrows
2098        fid.xllcorner = 2000.5
2099        fid.yllcorner = 3000.5
2100        fid.cellsize = 25
2101        fid.NODATA_value = NODATA_value
2102
2103        fid.zone = 56
2104        fid.false_easting = 0.0
2105        fid.false_northing = 0.0
2106        fid.projection = 'UTM'
2107        fid.datum = 'WGS84'
2108        fid.units = 'METERS'
2109
2110        fid.createDimension('number_of_points', nrows*ncols)
2111
2112        fid.createVariable('elevation', Float, ('number_of_points',))
2113
2114        elevation = fid.variables['elevation']
2115
2116        #generate initial elevation values
2117        elevation_tmp = (arange(nrows*ncols))
2118        #add some NODATA values
2119        elevation_tmp[0]   = NODATA_value
2120        elevation_tmp[95]  = NODATA_value
2121        elevation_tmp[188] = NODATA_value
2122        elevation_tmp[189] = NODATA_value
2123        elevation_tmp[190] = NODATA_value
2124        elevation_tmp[209] = NODATA_value
2125        elevation_tmp[252] = NODATA_value
2126
2127        elevation[:] = elevation_tmp
2128
2129        fid.close()
2130
2131        #generate the elevation values expected in the decimated file
2132        ref_elevation = [NODATA_value,
2133                         (  4+  5+  6+ 22+ 23+ 24+ 40+ 41+ 42) / 9.0,
2134                         (  8+  9+ 10+ 26+ 27+ 28+ 44+ 45+ 46) / 9.0,
2135                         ( 12+ 13+ 14+ 30+ 31+ 32+ 48+ 49+ 50) / 9.0,
2136                         ( 72+ 73+ 74+ 90+ 91+ 92+108+109+110) / 9.0,
2137                         NODATA_value,
2138                         ( 80+ 81+ 82+ 98+ 99+100+116+117+118) / 9.0,
2139                         ( 84+ 85+ 86+102+103+104+120+121+122) / 9.0,
2140                         (144+145+146+162+163+164+180+181+182) / 9.0,
2141                         (148+149+150+166+167+168+184+185+186) / 9.0,
2142                         NODATA_value,
2143                         (156+157+158+174+175+176+192+193+194) / 9.0,
2144                         NODATA_value,
2145                         (220+221+222+238+239+240+256+257+258) / 9.0,
2146                         (224+225+226+242+243+244+260+261+262) / 9.0,
2147                         (228+229+230+246+247+248+264+265+266) / 9.0]
2148
2149        #generate a stencil for computing the decimated values
2150        stencil = ones((3,3), Float) / 9.0
2151
2152        decimate_dem(root, stencil=stencil, cellsize_new=100)
2153
2154        #Open decimated NetCDF file
2155        fid = NetCDFFile(root + '_100.dem', 'r')
2156
2157        # Get decimated elevation
2158        elevation = fid.variables['elevation']
2159
2160        #Check values
2161        assert allclose(elevation, ref_elevation)
2162
2163        #Cleanup
2164        fid.close()
2165
2166        os.remove(root + '.dem')
2167        os.remove(root + '_100.dem')
2168
2169
2170
2171
2172#-------------------------------------------------------------
2173if __name__ == "__main__":
2174    suite = unittest.makeSuite(Test_Data_Manager,'test')
2175    #suite = unittest.makeSuite(Test_Data_Manager,'test_dem2pts_bounding_box')
2176    #suite = unittest.makeSuite(Test_Data_Manager,'test_decimate_dem')
2177    #suite = unittest.makeSuite(Test_Data_Manager,'test_decimate_dem_NODATA')
2178    runner = unittest.TextTestRunner()
2179    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.