source: inundation/pyvolution/test_data_manager.py @ 1740

Last change on this file since 1740 was 1740, checked in by ole, 19 years ago

Recovered data_manager as it was on 18 Aug and test_data_manager as it was on 16 Aug

File size: 59.4 KB
Line 
1#!/usr/bin/env python
2#
3
4import unittest
5import copy
6from Numeric import zeros, array, allclose, Float
7from util import mean
8
9from data_manager import *
10from shallow_water import *
11from config import epsilon
12
13from coordinate_transforms.geo_reference import Geo_reference
14
15class Test_Data_Manager(unittest.TestCase):
16    def setUp(self):
17        import time
18        from mesh_factory import rectangular
19
20
21        #Create basic mesh
22        points, vertices, boundary = rectangular(2, 2)
23
24        #Create shallow water domain
25        domain = Domain(points, vertices, boundary)
26        domain.default_order=2
27
28
29        #Set some field values
30        domain.set_quantity('elevation', lambda x,y: -x)
31        domain.set_quantity('friction', 0.03)
32
33
34        ######################
35        # Boundary conditions
36        B = Transmissive_boundary(domain)
37        domain.set_boundary( {'left': B, 'right': B, 'top': B, 'bottom': B})
38
39
40        ######################
41        #Initial condition - with jumps
42
43
44        bed = domain.quantities['elevation'].vertex_values
45        stage = zeros(bed.shape, Float)
46
47        h = 0.3
48        for i in range(stage.shape[0]):
49            if i % 2 == 0:
50                stage[i,:] = bed[i,:] + h
51            else:
52                stage[i,:] = bed[i,:]
53
54        domain.set_quantity('stage', stage)
55        self.initial_stage = copy.copy(domain.quantities['stage'].vertex_values)
56
57        domain.distribute_to_vertices_and_edges()
58
59
60        self.domain = domain
61
62        C = domain.get_vertex_coordinates()
63        self.X = C[:,0:6:2].copy()
64        self.Y = C[:,1:6:2].copy()
65
66        self.F = bed
67
68
69    def tearDown(self):
70        pass
71
72
73
74
75#     def test_xya(self):
76#         import os
77#         from Numeric import concatenate
78
79#         import time, os
80#         from Numeric import array, zeros, allclose, Float, concatenate
81
82#         domain = self.domain
83
84#         domain.filename = 'datatest' + str(time.time())
85#         domain.format = 'xya'
86#         domain.smooth = True
87
88#         xya = get_dataobject(self.domain)
89#         xya.store_all()
90
91
92#         #Read back
93#         file = open(xya.filename)
94#         lFile = file.read().split('\n')
95#         lFile = lFile[:-1]
96
97#         file.close()
98#         os.remove(xya.filename)
99
100#         #Check contents
101#         if domain.smooth:
102#             self.failUnless(lFile[0] == '9 3 # <vertex #> <x> <y> [attributes]')
103#         else:
104#             self.failUnless(lFile[0] == '24 3 # <vertex #> <x> <y> [attributes]')
105
106#         #Get smoothed field values with X and Y
107#         X,Y,F,V = domain.get_vertex_values(xy=True, value_array='field_values',
108#                                            indices = (0,1), precision = Float)
109
110
111#         Q,V = domain.get_vertex_values(xy=False, value_array='conserved_quantities',
112#                                            indices = (0,), precision = Float)
113
114
115
116#         for i, line in enumerate(lFile[1:]):
117#             fields = line.split()
118
119#             assert len(fields) == 5
120
121#             assert allclose(float(fields[0]), X[i])
122#             assert allclose(float(fields[1]), Y[i])
123#             assert allclose(float(fields[2]), F[i,0])
124#             assert allclose(float(fields[3]), Q[i,0])
125#             assert allclose(float(fields[4]), F[i,1])
126
127
128
129
130    def test_sww_constant(self):
131        """Test that constant sww information can be written correctly
132        (non smooth)
133        """
134
135        import time, os
136        from Numeric import array, zeros, allclose, Float, concatenate
137        from Scientific.IO.NetCDF import NetCDFFile
138
139        self.domain.filename = 'datatest' + str(id(self))
140        self.domain.format = 'sww'
141        self.domain.smooth = False
142
143        sww = get_dataobject(self.domain)
144        sww.store_connectivity()
145
146        #Check contents
147        #Get NetCDF
148        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
149
150        # Get the variables
151        x = fid.variables['x']
152        y = fid.variables['y']
153        z = fid.variables['elevation']
154
155        volumes = fid.variables['volumes']
156
157
158        assert allclose (x[:], self.X.flat)
159        assert allclose (y[:], self.Y.flat)
160        assert allclose (z[:], self.F.flat)
161
162        V = volumes
163
164        P = len(self.domain)
165        for k in range(P):
166            assert V[k, 0] == 3*k
167            assert V[k, 1] == 3*k+1
168            assert V[k, 2] == 3*k+2
169
170
171        fid.close()
172
173        #Cleanup
174        os.remove(sww.filename)
175
176
177    def test_sww_constant_smooth(self):
178        """Test that constant sww information can be written correctly
179        (non smooth)
180        """
181
182        import time, os
183        from Numeric import array, zeros, allclose, Float, concatenate
184        from Scientific.IO.NetCDF import NetCDFFile
185
186        self.domain.filename = 'datatest' + str(id(self))
187        self.domain.format = 'sww'
188        self.domain.smooth = True
189
190        sww = get_dataobject(self.domain)
191        sww.store_connectivity()
192
193        #Check contents
194        #Get NetCDF
195        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
196
197        # Get the variables
198        x = fid.variables['x']
199        y = fid.variables['y']
200        z = fid.variables['elevation']
201
202        volumes = fid.variables['volumes']
203
204        X = x[:]
205        Y = y[:]
206
207        assert allclose([X[0], Y[0]], array([0.0, 0.0]))
208        assert allclose([X[1], Y[1]], array([0.0, 0.5]))
209        assert allclose([X[2], Y[2]], array([0.0, 1.0]))
210
211        assert allclose([X[4], Y[4]], array([0.5, 0.5]))
212
213        assert allclose([X[7], Y[7]], array([1.0, 0.5]))
214
215        Z = z[:]
216        assert Z[4] == -0.5
217
218        V = volumes
219        assert V[2,0] == 4
220        assert V[2,1] == 5
221        assert V[2,2] == 1
222
223        assert V[4,0] == 6
224        assert V[4,1] == 7
225        assert V[4,2] == 3
226
227
228        fid.close()
229
230        #Cleanup
231        os.remove(sww.filename)
232
233
234
235    def test_sww_variable(self):
236        """Test that sww information can be written correctly
237        """
238
239        import time, os
240        from Numeric import array, zeros, allclose, Float, concatenate
241        from Scientific.IO.NetCDF import NetCDFFile
242
243        self.domain.filename = 'datatest' + str(id(self))
244        self.domain.format = 'sww'
245        self.domain.smooth = True
246        self.domain.reduction = mean
247
248        sww = get_dataobject(self.domain)
249        sww.store_connectivity()
250        sww.store_timestep('stage')
251
252        #Check contents
253        #Get NetCDF
254        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
255
256
257        # Get the variables
258        x = fid.variables['x']
259        y = fid.variables['y']
260        z = fid.variables['elevation']
261        time = fid.variables['time']
262        stage = fid.variables['stage']
263
264
265        Q = self.domain.quantities['stage']
266        Q0 = Q.vertex_values[:,0]
267        Q1 = Q.vertex_values[:,1]
268        Q2 = Q.vertex_values[:,2]
269
270        A = stage[0,:]
271        #print A[0], (Q2[0,0] + Q1[1,0])/2
272        assert allclose(A[0], (Q2[0] + Q1[1])/2)
273        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
274        assert allclose(A[2], Q0[3])
275        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
276
277        #Center point
278        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
279                                 Q0[5] + Q2[6] + Q1[7])/6)
280
281
282
283        fid.close()
284
285        #Cleanup
286        os.remove(sww.filename)
287
288
289    def test_sww_variable2(self):
290        """Test that sww information can be written correctly
291        multiple timesteps. Use average as reduction operator
292        """
293
294        import time, os
295        from Numeric import array, zeros, allclose, Float, concatenate
296        from Scientific.IO.NetCDF import NetCDFFile
297
298        self.domain.filename = 'datatest' + str(id(self))
299        self.domain.format = 'sww'
300        self.domain.smooth = True
301
302        self.domain.reduction = mean
303
304        sww = get_dataobject(self.domain)
305        sww.store_connectivity()
306        sww.store_timestep('stage')
307        self.domain.evolve_to_end(finaltime = 0.01)
308        sww.store_timestep('stage')
309
310
311        #Check contents
312        #Get NetCDF
313        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
314
315        # Get the variables
316        x = fid.variables['x']
317        y = fid.variables['y']
318        z = fid.variables['elevation']
319        time = fid.variables['time']
320        stage = fid.variables['stage']
321
322        #Check values
323        Q = self.domain.quantities['stage']
324        Q0 = Q.vertex_values[:,0]
325        Q1 = Q.vertex_values[:,1]
326        Q2 = Q.vertex_values[:,2]
327
328        A = stage[1,:]
329        assert allclose(A[0], (Q2[0] + Q1[1])/2)
330        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
331        assert allclose(A[2], Q0[3])
332        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
333
334        #Center point
335        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
336                                 Q0[5] + Q2[6] + Q1[7])/6)
337
338
339        fid.close()
340
341        #Cleanup
342        os.remove(sww.filename)
343
344    def test_sww_variable3(self):
345        """Test that sww information can be written correctly
346        multiple timesteps using a different reduction operator (min)
347        """
348
349        import time, os
350        from Numeric import array, zeros, allclose, Float, concatenate
351        from Scientific.IO.NetCDF import NetCDFFile
352
353        self.domain.filename = 'datatest' + str(id(self))
354        self.domain.format = 'sww'
355        self.domain.smooth = True
356        self.domain.reduction = min
357
358        sww = get_dataobject(self.domain)
359        sww.store_connectivity()
360        sww.store_timestep('stage')
361
362        self.domain.evolve_to_end(finaltime = 0.01)
363        sww.store_timestep('stage')
364
365
366        #Check contents
367        #Get NetCDF
368        fid = NetCDFFile(sww.filename, 'r')
369
370
371        # Get the variables
372        x = fid.variables['x']
373        y = fid.variables['y']
374        z = fid.variables['elevation']
375        time = fid.variables['time']
376        stage = fid.variables['stage']
377
378        #Check values
379        Q = self.domain.quantities['stage']
380        Q0 = Q.vertex_values[:,0]
381        Q1 = Q.vertex_values[:,1]
382        Q2 = Q.vertex_values[:,2]
383
384        A = stage[1,:]
385        assert allclose(A[0], min(Q2[0], Q1[1]))
386        assert allclose(A[1], min(Q0[1], Q1[3], Q2[2]))
387        assert allclose(A[2], Q0[3])
388        assert allclose(A[3], min(Q0[0], Q1[5], Q2[4]))
389
390        #Center point
391        assert allclose(A[4], min(Q1[0], Q2[1], Q0[2],\
392                                  Q0[5], Q2[6], Q1[7]))
393
394
395        fid.close()
396
397        #Cleanup
398        os.remove(sww.filename)
399
400
401    def test_sync(self):
402        """Test info stored at each timestep is as expected (incl initial condition)
403        """
404
405        import time, os, config
406        from Numeric import array, zeros, allclose, Float, concatenate
407        from Scientific.IO.NetCDF import NetCDFFile
408
409        self.domain.filename = 'synctest'
410        self.domain.format = 'sww'
411        self.domain.smooth = False
412        self.domain.store = True
413        self.domain.beta_h = 0
414
415        #Evolution
416        for t in self.domain.evolve(yieldstep = 1.0, finaltime = 4.0):
417            stage = self.domain.quantities['stage'].vertex_values
418
419            #Get NetCDF
420            fid = NetCDFFile(self.domain.writer.filename, 'r')
421            stage_file = fid.variables['stage']
422
423            if t == 0.0:
424                assert allclose(stage, self.initial_stage)
425                assert allclose(stage_file[:], stage.flat)
426            else:
427                assert not allclose(stage, self.initial_stage)
428                assert not allclose(stage_file[:], stage.flat)
429
430            fid.close()
431
432        os.remove(self.domain.writer.filename)
433
434
435
436    def test_sww_DSG(self):
437        """Not a test, rather a look at the sww format
438        """
439
440        import time, os
441        from Numeric import array, zeros, allclose, Float, concatenate
442        from Scientific.IO.NetCDF import NetCDFFile
443
444        self.domain.filename = 'datatest' + str(id(self))
445        self.domain.format = 'sww'
446        self.domain.smooth = True
447        self.domain.reduction = mean
448
449        sww = get_dataobject(self.domain)
450        sww.store_connectivity()
451        sww.store_timestep('stage')
452
453        #Check contents
454        #Get NetCDF
455        fid = NetCDFFile(sww.filename, 'r')
456
457        # Get the variables
458        x = fid.variables['x']
459        y = fid.variables['y']
460        z = fid.variables['elevation']
461
462        volumes = fid.variables['volumes']
463        time = fid.variables['time']
464
465        # 2D
466        stage = fid.variables['stage']
467
468        X = x[:]
469        Y = y[:]
470        Z = z[:]
471        V = volumes[:]
472        T = time[:]
473        S = stage[:,:]
474
475#         print "****************************"
476#         print "X ",X
477#         print "****************************"
478#         print "Y ",Y
479#         print "****************************"
480#         print "Z ",Z
481#         print "****************************"
482#         print "V ",V
483#         print "****************************"
484#         print "Time ",T
485#         print "****************************"
486#         print "Stage ",S
487#         print "****************************"
488
489
490        fid.close()
491
492        #Cleanup
493        os.remove(sww.filename)
494
495
496
497    def test_dem2pts(self):
498        """Test conversion from dem in ascii format to native NetCDF xya format
499        """
500
501        import time, os
502        from Numeric import array, zeros, allclose, Float, concatenate
503        from Scientific.IO.NetCDF import NetCDFFile
504
505        #Write test asc file
506        root = 'demtest'
507
508        filename = root+'.asc'
509        fid = open(filename, 'w')
510        fid.write("""ncols         5
511nrows         6
512xllcorner     2000.5
513yllcorner     3000.5
514cellsize      25
515NODATA_value  -9999
516""")
517        #Create linear function
518
519        ref_points = []
520        ref_elevation = []
521        for i in range(6):
522            y = (6-i)*25.0
523            for j in range(5):
524                x = j*25.0
525                z = x+2*y
526
527                ref_points.append( [x,y] )
528                ref_elevation.append(z)
529                fid.write('%f ' %z)
530            fid.write('\n')
531
532        fid.close()
533
534        #Write prj file with metadata
535        metafilename = root+'.prj'
536        fid = open(metafilename, 'w')
537
538
539        fid.write("""Projection UTM
540Zone 56
541Datum WGS84
542Zunits NO
543Units METERS
544Spheroid WGS84
545Xshift 0.0000000000
546Yshift 10000000.0000000000
547Parameters
548""")
549        fid.close()
550
551        #Convert to NetCDF pts
552        convert_dem_from_ascii2netcdf(root)
553        dem2pts(root)
554
555        #Check contents
556        #Get NetCDF
557        fid = NetCDFFile(root+'.pts', 'r')
558
559        # Get the variables
560        #print fid.variables.keys()
561        points = fid.variables['points']
562        elevation = fid.variables['elevation']
563
564        #Check values
565
566        #print points[:]
567        #print ref_points
568        assert allclose(points, ref_points)
569
570        #print attributes[:]
571        #print ref_elevation
572        assert allclose(elevation, ref_elevation)
573
574        #Cleanup
575        fid.close()
576
577
578        os.remove(root + '.pts')
579        os.remove(root + '.dem')
580        os.remove(root + '.asc')
581        os.remove(root + '.prj')
582
583
584
585    def test_dem2pts_bounding_box(self):
586        """Test conversion from dem in ascii format to native NetCDF xya format
587        """
588
589        import time, os
590        from Numeric import array, zeros, allclose, Float, concatenate
591        from Scientific.IO.NetCDF import NetCDFFile
592
593        #Write test asc file
594        root = 'demtest'
595
596        filename = root+'.asc'
597        fid = open(filename, 'w')
598        fid.write("""ncols         5
599nrows         6
600xllcorner     2000.5
601yllcorner     3000.5
602cellsize      25
603NODATA_value  -9999
604""")
605        #Create linear function
606
607        ref_points = []
608        ref_elevation = []
609        for i in range(6):
610            y = (6-i)*25.0
611            for j in range(5):
612                x = j*25.0
613                z = x+2*y
614
615                ref_points.append( [x,y] )
616                ref_elevation.append(z)
617                fid.write('%f ' %z)
618            fid.write('\n')
619
620        fid.close()
621
622        #Write prj file with metadata
623        metafilename = root+'.prj'
624        fid = open(metafilename, 'w')
625
626
627        fid.write("""Projection UTM
628Zone 56
629Datum WGS84
630Zunits NO
631Units METERS
632Spheroid WGS84
633Xshift 0.0000000000
634Yshift 10000000.0000000000
635Parameters
636""")
637        fid.close()
638
639        #Convert to NetCDF pts
640        convert_dem_from_ascii2netcdf(root)
641        dem2pts(root, easting_min=2010.0, easting_max=2110.0,
642                northing_min=3035.0, northing_max=3125.5)
643
644        #Check contents
645        #Get NetCDF
646        fid = NetCDFFile(root+'.pts', 'r')
647
648        # Get the variables
649        #print fid.variables.keys()
650        points = fid.variables['points']
651        elevation = fid.variables['elevation']
652
653        #Check values
654        assert fid.xllcorner[0] == 2010.0
655        assert fid.yllcorner[0] == 3035.0
656
657        #create new reference points
658        ref_points = []
659        ref_elevation = []
660        for i in range(4):
661            y = (4-i)*25.0 + 25.0
662            y_new = y + 3000.5 - 3035.0
663            for j in range(4):
664                x = j*25.0 + 25.0
665                x_new = x + 2000.5 - 2010.0
666                z = x+2*y
667
668                ref_points.append( [x_new,y_new] )
669                ref_elevation.append(z)
670
671        #print points[:]
672        #print ref_points
673        assert allclose(points, ref_points)
674
675        #print attributes[:]
676        #print ref_elevation
677        assert allclose(elevation, ref_elevation)
678
679        #Cleanup
680        fid.close()
681
682
683        os.remove(root + '.pts')
684        os.remove(root + '.dem')
685        os.remove(root + '.asc')
686        os.remove(root + '.prj')
687
688
689
690    def test_sww2asc_elevation(self):
691        """Test that sww information can be converted correctly to asc/prj
692        format readable by e.g. ArcView
693        """
694
695        import time, os
696        from Numeric import array, zeros, allclose, Float, concatenate
697        from Scientific.IO.NetCDF import NetCDFFile
698
699        #Setup
700        self.domain.filename = 'datatest'
701
702        prjfile = self.domain.filename + '_elevation.prj'
703        ascfile = self.domain.filename + '_elevation.asc'
704        swwfile = self.domain.filename + '.sww'
705
706        self.domain.set_datadir('.')
707        self.domain.format = 'sww'
708        self.domain.smooth = True
709        self.domain.set_quantity('elevation', lambda x,y: -x-y)
710
711        self.domain.geo_reference = Geo_reference(56,308500,6189000)
712
713        sww = get_dataobject(self.domain)
714        sww.store_connectivity()
715        sww.store_timestep('stage')
716
717        self.domain.evolve_to_end(finaltime = 0.01)
718        sww.store_timestep('stage')
719
720        cellsize = 0.25
721        #Check contents
722        #Get NetCDF
723
724        fid = NetCDFFile(sww.filename, 'r')
725
726        # Get the variables
727        x = fid.variables['x'][:]
728        y = fid.variables['y'][:]
729        z = fid.variables['elevation'][:]
730        time = fid.variables['time'][:]
731        stage = fid.variables['stage'][:]
732
733
734        #Export to ascii/prj files
735        sww2asc(self.domain.filename,
736                quantity = 'elevation',
737                cellsize = cellsize,
738                verbose = False)
739
740
741        #Check prj (meta data)
742        prjid = open(prjfile)
743        lines = prjid.readlines()
744        prjid.close()
745
746        L = lines[0].strip().split()
747        assert L[0].strip().lower() == 'projection'
748        assert L[1].strip().lower() == 'utm'
749
750        L = lines[1].strip().split()
751        assert L[0].strip().lower() == 'zone'
752        assert L[1].strip().lower() == '56'
753
754        L = lines[2].strip().split()
755        assert L[0].strip().lower() == 'datum'
756        assert L[1].strip().lower() == 'wgs84'
757
758        L = lines[3].strip().split()
759        assert L[0].strip().lower() == 'zunits'
760        assert L[1].strip().lower() == 'no'
761
762        L = lines[4].strip().split()
763        assert L[0].strip().lower() == 'units'
764        assert L[1].strip().lower() == 'meters'
765
766        L = lines[5].strip().split()
767        assert L[0].strip().lower() == 'spheroid'
768        assert L[1].strip().lower() == 'wgs84'
769
770        L = lines[6].strip().split()
771        assert L[0].strip().lower() == 'xshift'
772        assert L[1].strip().lower() == '500000'
773
774        L = lines[7].strip().split()
775        assert L[0].strip().lower() == 'yshift'
776        assert L[1].strip().lower() == '10000000'
777
778        L = lines[8].strip().split()
779        assert L[0].strip().lower() == 'parameters'
780
781
782        #Check asc file
783        ascid = open(ascfile)
784        lines = ascid.readlines()
785        ascid.close()
786
787        L = lines[0].strip().split()
788        assert L[0].strip().lower() == 'ncols'
789        assert L[1].strip().lower() == '5'
790
791        L = lines[1].strip().split()
792        assert L[0].strip().lower() == 'nrows'
793        assert L[1].strip().lower() == '5'
794
795        L = lines[2].strip().split()
796        assert L[0].strip().lower() == 'xllcorner'
797        assert allclose(float(L[1].strip().lower()), 308500)
798
799        L = lines[3].strip().split()
800        assert L[0].strip().lower() == 'yllcorner'
801        assert allclose(float(L[1].strip().lower()), 6189000)
802
803        L = lines[4].strip().split()
804        assert L[0].strip().lower() == 'cellsize'
805        assert allclose(float(L[1].strip().lower()), cellsize)
806
807        L = lines[5].strip().split()
808        assert L[0].strip() == 'NODATA_value'
809        assert L[1].strip().lower() == '-9999'
810
811        #Check grid values
812        for j in range(5):
813            L = lines[6+j].strip().split()
814            y = (4-j) * cellsize
815            for i in range(5):
816                assert allclose(float(L[i]), -i*cellsize - y)
817
818
819        fid.close()
820
821        #Cleanup
822        os.remove(prjfile)
823        os.remove(ascfile)
824        os.remove(swwfile)
825
826
827    def test_sww2asc_stage_reduction(self):
828        """Test that sww information can be converted correctly to asc/prj
829        format readable by e.g. ArcView
830
831        This tests the reduction of quantity stage using min
832        """
833
834        import time, os
835        from Numeric import array, zeros, allclose, Float, concatenate
836        from Scientific.IO.NetCDF import NetCDFFile
837
838        #Setup
839        self.domain.filename = 'datatest'
840
841        prjfile = self.domain.filename + '_stage.prj'
842        ascfile = self.domain.filename + '_stage.asc'
843        swwfile = self.domain.filename + '.sww'
844
845        self.domain.set_datadir('.')
846        self.domain.format = 'sww'
847        self.domain.smooth = True
848        self.domain.set_quantity('elevation', lambda x,y: -x-y)
849
850        self.domain.geo_reference = Geo_reference(56,308500,6189000)
851
852
853        sww = get_dataobject(self.domain)
854        sww.store_connectivity()
855        sww.store_timestep('stage')
856
857        self.domain.evolve_to_end(finaltime = 0.01)
858        sww.store_timestep('stage')
859
860        cellsize = 0.25
861        #Check contents
862        #Get NetCDF
863
864        fid = NetCDFFile(sww.filename, 'r')
865
866        # Get the variables
867        x = fid.variables['x'][:]
868        y = fid.variables['y'][:]
869        z = fid.variables['elevation'][:]
870        time = fid.variables['time'][:]
871        stage = fid.variables['stage'][:]
872
873
874        #Export to ascii/prj files
875        sww2asc(self.domain.filename,
876                quantity = 'stage',
877                cellsize = cellsize,
878                reduction = min)
879
880
881        #Check asc file
882        ascid = open(ascfile)
883        lines = ascid.readlines()
884        ascid.close()
885
886        L = lines[0].strip().split()
887        assert L[0].strip().lower() == 'ncols'
888        assert L[1].strip().lower() == '5'
889
890        L = lines[1].strip().split()
891        assert L[0].strip().lower() == 'nrows'
892        assert L[1].strip().lower() == '5'
893
894        L = lines[2].strip().split()
895        assert L[0].strip().lower() == 'xllcorner'
896        assert allclose(float(L[1].strip().lower()), 308500)
897
898        L = lines[3].strip().split()
899        assert L[0].strip().lower() == 'yllcorner'
900        assert allclose(float(L[1].strip().lower()), 6189000)
901
902        L = lines[4].strip().split()
903        assert L[0].strip().lower() == 'cellsize'
904        assert allclose(float(L[1].strip().lower()), cellsize)
905
906        L = lines[5].strip().split()
907        assert L[0].strip() == 'NODATA_value'
908        assert L[1].strip().lower() == '-9999'
909
910
911        #Check grid values (where applicable)
912        for j in range(5):
913            if j%2 == 0:
914                L = lines[6+j].strip().split()
915                jj = 4-j
916                for i in range(5):
917                    if i%2 == 0:
918                        index = jj/2 + i/2*3
919                        val0 = stage[0,index]
920                        val1 = stage[1,index]
921
922                        #print i, j, index, ':', L[i], val0, val1
923                        assert allclose(float(L[i]), min(val0, val1))
924
925
926        fid.close()
927
928        #Cleanup
929        os.remove(prjfile)
930        os.remove(ascfile)
931        #os.remove(swwfile)
932
933
934
935
936    def test_sww2asc_missing_points(self):
937        """Test that sww information can be converted correctly to asc/prj
938        format readable by e.g. ArcView
939
940        This test includes the writing of missing values
941        """
942
943        import time, os
944        from Numeric import array, zeros, allclose, Float, concatenate
945        from Scientific.IO.NetCDF import NetCDFFile
946
947        #Setup mesh not coinciding with rectangle.
948        #This will cause missing values to occur in gridded data
949
950
951        points = [                        [1.0, 1.0],
952                              [0.5, 0.5], [1.0, 0.5],
953                  [0.0, 0.0], [0.5, 0.0], [1.0, 0.0]]
954
955        vertices = [ [4,1,3], [5,2,4], [1,4,2], [2,0,1]]
956
957        #Create shallow water domain
958        domain = Domain(points, vertices)
959        domain.default_order=2
960
961
962        #Set some field values
963        domain.set_quantity('elevation', lambda x,y: -x-y)
964        domain.set_quantity('friction', 0.03)
965
966
967        ######################
968        # Boundary conditions
969        B = Transmissive_boundary(domain)
970        domain.set_boundary( {'exterior': B} )
971
972
973        ######################
974        #Initial condition - with jumps
975
976        bed = domain.quantities['elevation'].vertex_values
977        stage = zeros(bed.shape, Float)
978
979        h = 0.3
980        for i in range(stage.shape[0]):
981            if i % 2 == 0:
982                stage[i,:] = bed[i,:] + h
983            else:
984                stage[i,:] = bed[i,:]
985
986        domain.set_quantity('stage', stage)
987        domain.distribute_to_vertices_and_edges()
988
989        domain.filename = 'datatest'
990
991        prjfile = domain.filename + '_elevation.prj'
992        ascfile = domain.filename + '_elevation.asc'
993        swwfile = domain.filename + '.sww'
994
995        domain.set_datadir('.')
996        domain.format = 'sww'
997        domain.smooth = True
998
999        domain.geo_reference = Geo_reference(56,308500,6189000)
1000
1001        sww = get_dataobject(domain)
1002        sww.store_connectivity()
1003        sww.store_timestep('stage')
1004
1005        cellsize = 0.25
1006        #Check contents
1007        #Get NetCDF
1008
1009        fid = NetCDFFile(swwfile, 'r')
1010
1011        # Get the variables
1012        x = fid.variables['x'][:]
1013        y = fid.variables['y'][:]
1014        z = fid.variables['elevation'][:]
1015        time = fid.variables['time'][:]
1016
1017        try:
1018            geo_reference = Geo_reference(NetCDFObject=fid)
1019        except AttributeError, e:
1020            geo_reference = Geo_reference(DEFAULT_ZONE,0,0)
1021
1022        #Export to ascii/prj files
1023        sww2asc(domain.filename,
1024                quantity = 'elevation',
1025                cellsize = cellsize,
1026                verbose = False)
1027
1028
1029        #Check asc file
1030        ascid = open(ascfile)
1031        lines = ascid.readlines()
1032        ascid.close()
1033
1034        L = lines[0].strip().split()
1035        assert L[0].strip().lower() == 'ncols'
1036        assert L[1].strip().lower() == '5'
1037
1038        L = lines[1].strip().split()
1039        assert L[0].strip().lower() == 'nrows'
1040        assert L[1].strip().lower() == '5'
1041
1042        L = lines[2].strip().split()
1043        assert L[0].strip().lower() == 'xllcorner'
1044        assert allclose(float(L[1].strip().lower()), 308500)
1045
1046        L = lines[3].strip().split()
1047        assert L[0].strip().lower() == 'yllcorner'
1048        assert allclose(float(L[1].strip().lower()), 6189000)
1049
1050        L = lines[4].strip().split()
1051        assert L[0].strip().lower() == 'cellsize'
1052        assert allclose(float(L[1].strip().lower()), cellsize)
1053
1054        L = lines[5].strip().split()
1055        assert L[0].strip() == 'NODATA_value'
1056        assert L[1].strip().lower() == '-9999'
1057
1058
1059        #Check grid values
1060        for j in range(5):
1061            L = lines[6+j].strip().split()
1062            y = (4-j) * cellsize
1063            for i in range(5):
1064                if i+j >= 4:
1065                    assert allclose(float(L[i]), -i*cellsize - y)
1066                else:
1067                    #Missing values
1068                    assert allclose(float(L[i]), -9999)
1069
1070
1071
1072        fid.close()
1073
1074        #Cleanup
1075        os.remove(prjfile)
1076        os.remove(ascfile)
1077        os.remove(swwfile)
1078
1079
1080    def test_ferret2sww(self):
1081        """Test that georeferencing etc works when converting from
1082        ferret format (lat/lon) to sww format (UTM)
1083        """
1084        from Scientific.IO.NetCDF import NetCDFFile
1085
1086        #The test file has
1087        # LON = 150.66667, 150.83334, 151, 151.16667
1088        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1089        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1090        #
1091        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1092        # Fourth value (index==3) is -6.50198 cm
1093
1094
1095        from coordinate_transforms.redfearn import redfearn
1096
1097        fid = NetCDFFile('small_ha.nc')
1098        first_value = fid.variables['HA'][:][0,0,0]
1099        fourth_value = fid.variables['HA'][:][0,0,3]
1100
1101
1102        #Call conversion (with zero origin)
1103        ferret2sww('small', verbose=False,
1104                   origin = (56, 0, 0))
1105
1106
1107        #Work out the UTM coordinates for first point
1108        zone, e, n = redfearn(-34.5, 150.66667)
1109        #print zone, e, n
1110
1111        #Read output file 'small.sww'
1112        fid = NetCDFFile('small.sww')
1113
1114        x = fid.variables['x'][:]
1115        y = fid.variables['y'][:]
1116
1117        #Check that first coordinate is correctly represented
1118        assert allclose(x[0], e)
1119        assert allclose(y[0], n)
1120
1121        #Check first value
1122        stage = fid.variables['stage'][:]
1123        xmomentum = fid.variables['xmomentum'][:]
1124        ymomentum = fid.variables['ymomentum'][:]
1125
1126        #print ymomentum
1127
1128        assert allclose(stage[0,0], first_value/100)  #Meters
1129
1130        #Check fourth value
1131        assert allclose(stage[0,3], fourth_value/100)  #Meters
1132
1133        fid.close()
1134
1135        #Cleanup
1136        import os
1137        os.remove('small.sww')
1138
1139
1140
1141    def test_ferret2sww_2(self):
1142        """Test that georeferencing etc works when converting from
1143        ferret format (lat/lon) to sww format (UTM)
1144        """
1145        from Scientific.IO.NetCDF import NetCDFFile
1146
1147        #The test file has
1148        # LON = 150.66667, 150.83334, 151, 151.16667
1149        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1150        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1151        #
1152        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1153        # Fourth value (index==3) is -6.50198 cm
1154
1155
1156        from coordinate_transforms.redfearn import redfearn
1157
1158        fid = NetCDFFile('small_ha.nc')
1159
1160        #Pick a coordinate and a value
1161
1162        time_index = 1
1163        lat_index = 0
1164        lon_index = 2
1165
1166        test_value = fid.variables['HA'][:][time_index, lat_index, lon_index]
1167        test_time = fid.variables['TIME'][:][time_index]
1168        test_lat = fid.variables['LAT'][:][lat_index]
1169        test_lon = fid.variables['LON'][:][lon_index]
1170
1171        linear_point_index = lat_index*4 + lon_index
1172        fid.close()
1173
1174        #Call conversion (with zero origin)
1175        ferret2sww('small', verbose=False,
1176                   origin = (56, 0, 0))
1177
1178
1179        #Work out the UTM coordinates for test point
1180        zone, e, n = redfearn(test_lat, test_lon)
1181
1182        #Read output file 'small.sww'
1183        fid = NetCDFFile('small.sww')
1184
1185        x = fid.variables['x'][:]
1186        y = fid.variables['y'][:]
1187
1188        #Check that test coordinate is correctly represented
1189        assert allclose(x[linear_point_index], e)
1190        assert allclose(y[linear_point_index], n)
1191
1192        #Check test value
1193        stage = fid.variables['stage'][:]
1194
1195        assert allclose(stage[time_index, linear_point_index], test_value/100)
1196
1197        fid.close()
1198
1199        #Cleanup
1200        import os
1201        os.remove('small.sww')
1202
1203
1204
1205    def test_ferret2sww3(self):
1206        """
1207        """
1208        from Scientific.IO.NetCDF import NetCDFFile
1209
1210        #The test file has
1211        # LON = 150.66667, 150.83334, 151, 151.16667
1212        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1213        # ELEVATION = [-1 -2 -3 -4
1214        #             -5 -6 -7 -8
1215        #              ...
1216        #              ...      -16]
1217        # where the top left corner is -1m,
1218        # and the ll corner is -13.0m
1219        #
1220        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1221        # Fourth value (index==3) is -6.50198 cm
1222
1223        from coordinate_transforms.redfearn import redfearn
1224        import os
1225        fid1 = NetCDFFile('test_ha.nc','w')
1226        fid2 = NetCDFFile('test_ua.nc','w')
1227        fid3 = NetCDFFile('test_va.nc','w')
1228        fid4 = NetCDFFile('test_e.nc','w')
1229
1230        h1_list = [150.66667,150.83334,151.]
1231        h2_list = [-34.5,-34.33333]
1232
1233        long_name = 'LON'
1234        lat_name = 'LAT'
1235
1236        nx = 3
1237        ny = 2
1238
1239        for fid in [fid1,fid2,fid3]:
1240            fid.createDimension(long_name,nx)
1241            fid.createVariable(long_name,'d',(long_name,))
1242            fid.variables[long_name].point_spacing='uneven'
1243            fid.variables[long_name].units='degrees_east'
1244            fid.variables[long_name].assignValue(h1_list)
1245
1246            fid.createDimension(lat_name,ny)
1247            fid.createVariable(lat_name,'d',(lat_name,))
1248            fid.variables[lat_name].point_spacing='uneven'
1249            fid.variables[lat_name].units='degrees_north'
1250            fid.variables[lat_name].assignValue(h2_list)
1251
1252            fid.createDimension('TIME',2)
1253            fid.createVariable('TIME','d',('TIME',))
1254            fid.variables['TIME'].point_spacing='uneven'
1255            fid.variables['TIME'].units='seconds'
1256            fid.variables['TIME'].assignValue([0.,1.])
1257            if fid == fid3: break
1258
1259
1260        for fid in [fid4]:
1261            fid.createDimension(long_name,nx)
1262            fid.createVariable(long_name,'d',(long_name,))
1263            fid.variables[long_name].point_spacing='uneven'
1264            fid.variables[long_name].units='degrees_east'
1265            fid.variables[long_name].assignValue(h1_list)
1266
1267            fid.createDimension(lat_name,ny)
1268            fid.createVariable(lat_name,'d',(lat_name,))
1269            fid.variables[lat_name].point_spacing='uneven'
1270            fid.variables[lat_name].units='degrees_north'
1271            fid.variables[lat_name].assignValue(h2_list)
1272
1273        name = {}
1274        name[fid1]='HA'
1275        name[fid2]='UA'
1276        name[fid3]='VA'
1277        name[fid4]='ELEVATION'
1278
1279        units = {}
1280        units[fid1]='cm'
1281        units[fid2]='cm/s'
1282        units[fid3]='cm/s'
1283        units[fid4]='m'
1284
1285        values = {}
1286        values[fid1]=[[[5., 10.,15.], [13.,18.,23.]],[[50.,100.,150.],[130.,180.,230.]]]
1287        values[fid2]=[[[1., 2.,3.], [4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]]
1288        values[fid3]=[[[13., 12.,11.], [10.,9.,8.]],[[7.,6.,5.],[4.,3.,2.]]]
1289        values[fid4]=[[-3000,-3100,-3200],[-4000,-5000,-6000]]
1290
1291        for fid in [fid1,fid2,fid3]:
1292          fid.createVariable(name[fid],'d',('TIME',lat_name,long_name))
1293          fid.variables[name[fid]].point_spacing='uneven'
1294          fid.variables[name[fid]].units=units[fid]
1295          fid.variables[name[fid]].assignValue(values[fid])
1296          fid.variables[name[fid]].missing_value = -99999999.
1297          if fid == fid3: break
1298
1299        for fid in [fid4]:
1300            fid.createVariable(name[fid],'d',(lat_name,long_name))
1301            fid.variables[name[fid]].point_spacing='uneven'
1302            fid.variables[name[fid]].units=units[fid]
1303            fid.variables[name[fid]].assignValue(values[fid])
1304            fid.variables[name[fid]].missing_value = -99999999.
1305
1306
1307        fid1.sync(); fid1.close()
1308        fid2.sync(); fid2.close()
1309        fid3.sync(); fid3.close()
1310        fid4.sync(); fid4.close()
1311
1312        fid1 = NetCDFFile('test_ha.nc','r')
1313        fid2 = NetCDFFile('test_e.nc','r')
1314        fid3 = NetCDFFile('test_va.nc','r')
1315
1316
1317        first_amp = fid1.variables['HA'][:][0,0,0]
1318        third_amp = fid1.variables['HA'][:][0,0,2]
1319        first_elevation = fid2.variables['ELEVATION'][0,0]
1320        third_elevation= fid2.variables['ELEVATION'][:][0,2]
1321        first_speed = fid3.variables['VA'][0,0,0]
1322        third_speed = fid3.variables['VA'][:][0,0,2]
1323
1324        fid1.close()
1325        fid2.close()
1326        fid3.close()
1327
1328        #Call conversion (with zero origin)
1329        ferret2sww('test', verbose=False,
1330                   origin = (56, 0, 0))
1331
1332        os.remove('test_va.nc')
1333        os.remove('test_ua.nc')
1334        os.remove('test_ha.nc')
1335        os.remove('test_e.nc')
1336
1337        #Read output file 'test.sww'
1338        fid = NetCDFFile('test.sww')
1339
1340
1341        #Check first value
1342        elevation = fid.variables['elevation'][:]
1343        stage = fid.variables['stage'][:]
1344        xmomentum = fid.variables['xmomentum'][:]
1345        ymomentum = fid.variables['ymomentum'][:]
1346
1347        #print ymomentum
1348        first_height = first_amp/100 - first_elevation
1349        third_height = third_amp/100 - third_elevation
1350        first_momentum=first_speed*first_height/100
1351        third_momentum=third_speed*third_height/100
1352
1353        assert allclose(ymomentum[0][0],first_momentum)  #Meters
1354        assert allclose(ymomentum[0][2],third_momentum)  #Meters
1355
1356        fid.close()
1357
1358        #Cleanup
1359        os.remove('test.sww')
1360
1361
1362
1363
1364    def test_sww_extent(self):
1365        """Not a test, rather a look at the sww format
1366        """
1367
1368        import time, os
1369        from Numeric import array, zeros, allclose, Float, concatenate
1370        from Scientific.IO.NetCDF import NetCDFFile
1371
1372        self.domain.filename = 'datatest' + str(id(self))
1373        self.domain.format = 'sww'
1374        self.domain.smooth = True
1375        self.domain.reduction = mean
1376        self.domain.set_datadir('.')
1377
1378
1379        sww = get_dataobject(self.domain)
1380        sww.store_connectivity()
1381        sww.store_timestep('stage')
1382        self.domain.time = 2.
1383
1384        #Modify stage at second timestep
1385        stage = self.domain.quantities['stage'].vertex_values
1386        self.domain.set_quantity('stage', stage/2)
1387
1388        sww.store_timestep('stage')
1389
1390        file_and_extension_name = self.domain.filename + ".sww"
1391        #print "file_and_extension_name",file_and_extension_name
1392        [xmin, xmax, ymin, ymax, stagemin, stagemax] = \
1393               extent_sww(file_and_extension_name )
1394
1395        assert allclose(xmin, 0.0)
1396        assert allclose(xmax, 1.0)
1397        assert allclose(ymin, 0.0)
1398        assert allclose(ymax, 1.0)
1399        assert allclose(stagemin, -0.85)
1400        assert allclose(stagemax, 0.15)
1401
1402
1403        #Cleanup
1404        os.remove(sww.filename)
1405
1406
1407    def test_ferret2sww_nz_origin(self):
1408        from Scientific.IO.NetCDF import NetCDFFile
1409        from coordinate_transforms.redfearn import redfearn
1410
1411        #Call conversion (with nonzero origin)
1412        ferret2sww('small', verbose=False,
1413                   origin = (56, 100000, 200000))
1414
1415
1416        #Work out the UTM coordinates for first point
1417        zone, e, n = redfearn(-34.5, 150.66667)
1418
1419        #Read output file 'small.sww'
1420        fid = NetCDFFile('small.sww', 'r')
1421
1422        x = fid.variables['x'][:]
1423        y = fid.variables['y'][:]
1424
1425        #Check that first coordinate is correctly represented
1426        assert allclose(x[0], e-100000)
1427        assert allclose(y[0], n-200000)
1428
1429        fid.close()
1430
1431        #Cleanup
1432        import os
1433        os.remove('small.sww')
1434
1435    def test_sww2domain(self):
1436        ################################################
1437        #Create a test domain, and evolve and save it.
1438        ################################################
1439        from mesh_factory import rectangular
1440        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1441             Constant_height, Time_boundary, Transmissive_boundary
1442        from Numeric import array
1443
1444        #Create basic mesh
1445
1446        yiel=0.01
1447        points, vertices, boundary = rectangular(10,10)
1448
1449        #Create shallow water domain
1450        domain = Domain(points, vertices, boundary)
1451        domain.geo_reference = Geo_reference(56,11,11)
1452        domain.smooth = False
1453        domain.visualise = False
1454        domain.store = True
1455        domain.filename = 'bedslope'
1456        domain.default_order=2
1457        #Bed-slope and friction
1458        domain.set_quantity('elevation', lambda x,y: -x/3)
1459        domain.set_quantity('friction', 0.1)
1460        # Boundary conditions
1461        from math import sin, pi
1462        Br = Reflective_boundary(domain)
1463        Bt = Transmissive_boundary(domain)
1464        Bd = Dirichlet_boundary([0.2,0.,0.])
1465        Bw = Time_boundary(domain=domain,
1466                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1467
1468        #domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1469        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1470
1471        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1472        #Initial condition
1473        h = 0.05
1474        elevation = domain.quantities['elevation'].vertex_values
1475        domain.set_quantity('stage', elevation + h)
1476        #elevation = domain.get_quantity('elevation')
1477        #domain.set_quantity('stage', elevation + h)
1478
1479        domain.check_integrity()
1480        #Evolution
1481        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1482        #    domain.write_time()
1483            pass
1484
1485
1486        ##########################################
1487        #Import the example's file as a new domain
1488        ##########################################
1489        from data_manager import sww2domain
1490        from Numeric import allclose
1491        import os
1492
1493        filename = domain.datadir+os.sep+domain.filename+'.sww'
1494        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
1495        #points, vertices, boundary = rectangular(15,15)
1496        #domain2.boundary = boundary
1497        ###################
1498        ##NOW TEST IT!!!
1499        ###################
1500
1501        bits = ['vertex_coordinates']
1502        for quantity in ['elevation']+domain.quantities_to_be_stored:
1503            bits.append('quantities["%s"].get_integral()'%quantity)
1504            bits.append('get_quantity("%s")'%quantity)
1505
1506        for bit in bits:
1507            #print 'testing that domain.'+bit+' has been restored'
1508            #print bit
1509        #print 'done'
1510            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1511
1512        ######################################
1513        #Now evolve them both, just to be sure
1514        ######################################x
1515        visualise = False
1516        #visualise = True
1517        domain.visualise = visualise
1518        domain.time = 0.
1519        from time import sleep
1520
1521        final = .1
1522        domain.set_quantity('friction', 0.1)
1523        domain.store = False
1524        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1525
1526
1527        for t in domain.evolve(yieldstep = yiel, finaltime = final):
1528            if visualise: sleep(1.)
1529            #domain.write_time()
1530            pass
1531
1532        final = final - (domain2.starttime-domain.starttime)
1533        #BUT since domain1 gets time hacked back to 0:
1534        final = final + (domain2.starttime-domain.starttime)
1535
1536        domain2.smooth = False
1537        domain2.visualise = visualise
1538        domain2.store = False
1539        domain2.default_order=2
1540        domain2.set_quantity('friction', 0.1)
1541        #Bed-slope and friction
1542        # Boundary conditions
1543        Bd2=Dirichlet_boundary([0.2,0.,0.])
1544        domain2.boundary = domain.boundary
1545        #print 'domain2.boundary'
1546        #print domain2.boundary
1547        domain2.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1548        #domain2.set_boundary({'exterior': Bd})
1549
1550        domain2.check_integrity()
1551
1552        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
1553            if visualise: sleep(1.)
1554            #domain2.write_time()
1555            pass
1556
1557        ###################
1558        ##NOW TEST IT!!!
1559        ##################
1560
1561        bits = [ 'vertex_coordinates']
1562
1563        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
1564            bits.append('quantities["%s"].get_integral()'%quantity)
1565            bits.append('get_quantity("%s")'%quantity)
1566
1567        for bit in bits:
1568            #print bit
1569            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1570
1571
1572    def test_sww2domain2(self):
1573        ##################################################################
1574        #Same as previous test, but this checks how NaNs are handled.
1575        ##################################################################
1576
1577
1578        from mesh_factory import rectangular
1579        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1580             Constant_height, Time_boundary, Transmissive_boundary
1581        from Numeric import array
1582
1583        #Create basic mesh
1584        points, vertices, boundary = rectangular(2,2)
1585
1586        #Create shallow water domain
1587        domain = Domain(points, vertices, boundary)
1588        domain.smooth = False
1589        domain.visualise = False
1590        domain.store = True
1591        domain.filename = 'bedslope'
1592        domain.default_order=2
1593        domain.quantities_to_be_stored=['stage']
1594
1595        domain.set_quantity('elevation', lambda x,y: -x/3)
1596        domain.set_quantity('friction', 0.1)
1597
1598        from math import sin, pi
1599        Br = Reflective_boundary(domain)
1600        Bt = Transmissive_boundary(domain)
1601        Bd = Dirichlet_boundary([0.2,0.,0.])
1602        Bw = Time_boundary(domain=domain,
1603                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1604
1605        domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1606
1607        h = 0.05
1608        elevation = domain.quantities['elevation'].vertex_values
1609        domain.set_quantity('stage', elevation + h)
1610
1611        domain.check_integrity()
1612
1613        for t in domain.evolve(yieldstep = 1, finaltime = 2.0):
1614            pass
1615            #domain.write_time()
1616
1617
1618
1619        ##################################
1620        #Import the file as a new domain
1621        ##################################
1622        from data_manager import sww2domain
1623        from Numeric import allclose
1624        import os
1625
1626        filename = domain.datadir+os.sep+domain.filename+'.sww'
1627
1628        #Fail because NaNs are present
1629        try:
1630            domain2 = sww2domain(filename,boundary,fail_if_NaN=True,verbose=False)
1631            assert True == False
1632        except:
1633            #Now import it, filling NaNs to be 0
1634            filler = 0
1635            domain2 = sww2domain(filename,None,fail_if_NaN=False,NaN_filler = filler,verbose=False)
1636        bits = [ 'geo_reference.get_xllcorner()',
1637                'geo_reference.get_yllcorner()',
1638                'vertex_coordinates']
1639
1640        for quantity in ['elevation']+domain.quantities_to_be_stored:
1641            bits.append('quantities["%s"].get_integral()'%quantity)
1642            bits.append('get_quantity("%s")'%quantity)
1643
1644        for bit in bits:
1645        #    print 'testing that domain.'+bit+' has been restored'
1646            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1647
1648        assert max(max(domain2.get_quantity('xmomentum')))==filler
1649        assert min(min(domain2.get_quantity('xmomentum')))==filler
1650        assert max(max(domain2.get_quantity('ymomentum')))==filler
1651        assert min(min(domain2.get_quantity('ymomentum')))==filler
1652
1653        #print 'passed'
1654
1655        #cleanup
1656        #import os
1657        #os.remove(domain.datadir+'/'+domain.filename+'.sww')
1658
1659
1660    #def test_weed(self):
1661        from data_manager import weed
1662
1663        coordinates1 = [[0.,0.],[1.,0.],[1.,1.],[1.,0.],[2.,0.],[1.,1.]]
1664        volumes1 = [[0,1,2],[3,4,5]]
1665        boundary1= {(0,1): 'external',(1,2): 'not external',(2,0): 'external',(3,4): 'external',(4,5): 'external',(5,3): 'not external'}
1666        coordinates2,volumes2,boundary2=weed(coordinates1,volumes1,boundary1)
1667
1668        points2 = {(0.,0.):None,(1.,0.):None,(1.,1.):None,(2.,0.):None}
1669
1670        assert len(points2)==len(coordinates2)
1671        for i in range(len(coordinates2)):
1672            coordinate = tuple(coordinates2[i])
1673            assert points2.has_key(coordinate)
1674            points2[coordinate]=i
1675
1676        for triangle in volumes1:
1677            for coordinate in triangle:
1678                assert coordinates2[points2[tuple(coordinates1[coordinate])]][0]==coordinates1[coordinate][0]
1679                assert coordinates2[points2[tuple(coordinates1[coordinate])]][1]==coordinates1[coordinate][1]
1680
1681
1682     #FIXME This fails - smooth makes the comparism too hard for allclose
1683    def ztest_sww2domain3(self):
1684        ################################################
1685        #DOMAIN.SMOOTH = TRUE !!!!!!!!!!!!!!!!!!!!!!!!!!
1686        ################################################
1687        from mesh_factory import rectangular
1688        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1689             Constant_height, Time_boundary, Transmissive_boundary
1690        from Numeric import array
1691        #Create basic mesh
1692
1693        yiel=0.01
1694        points, vertices, boundary = rectangular(10,10)
1695
1696        #Create shallow water domain
1697        domain = Domain(points, vertices, boundary)
1698        domain.geo_reference = Geo_reference(56,11,11)
1699        domain.smooth = True
1700        domain.visualise = False
1701        domain.store = True
1702        domain.filename = 'bedslope'
1703        domain.default_order=2
1704        #Bed-slope and friction
1705        domain.set_quantity('elevation', lambda x,y: -x/3)
1706        domain.set_quantity('friction', 0.1)
1707        # Boundary conditions
1708        from math import sin, pi
1709        Br = Reflective_boundary(domain)
1710        Bt = Transmissive_boundary(domain)
1711        Bd = Dirichlet_boundary([0.2,0.,0.])
1712        Bw = Time_boundary(domain=domain,
1713                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1714
1715        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1716
1717        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1718        #Initial condition
1719        h = 0.05
1720        elevation = domain.quantities['elevation'].vertex_values
1721        domain.set_quantity('stage', elevation + h)
1722        #elevation = domain.get_quantity('elevation')
1723        #domain.set_quantity('stage', elevation + h)
1724
1725        domain.check_integrity()
1726        #Evolution
1727        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1728        #    domain.write_time()
1729            pass
1730
1731
1732        ##########################################
1733        #Import the example's file as a new domain
1734        ##########################################
1735        from data_manager import sww2domain
1736        from Numeric import allclose
1737        import os
1738
1739        filename = domain.datadir+os.sep+domain.filename+'.sww'
1740        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
1741        #points, vertices, boundary = rectangular(15,15)
1742        #domain2.boundary = boundary
1743        ###################
1744        ##NOW TEST IT!!!
1745        ###################
1746
1747        #FIXME smooth domain so that they can be compared
1748
1749
1750        bits = []#'vertex_coordinates']
1751        for quantity in ['elevation']+domain.quantities_to_be_stored:
1752            bits.append('quantities["%s"].get_integral()'%quantity)
1753            #bits.append('get_quantity("%s")'%quantity)
1754
1755        for bit in bits:
1756            #print 'testing that domain.'+bit+' has been restored'
1757            #print bit
1758            #print 'done'
1759            #print ('domain.'+bit), eval('domain.'+bit)
1760            #print ('domain2.'+bit), eval('domain2.'+bit)
1761            assert allclose(eval('domain.'+bit),eval('domain2.'+bit),rtol=1.0e-1,atol=1.e-3)
1762            pass
1763
1764        ######################################
1765        #Now evolve them both, just to be sure
1766        ######################################x
1767        visualise = False
1768        visualise = True
1769        domain.visualise = visualise
1770        domain.time = 0.
1771        from time import sleep
1772
1773        final = .5
1774        domain.set_quantity('friction', 0.1)
1775        domain.store = False
1776        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Br})
1777
1778        for t in domain.evolve(yieldstep = yiel, finaltime = final):
1779            if visualise: sleep(.03)
1780            #domain.write_time()
1781            pass
1782
1783        domain2.smooth = True
1784        domain2.visualise = visualise
1785        domain2.store = False
1786        domain2.default_order=2
1787        domain2.set_quantity('friction', 0.1)
1788        #Bed-slope and friction
1789        # Boundary conditions
1790        Bd2=Dirichlet_boundary([0.2,0.,0.])
1791        Br2 = Reflective_boundary(domain2)
1792        domain2.boundary = domain.boundary
1793        #print 'domain2.boundary'
1794        #print domain2.boundary
1795        domain2.set_boundary({'left': Bd2, 'right': Bd2, 'top': Bd2, 'bottom': Br2})
1796        #domain2.boundary = domain.boundary
1797        #domain2.set_boundary({'exterior': Bd})
1798
1799        domain2.check_integrity()
1800
1801        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
1802            if visualise: sleep(.03)
1803            #domain2.write_time()
1804            pass
1805
1806        ###################
1807        ##NOW TEST IT!!!
1808        ##################
1809
1810        bits = [ 'vertex_coordinates']
1811
1812        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
1813            #bits.append('quantities["%s"].get_integral()'%quantity)
1814            bits.append('get_quantity("%s")'%quantity)
1815
1816        for bit in bits:
1817            print bit
1818            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1819
1820
1821    def test_decimate_dem(self):
1822        """Test decimation of dem file
1823        """
1824
1825        import os
1826        from Numeric import ones, allclose, Float, arange
1827        from Scientific.IO.NetCDF import NetCDFFile
1828
1829        #Write test dem file
1830        root = 'decdemtest'
1831
1832        filename = root + '.dem'
1833        fid = NetCDFFile(filename, 'w')
1834
1835        fid.institution = 'Geoscience Australia'
1836        fid.description = 'NetCDF DEM format for compact and portable ' +\
1837                          'storage of spatial point data'
1838
1839        nrows = 15
1840        ncols = 18
1841
1842        fid.ncols = ncols
1843        fid.nrows = nrows
1844        fid.xllcorner = 2000.5
1845        fid.yllcorner = 3000.5
1846        fid.cellsize = 25
1847        fid.NODATA_value = -9999
1848
1849        fid.zone = 56
1850        fid.false_easting = 0.0
1851        fid.false_northing = 0.0
1852        fid.projection = 'UTM'
1853        fid.datum = 'WGS84'
1854        fid.units = 'METERS'
1855
1856        fid.createDimension('number_of_points', nrows*ncols)
1857
1858        fid.createVariable('elevation', Float, ('number_of_points',))
1859
1860        elevation = fid.variables['elevation']
1861
1862        elevation[:] = (arange(nrows*ncols))
1863
1864        fid.close()
1865
1866        #generate the elevation values expected in the decimated file
1867        ref_elevation = [(  0+  1+  2+ 18+ 19+ 20+ 36+ 37+ 38) / 9.0,
1868                         (  4+  5+  6+ 22+ 23+ 24+ 40+ 41+ 42) / 9.0,
1869                         (  8+  9+ 10+ 26+ 27+ 28+ 44+ 45+ 46) / 9.0,
1870                         ( 12+ 13+ 14+ 30+ 31+ 32+ 48+ 49+ 50) / 9.0,
1871                         ( 72+ 73+ 74+ 90+ 91+ 92+108+109+110) / 9.0,
1872                         ( 76+ 77+ 78+ 94+ 95+ 96+112+113+114) / 9.0,
1873                         ( 80+ 81+ 82+ 98+ 99+100+116+117+118) / 9.0,
1874                         ( 84+ 85+ 86+102+103+104+120+121+122) / 9.0,
1875                         (144+145+146+162+163+164+180+181+182) / 9.0,
1876                         (148+149+150+166+167+168+184+185+186) / 9.0,
1877                         (152+153+154+170+171+172+188+189+190) / 9.0,
1878                         (156+157+158+174+175+176+192+193+194) / 9.0,
1879                         (216+217+218+234+235+236+252+253+254) / 9.0,
1880                         (220+221+222+238+239+240+256+257+258) / 9.0,
1881                         (224+225+226+242+243+244+260+261+262) / 9.0,
1882                         (228+229+230+246+247+248+264+265+266) / 9.0]
1883
1884        #generate a stencil for computing the decimated values
1885        stencil = ones((3,3), Float) / 9.0
1886
1887        decimate_dem(root, stencil=stencil, cellsize_new=100)
1888
1889        #Open decimated NetCDF file
1890        fid = NetCDFFile(root + '_100.dem', 'r')
1891
1892        # Get decimated elevation
1893        elevation = fid.variables['elevation']
1894
1895        #Check values
1896        assert allclose(elevation, ref_elevation)
1897
1898        #Cleanup
1899        fid.close()
1900
1901        os.remove(root + '.dem')
1902        os.remove(root + '_100.dem')
1903
1904    def test_decimate_dem_NODATA(self):
1905        """Test decimation of dem file that includes NODATA values
1906        """
1907
1908        import os
1909        from Numeric import ones, allclose, Float, arange, reshape
1910        from Scientific.IO.NetCDF import NetCDFFile
1911
1912        #Write test dem file
1913        root = 'decdemtest'
1914
1915        filename = root + '.dem'
1916        fid = NetCDFFile(filename, 'w')
1917
1918        fid.institution = 'Geoscience Australia'
1919        fid.description = 'NetCDF DEM format for compact and portable ' +\
1920                          'storage of spatial point data'
1921
1922        nrows = 15
1923        ncols = 18
1924        NODATA_value = -9999
1925
1926        fid.ncols = ncols
1927        fid.nrows = nrows
1928        fid.xllcorner = 2000.5
1929        fid.yllcorner = 3000.5
1930        fid.cellsize = 25
1931        fid.NODATA_value = NODATA_value
1932
1933        fid.zone = 56
1934        fid.false_easting = 0.0
1935        fid.false_northing = 0.0
1936        fid.projection = 'UTM'
1937        fid.datum = 'WGS84'
1938        fid.units = 'METERS'
1939
1940        fid.createDimension('number_of_points', nrows*ncols)
1941
1942        fid.createVariable('elevation', Float, ('number_of_points',))
1943
1944        elevation = fid.variables['elevation']
1945
1946        #generate initial elevation values
1947        elevation_tmp = (arange(nrows*ncols))
1948        #add some NODATA values
1949        elevation_tmp[0]   = NODATA_value
1950        elevation_tmp[95]  = NODATA_value
1951        elevation_tmp[188] = NODATA_value
1952        elevation_tmp[189] = NODATA_value
1953        elevation_tmp[190] = NODATA_value
1954        elevation_tmp[209] = NODATA_value
1955        elevation_tmp[252] = NODATA_value
1956
1957        elevation[:] = elevation_tmp
1958
1959        fid.close()
1960
1961        #generate the elevation values expected in the decimated file
1962        ref_elevation = [NODATA_value,
1963                         (  4+  5+  6+ 22+ 23+ 24+ 40+ 41+ 42) / 9.0,
1964                         (  8+  9+ 10+ 26+ 27+ 28+ 44+ 45+ 46) / 9.0,
1965                         ( 12+ 13+ 14+ 30+ 31+ 32+ 48+ 49+ 50) / 9.0,
1966                         ( 72+ 73+ 74+ 90+ 91+ 92+108+109+110) / 9.0,
1967                         NODATA_value,
1968                         ( 80+ 81+ 82+ 98+ 99+100+116+117+118) / 9.0,
1969                         ( 84+ 85+ 86+102+103+104+120+121+122) / 9.0,
1970                         (144+145+146+162+163+164+180+181+182) / 9.0,
1971                         (148+149+150+166+167+168+184+185+186) / 9.0,
1972                         NODATA_value,
1973                         (156+157+158+174+175+176+192+193+194) / 9.0,
1974                         NODATA_value,
1975                         (220+221+222+238+239+240+256+257+258) / 9.0,
1976                         (224+225+226+242+243+244+260+261+262) / 9.0,
1977                         (228+229+230+246+247+248+264+265+266) / 9.0]
1978
1979        #generate a stencil for computing the decimated values
1980        stencil = ones((3,3), Float) / 9.0
1981
1982        decimate_dem(root, stencil=stencil, cellsize_new=100)
1983
1984        #Open decimated NetCDF file
1985        fid = NetCDFFile(root + '_100.dem', 'r')
1986
1987        # Get decimated elevation
1988        elevation = fid.variables['elevation']
1989
1990        #Check values
1991        assert allclose(elevation, ref_elevation)
1992
1993        #Cleanup
1994        fid.close()
1995
1996        os.remove(root + '.dem')
1997        os.remove(root + '_100.dem')
1998
1999
2000#-------------------------------------------------------------
2001if __name__ == "__main__":
2002    suite = unittest.makeSuite(Test_Data_Manager,'test')
2003    #suite = unittest.makeSuite(Test_Data_Manager,'test_dem2pts_bounding_box')
2004    #suite = unittest.makeSuite(Test_Data_Manager,'test_decimate_dem')
2005    #suite = unittest.makeSuite(Test_Data_Manager,'test_decimate_dem_NODATA')
2006    runner = unittest.TextTestRunner()
2007    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.