source: inundation/pyvolution/test_data_manager.py @ 1884

Last change on this file since 1884 was 1875, checked in by ole, 18 years ago

Bounding box for sww2dem

File size: 72.9 KB
Line 
1#!/usr/bin/env python
2#
3
4import unittest
5import copy
6from Numeric import zeros, array, allclose, Float
7from util import mean
8
9from data_manager import *
10from shallow_water import *
11from config import epsilon
12
13from coordinate_transforms.geo_reference import Geo_reference
14
15class Test_Data_Manager(unittest.TestCase):
16    def setUp(self):
17        import time
18        from mesh_factory import rectangular
19
20        #Create basic mesh
21        points, vertices, boundary = rectangular(2, 2)
22
23        #Create shallow water domain
24        domain = Domain(points, vertices, boundary)
25        domain.default_order=2
26
27
28        #Set some field values
29        domain.set_quantity('elevation', lambda x,y: -x)
30        domain.set_quantity('friction', 0.03)
31
32
33        ######################
34        # Boundary conditions
35        B = Transmissive_boundary(domain)
36        domain.set_boundary( {'left': B, 'right': B, 'top': B, 'bottom': B})
37
38
39        ######################
40        #Initial condition - with jumps
41
42
43        bed = domain.quantities['elevation'].vertex_values
44        stage = zeros(bed.shape, Float)
45
46        h = 0.3
47        for i in range(stage.shape[0]):
48            if i % 2 == 0:
49                stage[i,:] = bed[i,:] + h
50            else:
51                stage[i,:] = bed[i,:]
52
53        domain.set_quantity('stage', stage)
54        self.initial_stage = copy.copy(domain.quantities['stage'].vertex_values)
55
56        domain.distribute_to_vertices_and_edges()
57
58
59        self.domain = domain
60
61        C = domain.get_vertex_coordinates()
62        self.X = C[:,0:6:2].copy()
63        self.Y = C[:,1:6:2].copy()
64
65        self.F = bed
66
67
68    def tearDown(self):
69        pass
70
71
72
73
74#     def test_xya(self):
75#         import os
76#         from Numeric import concatenate
77
78#         import time, os
79#         from Numeric import array, zeros, allclose, Float, concatenate
80
81#         domain = self.domain
82
83#         domain.filename = 'datatest' + str(time.time())
84#         domain.format = 'xya'
85#         domain.smooth = True
86
87#         xya = get_dataobject(self.domain)
88#         xya.store_all()
89
90
91#         #Read back
92#         file = open(xya.filename)
93#         lFile = file.read().split('\n')
94#         lFile = lFile[:-1]
95
96#         file.close()
97#         os.remove(xya.filename)
98
99#         #Check contents
100#         if domain.smooth:
101#             self.failUnless(lFile[0] == '9 3 # <vertex #> <x> <y> [attributes]')
102#         else:
103#             self.failUnless(lFile[0] == '24 3 # <vertex #> <x> <y> [attributes]')
104
105#         #Get smoothed field values with X and Y
106#         X,Y,F,V = domain.get_vertex_values(xy=True, value_array='field_values',
107#                                            indices = (0,1), precision = Float)
108
109
110#         Q,V = domain.get_vertex_values(xy=False, value_array='conserved_quantities',
111#                                            indices = (0,), precision = Float)
112
113
114
115#         for i, line in enumerate(lFile[1:]):
116#             fields = line.split()
117
118#             assert len(fields) == 5
119
120#             assert allclose(float(fields[0]), X[i])
121#             assert allclose(float(fields[1]), Y[i])
122#             assert allclose(float(fields[2]), F[i,0])
123#             assert allclose(float(fields[3]), Q[i,0])
124#             assert allclose(float(fields[4]), F[i,1])
125
126
127
128
129    def test_sww_constant(self):
130        """Test that constant sww information can be written correctly
131        (non smooth)
132        """
133
134        import time, os
135        from Numeric import array, zeros, allclose, Float, concatenate
136        from Scientific.IO.NetCDF import NetCDFFile
137
138        self.domain.filename = 'datatest' + str(id(self))
139        self.domain.format = 'sww'
140        self.domain.smooth = False
141
142        sww = get_dataobject(self.domain)
143        sww.store_connectivity()
144
145        #Check contents
146        #Get NetCDF
147        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
148
149        # Get the variables
150        x = fid.variables['x']
151        y = fid.variables['y']
152        z = fid.variables['elevation']
153
154        volumes = fid.variables['volumes']
155
156
157        assert allclose (x[:], self.X.flat)
158        assert allclose (y[:], self.Y.flat)
159        assert allclose (z[:], self.F.flat)
160
161        V = volumes
162
163        P = len(self.domain)
164        for k in range(P):
165            assert V[k, 0] == 3*k
166            assert V[k, 1] == 3*k+1
167            assert V[k, 2] == 3*k+2
168
169
170        fid.close()
171
172        #Cleanup
173        os.remove(sww.filename)
174
175
176    def test_sww_constant_smooth(self):
177        """Test that constant sww information can be written correctly
178        (non smooth)
179        """
180
181        import time, os
182        from Numeric import array, zeros, allclose, Float, concatenate
183        from Scientific.IO.NetCDF import NetCDFFile
184
185        self.domain.filename = 'datatest' + str(id(self))
186        self.domain.format = 'sww'
187        self.domain.smooth = True
188
189        sww = get_dataobject(self.domain)
190        sww.store_connectivity()
191
192        #Check contents
193        #Get NetCDF
194        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
195
196        # Get the variables
197        x = fid.variables['x']
198        y = fid.variables['y']
199        z = fid.variables['elevation']
200
201        volumes = fid.variables['volumes']
202
203        X = x[:]
204        Y = y[:]
205
206        assert allclose([X[0], Y[0]], array([0.0, 0.0]))
207        assert allclose([X[1], Y[1]], array([0.0, 0.5]))
208        assert allclose([X[2], Y[2]], array([0.0, 1.0]))
209
210        assert allclose([X[4], Y[4]], array([0.5, 0.5]))
211
212        assert allclose([X[7], Y[7]], array([1.0, 0.5]))
213
214        Z = z[:]
215        assert Z[4] == -0.5
216
217        V = volumes
218        assert V[2,0] == 4
219        assert V[2,1] == 5
220        assert V[2,2] == 1
221
222        assert V[4,0] == 6
223        assert V[4,1] == 7
224        assert V[4,2] == 3
225
226
227        fid.close()
228
229        #Cleanup
230        os.remove(sww.filename)
231
232
233
234    def test_sww_variable(self):
235        """Test that sww information can be written correctly
236        """
237
238        import time, os
239        from Numeric import array, zeros, allclose, Float, concatenate
240        from Scientific.IO.NetCDF import NetCDFFile
241
242        self.domain.filename = 'datatest' + str(id(self))
243        self.domain.format = 'sww'
244        self.domain.smooth = True
245        self.domain.reduction = mean
246
247        sww = get_dataobject(self.domain)
248        sww.store_connectivity()
249        sww.store_timestep('stage')
250
251        #Check contents
252        #Get NetCDF
253        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
254
255
256        # Get the variables
257        x = fid.variables['x']
258        y = fid.variables['y']
259        z = fid.variables['elevation']
260        time = fid.variables['time']
261        stage = fid.variables['stage']
262
263
264        Q = self.domain.quantities['stage']
265        Q0 = Q.vertex_values[:,0]
266        Q1 = Q.vertex_values[:,1]
267        Q2 = Q.vertex_values[:,2]
268
269        A = stage[0,:]
270        #print A[0], (Q2[0,0] + Q1[1,0])/2
271        assert allclose(A[0], (Q2[0] + Q1[1])/2)
272        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
273        assert allclose(A[2], Q0[3])
274        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
275
276        #Center point
277        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
278                                 Q0[5] + Q2[6] + Q1[7])/6)
279
280
281
282        fid.close()
283
284        #Cleanup
285        os.remove(sww.filename)
286
287
288    def test_sww_variable2(self):
289        """Test that sww information can be written correctly
290        multiple timesteps. Use average as reduction operator
291        """
292
293        import time, os
294        from Numeric import array, zeros, allclose, Float, concatenate
295        from Scientific.IO.NetCDF import NetCDFFile
296
297        self.domain.filename = 'datatest' + str(id(self))
298        self.domain.format = 'sww'
299        self.domain.smooth = True
300
301        self.domain.reduction = mean
302
303        sww = get_dataobject(self.domain)
304        sww.store_connectivity()
305        sww.store_timestep('stage')
306        self.domain.evolve_to_end(finaltime = 0.01)
307        sww.store_timestep('stage')
308
309
310        #Check contents
311        #Get NetCDF
312        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
313
314        # Get the variables
315        x = fid.variables['x']
316        y = fid.variables['y']
317        z = fid.variables['elevation']
318        time = fid.variables['time']
319        stage = fid.variables['stage']
320
321        #Check values
322        Q = self.domain.quantities['stage']
323        Q0 = Q.vertex_values[:,0]
324        Q1 = Q.vertex_values[:,1]
325        Q2 = Q.vertex_values[:,2]
326
327        A = stage[1,:]
328        assert allclose(A[0], (Q2[0] + Q1[1])/2)
329        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
330        assert allclose(A[2], Q0[3])
331        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
332
333        #Center point
334        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
335                                 Q0[5] + Q2[6] + Q1[7])/6)
336
337
338        fid.close()
339
340        #Cleanup
341        os.remove(sww.filename)
342
343    def test_sww_variable3(self):
344        """Test that sww information can be written correctly
345        multiple timesteps using a different reduction operator (min)
346        """
347
348        import time, os
349        from Numeric import array, zeros, allclose, Float, concatenate
350        from Scientific.IO.NetCDF import NetCDFFile
351
352        self.domain.filename = 'datatest' + str(id(self))
353        self.domain.format = 'sww'
354        self.domain.smooth = True
355        self.domain.reduction = min
356
357        sww = get_dataobject(self.domain)
358        sww.store_connectivity()
359        sww.store_timestep('stage')
360
361        self.domain.evolve_to_end(finaltime = 0.01)
362        sww.store_timestep('stage')
363
364
365        #Check contents
366        #Get NetCDF
367        fid = NetCDFFile(sww.filename, 'r')
368
369
370        # Get the variables
371        x = fid.variables['x']
372        y = fid.variables['y']
373        z = fid.variables['elevation']
374        time = fid.variables['time']
375        stage = fid.variables['stage']
376
377        #Check values
378        Q = self.domain.quantities['stage']
379        Q0 = Q.vertex_values[:,0]
380        Q1 = Q.vertex_values[:,1]
381        Q2 = Q.vertex_values[:,2]
382
383        A = stage[1,:]
384        assert allclose(A[0], min(Q2[0], Q1[1]))
385        assert allclose(A[1], min(Q0[1], Q1[3], Q2[2]))
386        assert allclose(A[2], Q0[3])
387        assert allclose(A[3], min(Q0[0], Q1[5], Q2[4]))
388
389        #Center point
390        assert allclose(A[4], min(Q1[0], Q2[1], Q0[2],\
391                                  Q0[5], Q2[6], Q1[7]))
392
393
394        fid.close()
395
396        #Cleanup
397        os.remove(sww.filename)
398
399
400    def test_sync(self):
401        """Test info stored at each timestep is as expected (incl initial condition)
402        """
403
404        import time, os, config
405        from Numeric import array, zeros, allclose, Float, concatenate
406        from Scientific.IO.NetCDF import NetCDFFile
407
408        self.domain.filename = 'synctest'
409        self.domain.format = 'sww'
410        self.domain.smooth = False
411        self.domain.store = True
412        self.domain.beta_h = 0
413
414        #Evolution
415        for t in self.domain.evolve(yieldstep = 1.0, finaltime = 4.0):
416            stage = self.domain.quantities['stage'].vertex_values
417
418            #Get NetCDF
419            fid = NetCDFFile(self.domain.writer.filename, 'r')
420            stage_file = fid.variables['stage']
421
422            if t == 0.0:
423                assert allclose(stage, self.initial_stage)
424                assert allclose(stage_file[:], stage.flat)
425            else:
426                assert not allclose(stage, self.initial_stage)
427                assert not allclose(stage_file[:], stage.flat)
428
429            fid.close()
430
431        os.remove(self.domain.writer.filename)
432
433
434
435    def test_sww_DSG(self):
436        """Not a test, rather a look at the sww format
437        """
438
439        import time, os
440        from Numeric import array, zeros, allclose, Float, concatenate
441        from Scientific.IO.NetCDF import NetCDFFile
442
443        self.domain.filename = 'datatest' + str(id(self))
444        self.domain.format = 'sww'
445        self.domain.smooth = True
446        self.domain.reduction = mean
447
448        sww = get_dataobject(self.domain)
449        sww.store_connectivity()
450        sww.store_timestep('stage')
451
452        #Check contents
453        #Get NetCDF
454        fid = NetCDFFile(sww.filename, 'r')
455
456        # Get the variables
457        x = fid.variables['x']
458        y = fid.variables['y']
459        z = fid.variables['elevation']
460
461        volumes = fid.variables['volumes']
462        time = fid.variables['time']
463
464        # 2D
465        stage = fid.variables['stage']
466
467        X = x[:]
468        Y = y[:]
469        Z = z[:]
470        V = volumes[:]
471        T = time[:]
472        S = stage[:,:]
473
474#         print "****************************"
475#         print "X ",X
476#         print "****************************"
477#         print "Y ",Y
478#         print "****************************"
479#         print "Z ",Z
480#         print "****************************"
481#         print "V ",V
482#         print "****************************"
483#         print "Time ",T
484#         print "****************************"
485#         print "Stage ",S
486#         print "****************************"
487
488
489        fid.close()
490
491        #Cleanup
492        os.remove(sww.filename)
493
494
495    def test_write_pts(self):
496        #Get (enough) datapoints
497
498        from Numeric import array
499        points = array([[ 0.66666667, 0.66666667],
500                        [ 1.33333333, 1.33333333],
501                        [ 2.66666667, 0.66666667],
502                        [ 0.66666667, 2.66666667],
503                        [ 0.0, 1.0],
504                        [ 0.0, 3.0],
505                        [ 1.0, 0.0],
506                        [ 1.0, 1.0],
507                        [ 1.0, 2.0],
508                        [ 1.0, 3.0],
509                        [ 2.0, 1.0],
510                        [ 3.0, 0.0],
511                        [ 3.0, 1.0]])
512
513        z = points[:,0] + 2*points[:,1]
514
515        ptsfile = 'testptsfile.pts'
516        write_ptsfile(ptsfile, points, z,
517                      attribute_name = 'linear_combination')
518
519        #Check contents
520        #Get NetCDF
521        from Scientific.IO.NetCDF import NetCDFFile       
522        fid = NetCDFFile(ptsfile, 'r')
523
524        # Get the variables
525        #print fid.variables.keys()
526        points1 = fid.variables['points']
527        z1 = fid.variables['linear_combination']
528
529        #Check values
530
531        #print points[:]
532        #print ref_points
533        assert allclose(points, points1)
534
535        #print attributes[:]
536        #print ref_elevation
537        assert allclose(z, z1)
538
539        #Cleanup
540        fid.close()
541
542        import os
543        os.remove(ptsfile)
544       
545
546
547
548    def test_dem2pts(self):
549        """Test conversion from dem in ascii format to native NetCDF xya format
550        """
551
552        import time, os
553        from Numeric import array, zeros, allclose, Float, concatenate
554        from Scientific.IO.NetCDF import NetCDFFile
555
556        #Write test asc file
557        root = 'demtest'
558
559        filename = root+'.asc'
560        fid = open(filename, 'w')
561        fid.write("""ncols         5
562nrows         6
563xllcorner     2000.5
564yllcorner     3000.5
565cellsize      25
566NODATA_value  -9999
567""")
568        #Create linear function
569
570        ref_points = []
571        ref_elevation = []
572        for i in range(6):
573            y = (6-i)*25.0
574            for j in range(5):
575                x = j*25.0
576                z = x+2*y
577
578                ref_points.append( [x,y] )
579                ref_elevation.append(z)
580                fid.write('%f ' %z)
581            fid.write('\n')
582
583        fid.close()
584
585        #Write prj file with metadata
586        metafilename = root+'.prj'
587        fid = open(metafilename, 'w')
588
589
590        fid.write("""Projection UTM
591Zone 56
592Datum WGS84
593Zunits NO
594Units METERS
595Spheroid WGS84
596Xshift 0.0000000000
597Yshift 10000000.0000000000
598Parameters
599""")
600        fid.close()
601
602        #Convert to NetCDF pts
603        convert_dem_from_ascii2netcdf(root)
604        dem2pts(root)
605
606        #Check contents
607        #Get NetCDF
608        fid = NetCDFFile(root+'.pts', 'r')
609
610        # Get the variables
611        #print fid.variables.keys()
612        points = fid.variables['points']
613        elevation = fid.variables['elevation']
614
615        #Check values
616
617        #print points[:]
618        #print ref_points
619        assert allclose(points, ref_points)
620
621        #print attributes[:]
622        #print ref_elevation
623        assert allclose(elevation, ref_elevation)
624
625        #Cleanup
626        fid.close()
627
628
629        os.remove(root + '.pts')
630        os.remove(root + '.dem')
631        os.remove(root + '.asc')
632        os.remove(root + '.prj')
633
634
635
636    def test_dem2pts_bounding_box(self):
637        """Test conversion from dem in ascii format to native NetCDF xya format
638        """
639
640        import time, os
641        from Numeric import array, zeros, allclose, Float, concatenate
642        from Scientific.IO.NetCDF import NetCDFFile
643
644        #Write test asc file
645        root = 'demtest'
646
647        filename = root+'.asc'
648        fid = open(filename, 'w')
649        fid.write("""ncols         5
650nrows         6
651xllcorner     2000.5
652yllcorner     3000.5
653cellsize      25
654NODATA_value  -9999
655""")
656        #Create linear function
657
658        ref_points = []
659        ref_elevation = []
660        for i in range(6):
661            y = (6-i)*25.0
662            for j in range(5):
663                x = j*25.0
664                z = x+2*y
665
666                ref_points.append( [x,y] )
667                ref_elevation.append(z)
668                fid.write('%f ' %z)
669            fid.write('\n')
670
671        fid.close()
672
673        #Write prj file with metadata
674        metafilename = root+'.prj'
675        fid = open(metafilename, 'w')
676
677
678        fid.write("""Projection UTM
679Zone 56
680Datum WGS84
681Zunits NO
682Units METERS
683Spheroid WGS84
684Xshift 0.0000000000
685Yshift 10000000.0000000000
686Parameters
687""")
688        fid.close()
689
690        #Convert to NetCDF pts
691        convert_dem_from_ascii2netcdf(root)
692        dem2pts(root, easting_min=2010.0, easting_max=2110.0,
693                northing_min=3035.0, northing_max=3125.5)
694
695        #Check contents
696        #Get NetCDF
697        fid = NetCDFFile(root+'.pts', 'r')
698
699        # Get the variables
700        #print fid.variables.keys()
701        points = fid.variables['points']
702        elevation = fid.variables['elevation']
703
704        #Check values
705        assert fid.xllcorner[0] == 2010.0
706        assert fid.yllcorner[0] == 3035.0
707
708        #create new reference points
709        ref_points = []
710        ref_elevation = []
711        for i in range(4):
712            y = (4-i)*25.0 + 25.0
713            y_new = y + 3000.5 - 3035.0
714            for j in range(4):
715                x = j*25.0 + 25.0
716                x_new = x + 2000.5 - 2010.0
717                z = x+2*y
718
719                ref_points.append( [x_new,y_new] )
720                ref_elevation.append(z)
721
722        #print points[:]
723        #print ref_points
724        assert allclose(points, ref_points)
725
726        #print attributes[:]
727        #print ref_elevation
728        assert allclose(elevation, ref_elevation)
729
730        #Cleanup
731        fid.close()
732
733
734        os.remove(root + '.pts')
735        os.remove(root + '.dem')
736        os.remove(root + '.asc')
737        os.remove(root + '.prj')
738
739
740
741    def test_sww2dem_asc_elevation(self):
742        """Test that sww information can be converted correctly to asc/prj
743        format readable by e.g. ArcView
744        """
745
746        import time, os
747        from Numeric import array, zeros, allclose, Float, concatenate
748        from Scientific.IO.NetCDF import NetCDFFile
749
750        #Setup
751        self.domain.filename = 'datatest'
752
753        prjfile = self.domain.filename + '_elevation.prj'
754        ascfile = self.domain.filename + '_elevation.asc'
755        swwfile = self.domain.filename + '.sww'
756
757        self.domain.set_datadir('.')
758        self.domain.format = 'sww'
759        self.domain.smooth = True
760        self.domain.set_quantity('elevation', lambda x,y: -x-y)
761
762        self.domain.geo_reference = Geo_reference(56,308500,6189000)
763
764        sww = get_dataobject(self.domain)
765        sww.store_connectivity()
766        sww.store_timestep('stage')
767
768        self.domain.evolve_to_end(finaltime = 0.01)
769        sww.store_timestep('stage')
770
771        cellsize = 0.25
772        #Check contents
773        #Get NetCDF
774
775        fid = NetCDFFile(sww.filename, 'r')
776
777        # Get the variables
778        x = fid.variables['x'][:]
779        y = fid.variables['y'][:]
780        z = fid.variables['elevation'][:]
781        time = fid.variables['time'][:]
782        stage = fid.variables['stage'][:]
783
784
785        #Export to ascii/prj files
786        sww2dem(self.domain.filename,
787                quantity = 'elevation',
788                cellsize = cellsize,
789                verbose = False,
790                format = 'asc')
791               
792        #Check prj (meta data)
793        prjid = open(prjfile)
794        lines = prjid.readlines()
795        prjid.close()
796
797        L = lines[0].strip().split()
798        assert L[0].strip().lower() == 'projection'
799        assert L[1].strip().lower() == 'utm'
800
801        L = lines[1].strip().split()
802        assert L[0].strip().lower() == 'zone'
803        assert L[1].strip().lower() == '56'
804
805        L = lines[2].strip().split()
806        assert L[0].strip().lower() == 'datum'
807        assert L[1].strip().lower() == 'wgs84'
808
809        L = lines[3].strip().split()
810        assert L[0].strip().lower() == 'zunits'
811        assert L[1].strip().lower() == 'no'
812
813        L = lines[4].strip().split()
814        assert L[0].strip().lower() == 'units'
815        assert L[1].strip().lower() == 'meters'
816
817        L = lines[5].strip().split()
818        assert L[0].strip().lower() == 'spheroid'
819        assert L[1].strip().lower() == 'wgs84'
820
821        L = lines[6].strip().split()
822        assert L[0].strip().lower() == 'xshift'
823        assert L[1].strip().lower() == '500000'
824
825        L = lines[7].strip().split()
826        assert L[0].strip().lower() == 'yshift'
827        assert L[1].strip().lower() == '10000000'
828
829        L = lines[8].strip().split()
830        assert L[0].strip().lower() == 'parameters'
831
832
833        #Check asc file
834        ascid = open(ascfile)
835        lines = ascid.readlines()
836        ascid.close()
837
838        L = lines[0].strip().split()
839        assert L[0].strip().lower() == 'ncols'
840        assert L[1].strip().lower() == '5'
841
842        L = lines[1].strip().split()
843        assert L[0].strip().lower() == 'nrows'
844        assert L[1].strip().lower() == '5'
845
846        L = lines[2].strip().split()
847        assert L[0].strip().lower() == 'xllcorner'
848        assert allclose(float(L[1].strip().lower()), 308500)
849
850        L = lines[3].strip().split()
851        assert L[0].strip().lower() == 'yllcorner'
852        assert allclose(float(L[1].strip().lower()), 6189000)
853
854        L = lines[4].strip().split()
855        assert L[0].strip().lower() == 'cellsize'
856        assert allclose(float(L[1].strip().lower()), cellsize)
857
858        L = lines[5].strip().split()
859        assert L[0].strip() == 'NODATA_value'
860        assert L[1].strip().lower() == '-9999'
861
862        #Check grid values
863        for j in range(5):
864            L = lines[6+j].strip().split()
865            y = (4-j) * cellsize
866            for i in range(5):
867                assert allclose(float(L[i]), -i*cellsize - y)
868
869
870        fid.close()
871
872        #Cleanup
873        os.remove(prjfile)
874        os.remove(ascfile)
875        os.remove(swwfile)
876
877
878
879    def test_sww2dem_larger(self):
880        """Test that sww information can be converted correctly to asc/prj
881        format readable by e.g. ArcView. Here:
882
883        ncols         11
884        nrows         11
885        xllcorner     308500
886        yllcorner     6189000
887        cellsize      10.000000
888        NODATA_value  -9999
889        -100 -110 -120 -130 -140 -150 -160 -170 -180 -190 -200
890         -90 -100 -110 -120 -130 -140 -150 -160 -170 -180 -190
891         -80  -90 -100 -110 -120 -130 -140 -150 -160 -170 -180
892         -70  -80  -90 -100 -110 -120 -130 -140 -150 -160 -170
893         -60  -70  -80  -90 -100 -110 -120 -130 -140 -150 -160
894         -50  -60  -70  -80  -90 -100 -110 -120 -130 -140 -150
895         -40  -50  -60  -70  -80  -90 -100 -110 -120 -130 -140
896         -30  -40  -50  -60  -70  -80  -90 -100 -110 -120 -130
897         -20  -30  -40  -50  -60  -70  -80  -90 -100 -110 -120
898         -10  -20  -30  -40  -50  -60  -70  -80  -90 -100 -110
899           0  -10  -20  -30  -40  -50  -60  -70  -80  -90 -100
900       
901        """
902
903        import time, os
904        from Numeric import array, zeros, allclose, Float, concatenate
905        from Scientific.IO.NetCDF import NetCDFFile
906
907        #Setup
908
909        from mesh_factory import rectangular
910
911        #Create basic mesh (100m x 100m)
912        points, vertices, boundary = rectangular(2, 2, 100, 100)
913
914        #Create shallow water domain
915        domain = Domain(points, vertices, boundary)
916        domain.default_order = 2
917
918        domain.filename = 'datatest'
919
920        prjfile = domain.filename + '_elevation.prj'
921        ascfile = domain.filename + '_elevation.asc'
922        swwfile = domain.filename + '.sww'
923
924        domain.set_datadir('.')
925        domain.format = 'sww'
926        domain.smooth = True
927        domain.geo_reference = Geo_reference(56, 308500, 6189000)
928
929        #
930        domain.set_quantity('elevation', lambda x,y: -x-y)
931        domain.set_quantity('stage', 0)       
932
933        B = Transmissive_boundary(domain)
934        domain.set_boundary( {'left': B, 'right': B, 'top': B, 'bottom': B})
935
936
937        #
938        sww = get_dataobject(domain)
939        sww.store_connectivity()
940        sww.store_timestep('stage')
941
942        domain.evolve_to_end(finaltime = 0.01)
943        sww.store_timestep('stage')
944
945        cellsize = 10  #10m grid
946
947       
948        #Check contents
949        #Get NetCDF
950
951        fid = NetCDFFile(sww.filename, 'r')
952
953        # Get the variables
954        x = fid.variables['x'][:]
955        y = fid.variables['y'][:]
956        z = fid.variables['elevation'][:]
957        time = fid.variables['time'][:]
958        stage = fid.variables['stage'][:]
959
960
961        #Export to ascii/prj files
962        sww2dem(domain.filename,
963                quantity = 'elevation',
964                cellsize = cellsize,
965                verbose = False,
966                format = 'asc')
967       
968               
969        #Check prj (meta data)
970        prjid = open(prjfile)
971        lines = prjid.readlines()
972        prjid.close()
973
974        L = lines[0].strip().split()
975        assert L[0].strip().lower() == 'projection'
976        assert L[1].strip().lower() == 'utm'
977
978        L = lines[1].strip().split()
979        assert L[0].strip().lower() == 'zone'
980        assert L[1].strip().lower() == '56'
981
982        L = lines[2].strip().split()
983        assert L[0].strip().lower() == 'datum'
984        assert L[1].strip().lower() == 'wgs84'
985
986        L = lines[3].strip().split()
987        assert L[0].strip().lower() == 'zunits'
988        assert L[1].strip().lower() == 'no'
989
990        L = lines[4].strip().split()
991        assert L[0].strip().lower() == 'units'
992        assert L[1].strip().lower() == 'meters'
993
994        L = lines[5].strip().split()
995        assert L[0].strip().lower() == 'spheroid'
996        assert L[1].strip().lower() == 'wgs84'
997
998        L = lines[6].strip().split()
999        assert L[0].strip().lower() == 'xshift'
1000        assert L[1].strip().lower() == '500000'
1001
1002        L = lines[7].strip().split()
1003        assert L[0].strip().lower() == 'yshift'
1004        assert L[1].strip().lower() == '10000000'
1005
1006        L = lines[8].strip().split()
1007        assert L[0].strip().lower() == 'parameters'
1008
1009
1010        #Check asc file
1011        ascid = open(ascfile)
1012        lines = ascid.readlines()
1013        ascid.close()
1014
1015        L = lines[0].strip().split()
1016        assert L[0].strip().lower() == 'ncols'
1017        assert L[1].strip().lower() == '11'
1018
1019        L = lines[1].strip().split()
1020        assert L[0].strip().lower() == 'nrows'
1021        assert L[1].strip().lower() == '11'
1022
1023        L = lines[2].strip().split()
1024        assert L[0].strip().lower() == 'xllcorner'
1025        assert allclose(float(L[1].strip().lower()), 308500)
1026
1027        L = lines[3].strip().split()
1028        assert L[0].strip().lower() == 'yllcorner'
1029        assert allclose(float(L[1].strip().lower()), 6189000)
1030
1031        L = lines[4].strip().split()
1032        assert L[0].strip().lower() == 'cellsize'
1033        assert allclose(float(L[1].strip().lower()), cellsize)
1034
1035        L = lines[5].strip().split()
1036        assert L[0].strip() == 'NODATA_value'
1037        assert L[1].strip().lower() == '-9999'
1038
1039        #Check grid values (FIXME: Use same strategy for other sww2dem tests)
1040        for i, line in enumerate(lines[6:]):
1041            for j, value in enumerate( line.split() ):
1042                #assert allclose(float(value), -(10-i+j)*cellsize)
1043                assert float(value) == -(10-i+j)*cellsize
1044       
1045
1046        fid.close()
1047
1048        #Cleanup
1049        os.remove(prjfile)
1050        os.remove(ascfile)
1051        os.remove(swwfile)
1052
1053
1054
1055    def test_sww2dem_boundingbox(self):
1056        """Test that sww information can be converted correctly to asc/prj
1057        format readable by e.g. ArcView.
1058        This will test that mesh can be restricted by bounding box
1059
1060        Original extent is 100m x 100m:
1061
1062        Eastings:   308500 -  308600
1063        Northings: 6189000 - 6189100
1064
1065        Bounding box changes this to the 50m x 50m square defined by
1066
1067        Eastings:   308530 -  308570
1068        Northings: 6189050 - 6189100
1069
1070        The cropped values should be
1071
1072         -130 -140 -150 -160 -170 
1073         -120 -130 -140 -150 -160 
1074         -110 -120 -130 -140 -150 
1075         -100 -110 -120 -130 -140 
1076          -90 -100 -110 -120 -130 
1077          -80  -90 -100 -110 -120
1078
1079        and the new lower reference point should be   
1080        Eastings:   308530
1081        Northings: 6189050
1082       
1083        Original dataset is the same as in test_sww2dem_larger()
1084       
1085        """
1086
1087        import time, os
1088        from Numeric import array, zeros, allclose, Float, concatenate
1089        from Scientific.IO.NetCDF import NetCDFFile
1090
1091        #Setup
1092
1093        from mesh_factory import rectangular
1094
1095        #Create basic mesh (100m x 100m)
1096        points, vertices, boundary = rectangular(2, 2, 100, 100)
1097
1098        #Create shallow water domain
1099        domain = Domain(points, vertices, boundary)
1100        domain.default_order = 2
1101
1102        domain.filename = 'datatest'
1103
1104        prjfile = domain.filename + '_elevation.prj'
1105        ascfile = domain.filename + '_elevation.asc'
1106        swwfile = domain.filename + '.sww'
1107
1108        domain.set_datadir('.')
1109        domain.format = 'sww'
1110        domain.smooth = True
1111        domain.geo_reference = Geo_reference(56, 308500, 6189000)
1112
1113        #
1114        domain.set_quantity('elevation', lambda x,y: -x-y)
1115        domain.set_quantity('stage', 0)       
1116
1117        B = Transmissive_boundary(domain)
1118        domain.set_boundary( {'left': B, 'right': B, 'top': B, 'bottom': B})
1119
1120
1121        #
1122        sww = get_dataobject(domain)
1123        sww.store_connectivity()
1124        sww.store_timestep('stage')
1125
1126        domain.evolve_to_end(finaltime = 0.01)
1127        sww.store_timestep('stage')
1128
1129        cellsize = 10  #10m grid
1130
1131       
1132        #Check contents
1133        #Get NetCDF
1134
1135        fid = NetCDFFile(sww.filename, 'r')
1136
1137        # Get the variables
1138        x = fid.variables['x'][:]
1139        y = fid.variables['y'][:]
1140        z = fid.variables['elevation'][:]
1141        time = fid.variables['time'][:]
1142        stage = fid.variables['stage'][:]
1143
1144
1145        #Export to ascii/prj files
1146        sww2dem(domain.filename,
1147                quantity = 'elevation',
1148                cellsize = cellsize,
1149                easting_min = 308530,
1150                easting_max = 308570,
1151                northing_min = 6189050, 
1152                northing_max = 6189100,                               
1153                verbose = False,
1154                format = 'asc')
1155       
1156        fid.close()
1157
1158       
1159        #Check prj (meta data)
1160        prjid = open(prjfile)
1161        lines = prjid.readlines()
1162        prjid.close()
1163
1164        L = lines[0].strip().split()
1165        assert L[0].strip().lower() == 'projection'
1166        assert L[1].strip().lower() == 'utm'
1167
1168        L = lines[1].strip().split()
1169        assert L[0].strip().lower() == 'zone'
1170        assert L[1].strip().lower() == '56'
1171
1172        L = lines[2].strip().split()
1173        assert L[0].strip().lower() == 'datum'
1174        assert L[1].strip().lower() == 'wgs84'
1175
1176        L = lines[3].strip().split()
1177        assert L[0].strip().lower() == 'zunits'
1178        assert L[1].strip().lower() == 'no'
1179
1180        L = lines[4].strip().split()
1181        assert L[0].strip().lower() == 'units'
1182        assert L[1].strip().lower() == 'meters'
1183
1184        L = lines[5].strip().split()
1185        assert L[0].strip().lower() == 'spheroid'
1186        assert L[1].strip().lower() == 'wgs84'
1187
1188        L = lines[6].strip().split()
1189        assert L[0].strip().lower() == 'xshift'
1190        assert L[1].strip().lower() == '500000'
1191
1192        L = lines[7].strip().split()
1193        assert L[0].strip().lower() == 'yshift'
1194        assert L[1].strip().lower() == '10000000'
1195
1196        L = lines[8].strip().split()
1197        assert L[0].strip().lower() == 'parameters'
1198
1199
1200        #Check asc file
1201        ascid = open(ascfile)
1202        lines = ascid.readlines()
1203        ascid.close()
1204
1205        L = lines[0].strip().split()
1206        assert L[0].strip().lower() == 'ncols'
1207        assert L[1].strip().lower() == '5'
1208
1209        L = lines[1].strip().split()
1210        assert L[0].strip().lower() == 'nrows'
1211        assert L[1].strip().lower() == '6'
1212
1213        L = lines[2].strip().split()
1214        assert L[0].strip().lower() == 'xllcorner'
1215        assert allclose(float(L[1].strip().lower()), 308530)
1216
1217        L = lines[3].strip().split()
1218        assert L[0].strip().lower() == 'yllcorner'
1219        assert allclose(float(L[1].strip().lower()), 6189050)
1220
1221        L = lines[4].strip().split()
1222        assert L[0].strip().lower() == 'cellsize'
1223        assert allclose(float(L[1].strip().lower()), cellsize)
1224
1225        L = lines[5].strip().split()
1226        assert L[0].strip() == 'NODATA_value'
1227        assert L[1].strip().lower() == '-9999'
1228
1229        #Check grid values
1230        for i, line in enumerate(lines[6:]):
1231            for j, value in enumerate( line.split() ):
1232                #assert float(value) == -(10-i+j)*cellsize               
1233                assert float(value) == -(10-i+j+3)*cellsize
1234       
1235
1236
1237        #Cleanup
1238        os.remove(prjfile)
1239        os.remove(ascfile)
1240        os.remove(swwfile)
1241
1242
1243
1244    def test_sww2dem_asc_stage_reduction(self):
1245        """Test that sww information can be converted correctly to asc/prj
1246        format readable by e.g. ArcView
1247
1248        This tests the reduction of quantity stage using min
1249        """
1250
1251        import time, os
1252        from Numeric import array, zeros, allclose, Float, concatenate
1253        from Scientific.IO.NetCDF import NetCDFFile
1254
1255        #Setup
1256        self.domain.filename = 'datatest'
1257
1258        prjfile = self.domain.filename + '_stage.prj'
1259        ascfile = self.domain.filename + '_stage.asc'
1260        swwfile = self.domain.filename + '.sww'
1261
1262        self.domain.set_datadir('.')
1263        self.domain.format = 'sww'
1264        self.domain.smooth = True
1265        self.domain.set_quantity('elevation', lambda x,y: -x-y)
1266
1267        self.domain.geo_reference = Geo_reference(56,308500,6189000)
1268
1269
1270        sww = get_dataobject(self.domain)
1271        sww.store_connectivity()
1272        sww.store_timestep('stage')
1273
1274        self.domain.evolve_to_end(finaltime = 0.01)
1275        sww.store_timestep('stage')
1276
1277        cellsize = 0.25
1278        #Check contents
1279        #Get NetCDF
1280
1281        fid = NetCDFFile(sww.filename, 'r')
1282
1283        # Get the variables
1284        x = fid.variables['x'][:]
1285        y = fid.variables['y'][:]
1286        z = fid.variables['elevation'][:]
1287        time = fid.variables['time'][:]
1288        stage = fid.variables['stage'][:]
1289
1290
1291        #Export to ascii/prj files
1292        sww2dem(self.domain.filename,
1293                quantity = 'stage',
1294                cellsize = cellsize,
1295                reduction = min,
1296                format = 'asc')
1297
1298
1299        #Check asc file
1300        ascid = open(ascfile)
1301        lines = ascid.readlines()
1302        ascid.close()
1303
1304        L = lines[0].strip().split()
1305        assert L[0].strip().lower() == 'ncols'
1306        assert L[1].strip().lower() == '5'
1307
1308        L = lines[1].strip().split()
1309        assert L[0].strip().lower() == 'nrows'
1310        assert L[1].strip().lower() == '5'
1311
1312        L = lines[2].strip().split()
1313        assert L[0].strip().lower() == 'xllcorner'
1314        assert allclose(float(L[1].strip().lower()), 308500)
1315
1316        L = lines[3].strip().split()
1317        assert L[0].strip().lower() == 'yllcorner'
1318        assert allclose(float(L[1].strip().lower()), 6189000)
1319
1320        L = lines[4].strip().split()
1321        assert L[0].strip().lower() == 'cellsize'
1322        assert allclose(float(L[1].strip().lower()), cellsize)
1323
1324        L = lines[5].strip().split()
1325        assert L[0].strip() == 'NODATA_value'
1326        assert L[1].strip().lower() == '-9999'
1327
1328
1329        #Check grid values (where applicable)
1330        for j in range(5):
1331            if j%2 == 0:
1332                L = lines[6+j].strip().split()
1333                jj = 4-j
1334                for i in range(5):
1335                    if i%2 == 0:
1336                        index = jj/2 + i/2*3
1337                        val0 = stage[0,index]
1338                        val1 = stage[1,index]
1339
1340                        #print i, j, index, ':', L[i], val0, val1
1341                        assert allclose(float(L[i]), min(val0, val1))
1342
1343
1344        fid.close()
1345
1346        #Cleanup
1347        os.remove(prjfile)
1348        os.remove(ascfile)
1349        #os.remove(swwfile)
1350
1351
1352
1353
1354    def test_sww2dem_asc_missing_points(self):
1355        """Test that sww information can be converted correctly to asc/prj
1356        format readable by e.g. ArcView
1357
1358        This test includes the writing of missing values
1359        """
1360
1361        import time, os
1362        from Numeric import array, zeros, allclose, Float, concatenate
1363        from Scientific.IO.NetCDF import NetCDFFile
1364
1365        #Setup mesh not coinciding with rectangle.
1366        #This will cause missing values to occur in gridded data
1367
1368
1369        points = [                        [1.0, 1.0],
1370                              [0.5, 0.5], [1.0, 0.5],
1371                  [0.0, 0.0], [0.5, 0.0], [1.0, 0.0]]
1372
1373        vertices = [ [4,1,3], [5,2,4], [1,4,2], [2,0,1]]
1374
1375        #Create shallow water domain
1376        domain = Domain(points, vertices)
1377        domain.default_order=2
1378
1379
1380        #Set some field values
1381        domain.set_quantity('elevation', lambda x,y: -x-y)
1382        domain.set_quantity('friction', 0.03)
1383
1384
1385        ######################
1386        # Boundary conditions
1387        B = Transmissive_boundary(domain)
1388        domain.set_boundary( {'exterior': B} )
1389
1390
1391        ######################
1392        #Initial condition - with jumps
1393
1394        bed = domain.quantities['elevation'].vertex_values
1395        stage = zeros(bed.shape, Float)
1396
1397        h = 0.3
1398        for i in range(stage.shape[0]):
1399            if i % 2 == 0:
1400                stage[i,:] = bed[i,:] + h
1401            else:
1402                stage[i,:] = bed[i,:]
1403
1404        domain.set_quantity('stage', stage)
1405        domain.distribute_to_vertices_and_edges()
1406
1407        domain.filename = 'datatest'
1408
1409        prjfile = domain.filename + '_elevation.prj'
1410        ascfile = domain.filename + '_elevation.asc'
1411        swwfile = domain.filename + '.sww'
1412
1413        domain.set_datadir('.')
1414        domain.format = 'sww'
1415        domain.smooth = True
1416
1417        domain.geo_reference = Geo_reference(56,308500,6189000)
1418
1419        sww = get_dataobject(domain)
1420        sww.store_connectivity()
1421        sww.store_timestep('stage')
1422
1423        cellsize = 0.25
1424        #Check contents
1425        #Get NetCDF
1426
1427        fid = NetCDFFile(swwfile, 'r')
1428
1429        # Get the variables
1430        x = fid.variables['x'][:]
1431        y = fid.variables['y'][:]
1432        z = fid.variables['elevation'][:]
1433        time = fid.variables['time'][:]
1434
1435        try:
1436            geo_reference = Geo_reference(NetCDFObject=fid)
1437        except AttributeError, e:
1438            geo_reference = Geo_reference(DEFAULT_ZONE,0,0)
1439
1440        #Export to ascii/prj files
1441        sww2dem(domain.filename,
1442                quantity = 'elevation',
1443                cellsize = cellsize,
1444                verbose = False,
1445                format = 'asc')
1446
1447
1448        #Check asc file
1449        ascid = open(ascfile)
1450        lines = ascid.readlines()
1451        ascid.close()
1452
1453        L = lines[0].strip().split()
1454        assert L[0].strip().lower() == 'ncols'
1455        assert L[1].strip().lower() == '5'
1456
1457        L = lines[1].strip().split()
1458        assert L[0].strip().lower() == 'nrows'
1459        assert L[1].strip().lower() == '5'
1460
1461        L = lines[2].strip().split()
1462        assert L[0].strip().lower() == 'xllcorner'
1463        assert allclose(float(L[1].strip().lower()), 308500)
1464
1465        L = lines[3].strip().split()
1466        assert L[0].strip().lower() == 'yllcorner'
1467        assert allclose(float(L[1].strip().lower()), 6189000)
1468
1469        L = lines[4].strip().split()
1470        assert L[0].strip().lower() == 'cellsize'
1471        assert allclose(float(L[1].strip().lower()), cellsize)
1472
1473        L = lines[5].strip().split()
1474        assert L[0].strip() == 'NODATA_value'
1475        assert L[1].strip().lower() == '-9999'
1476
1477
1478        #Check grid values
1479        for j in range(5):
1480            L = lines[6+j].strip().split()
1481            y = (4-j) * cellsize
1482            for i in range(5):
1483                if i+j >= 4:
1484                    assert allclose(float(L[i]), -i*cellsize - y)
1485                else:
1486                    #Missing values
1487                    assert allclose(float(L[i]), -9999)
1488
1489
1490
1491        fid.close()
1492
1493        #Cleanup
1494        os.remove(prjfile)
1495        os.remove(ascfile)
1496        os.remove(swwfile)
1497
1498    def test_sww2ers_simple(self):
1499        """Test that sww information can be converted correctly to asc/prj
1500        format readable by e.g. ArcView
1501        """
1502
1503        import time, os
1504        from Numeric import array, zeros, allclose, Float, concatenate
1505        from Scientific.IO.NetCDF import NetCDFFile
1506
1507        #Setup
1508        self.domain.filename = 'datatest'
1509
1510        headerfile = self.domain.filename + '.ers'
1511        swwfile = self.domain.filename + '.sww'
1512
1513        self.domain.set_datadir('.')
1514        self.domain.format = 'sww'
1515        self.domain.smooth = True
1516        self.domain.set_quantity('elevation', lambda x,y: -x-y)
1517
1518        self.domain.geo_reference = Geo_reference(56,308500,6189000)
1519
1520        sww = get_dataobject(self.domain)
1521        sww.store_connectivity()
1522        sww.store_timestep('stage')
1523
1524        self.domain.evolve_to_end(finaltime = 0.01)
1525        sww.store_timestep('stage')
1526
1527        cellsize = 0.25
1528        #Check contents
1529        #Get NetCDF
1530
1531        fid = NetCDFFile(sww.filename, 'r')
1532
1533        # Get the variables
1534        x = fid.variables['x'][:]
1535        y = fid.variables['y'][:]
1536        z = fid.variables['elevation'][:]
1537        time = fid.variables['time'][:]
1538        stage = fid.variables['stage'][:]
1539
1540
1541        #Export to ers files
1542        #sww2ers(self.domain.filename,
1543        #        quantity = 'elevation',
1544        #        cellsize = cellsize,
1545        #        verbose = False)
1546               
1547        sww2dem(self.domain.filename,
1548                quantity = 'elevation',
1549                cellsize = cellsize,
1550                verbose = False,
1551                format = 'ers')
1552
1553        #Check header data
1554        from ermapper_grids import read_ermapper_header, read_ermapper_data
1555       
1556        header = read_ermapper_header(self.domain.filename + '_elevation.ers')
1557        #print header
1558        assert header['projection'].lower() == '"utm-56"'
1559        assert header['datum'].lower() == '"wgs84"'
1560        assert header['units'].lower() == '"meters"'   
1561        assert header['value'].lower() == '"elevation"'         
1562        assert header['xdimension'] == '0.25'
1563        assert header['ydimension'] == '0.25'   
1564        assert float(header['eastings']) == 308500.0   #xllcorner
1565        assert float(header['northings']) == 6189000.0 #yllcorner       
1566        assert int(header['nroflines']) == 5
1567        assert int(header['nrofcellsperline']) == 5     
1568        assert int(header['nullcellvalue']) == 0   #?           
1569        #FIXME - there is more in the header                   
1570
1571               
1572        #Check grid data               
1573        grid = read_ermapper_data(self.domain.filename + '_elevation') 
1574       
1575        #FIXME (Ole): Why is this the desired reference grid for -x-y?
1576        ref_grid = [0,      0,     0,     0,     0,
1577                    -1,    -1.25, -1.5,  -1.75, -2.0,
1578                    -0.75, -1.0,  -1.25, -1.5,  -1.75,             
1579                    -0.5,  -0.75, -1.0,  -1.25, -1.5,
1580                    -0.25, -0.5,  -0.75, -1.0,  -1.25]             
1581                                         
1582                                         
1583        assert allclose(grid, ref_grid)
1584
1585        fid.close()
1586       
1587        #Cleanup
1588        #FIXME the file clean-up doesn't work (eg Permission Denied Error)
1589        #Done (Ole) - it was because sww2ers didn't close it's sww file
1590        os.remove(sww.filename)
1591
1592
1593
1594    def test_ferret2sww(self):
1595        """Test that georeferencing etc works when converting from
1596        ferret format (lat/lon) to sww format (UTM)
1597        """
1598        from Scientific.IO.NetCDF import NetCDFFile
1599
1600        #The test file has
1601        # LON = 150.66667, 150.83334, 151, 151.16667
1602        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1603        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1604        #
1605        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1606        # Fourth value (index==3) is -6.50198 cm
1607
1608
1609        from coordinate_transforms.redfearn import redfearn
1610
1611        fid = NetCDFFile('small_ha.nc')
1612        first_value = fid.variables['HA'][:][0,0,0]
1613        fourth_value = fid.variables['HA'][:][0,0,3]
1614
1615
1616        #Call conversion (with zero origin)
1617        ferret2sww('small', verbose=False,
1618                   origin = (56, 0, 0))
1619
1620
1621        #Work out the UTM coordinates for first point
1622        zone, e, n = redfearn(-34.5, 150.66667)
1623        #print zone, e, n
1624
1625        #Read output file 'small.sww'
1626        fid = NetCDFFile('small.sww')
1627
1628        x = fid.variables['x'][:]
1629        y = fid.variables['y'][:]
1630
1631        #Check that first coordinate is correctly represented
1632        assert allclose(x[0], e)
1633        assert allclose(y[0], n)
1634
1635        #Check first value
1636        stage = fid.variables['stage'][:]
1637        xmomentum = fid.variables['xmomentum'][:]
1638        ymomentum = fid.variables['ymomentum'][:]
1639
1640        #print ymomentum
1641
1642        assert allclose(stage[0,0], first_value/100)  #Meters
1643
1644        #Check fourth value
1645        assert allclose(stage[0,3], fourth_value/100)  #Meters
1646
1647        fid.close()
1648
1649        #Cleanup
1650        import os
1651        os.remove('small.sww')
1652
1653
1654
1655    def test_ferret2sww_2(self):
1656        """Test that georeferencing etc works when converting from
1657        ferret format (lat/lon) to sww format (UTM)
1658        """
1659        from Scientific.IO.NetCDF import NetCDFFile
1660
1661        #The test file has
1662        # LON = 150.66667, 150.83334, 151, 151.16667
1663        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1664        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1665        #
1666        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1667        # Fourth value (index==3) is -6.50198 cm
1668
1669
1670        from coordinate_transforms.redfearn import redfearn
1671
1672        fid = NetCDFFile('small_ha.nc')
1673
1674        #Pick a coordinate and a value
1675
1676        time_index = 1
1677        lat_index = 0
1678        lon_index = 2
1679
1680        test_value = fid.variables['HA'][:][time_index, lat_index, lon_index]
1681        test_time = fid.variables['TIME'][:][time_index]
1682        test_lat = fid.variables['LAT'][:][lat_index]
1683        test_lon = fid.variables['LON'][:][lon_index]
1684
1685        linear_point_index = lat_index*4 + lon_index
1686        fid.close()
1687
1688        #Call conversion (with zero origin)
1689        ferret2sww('small', verbose=False,
1690                   origin = (56, 0, 0))
1691
1692
1693        #Work out the UTM coordinates for test point
1694        zone, e, n = redfearn(test_lat, test_lon)
1695
1696        #Read output file 'small.sww'
1697        fid = NetCDFFile('small.sww')
1698
1699        x = fid.variables['x'][:]
1700        y = fid.variables['y'][:]
1701
1702        #Check that test coordinate is correctly represented
1703        assert allclose(x[linear_point_index], e)
1704        assert allclose(y[linear_point_index], n)
1705
1706        #Check test value
1707        stage = fid.variables['stage'][:]
1708
1709        assert allclose(stage[time_index, linear_point_index], test_value/100)
1710
1711        fid.close()
1712
1713        #Cleanup
1714        import os
1715        os.remove('small.sww')
1716
1717
1718
1719    def test_ferret2sww3(self):
1720        """Elevation included
1721        """
1722        from Scientific.IO.NetCDF import NetCDFFile
1723
1724        #The test file has
1725        # LON = 150.66667, 150.83334, 151, 151.16667
1726        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1727        # ELEVATION = [-1 -2 -3 -4
1728        #             -5 -6 -7 -8
1729        #              ...
1730        #              ...      -16]
1731        # where the top left corner is -1m,
1732        # and the ll corner is -13.0m
1733        #
1734        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1735        # Fourth value (index==3) is -6.50198 cm
1736
1737        from coordinate_transforms.redfearn import redfearn
1738        import os
1739        fid1 = NetCDFFile('test_ha.nc','w')
1740        fid2 = NetCDFFile('test_ua.nc','w')
1741        fid3 = NetCDFFile('test_va.nc','w')
1742        fid4 = NetCDFFile('test_e.nc','w')
1743
1744        h1_list = [150.66667,150.83334,151.]
1745        h2_list = [-34.5,-34.33333]
1746
1747        long_name = 'LON'
1748        lat_name = 'LAT'
1749
1750        nx = 3
1751        ny = 2
1752
1753        for fid in [fid1,fid2,fid3]:
1754            fid.createDimension(long_name,nx)
1755            fid.createVariable(long_name,'d',(long_name,))
1756            fid.variables[long_name].point_spacing='uneven'
1757            fid.variables[long_name].units='degrees_east'
1758            fid.variables[long_name].assignValue(h1_list)
1759
1760            fid.createDimension(lat_name,ny)
1761            fid.createVariable(lat_name,'d',(lat_name,))
1762            fid.variables[lat_name].point_spacing='uneven'
1763            fid.variables[lat_name].units='degrees_north'
1764            fid.variables[lat_name].assignValue(h2_list)
1765
1766            fid.createDimension('TIME',2)
1767            fid.createVariable('TIME','d',('TIME',))
1768            fid.variables['TIME'].point_spacing='uneven'
1769            fid.variables['TIME'].units='seconds'
1770            fid.variables['TIME'].assignValue([0.,1.])
1771            if fid == fid3: break
1772
1773
1774        for fid in [fid4]:
1775            fid.createDimension(long_name,nx)
1776            fid.createVariable(long_name,'d',(long_name,))
1777            fid.variables[long_name].point_spacing='uneven'
1778            fid.variables[long_name].units='degrees_east'
1779            fid.variables[long_name].assignValue(h1_list)
1780
1781            fid.createDimension(lat_name,ny)
1782            fid.createVariable(lat_name,'d',(lat_name,))
1783            fid.variables[lat_name].point_spacing='uneven'
1784            fid.variables[lat_name].units='degrees_north'
1785            fid.variables[lat_name].assignValue(h2_list)
1786
1787        name = {}
1788        name[fid1]='HA'
1789        name[fid2]='UA'
1790        name[fid3]='VA'
1791        name[fid4]='ELEVATION'
1792
1793        units = {}
1794        units[fid1]='cm'
1795        units[fid2]='cm/s'
1796        units[fid3]='cm/s'
1797        units[fid4]='m'
1798
1799        values = {}
1800        values[fid1]=[[[5., 10.,15.], [13.,18.,23.]],[[50.,100.,150.],[130.,180.,230.]]]
1801        values[fid2]=[[[1., 2.,3.], [4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]]
1802        values[fid3]=[[[13., 12.,11.], [10.,9.,8.]],[[7.,6.,5.],[4.,3.,2.]]]
1803        values[fid4]=[[-3000,-3100,-3200],[-4000,-5000,-6000]]
1804
1805        for fid in [fid1,fid2,fid3]:
1806          fid.createVariable(name[fid],'d',('TIME',lat_name,long_name))
1807          fid.variables[name[fid]].point_spacing='uneven'
1808          fid.variables[name[fid]].units=units[fid]
1809          fid.variables[name[fid]].assignValue(values[fid])
1810          fid.variables[name[fid]].missing_value = -99999999.
1811          if fid == fid3: break
1812
1813        for fid in [fid4]:
1814            fid.createVariable(name[fid],'d',(lat_name,long_name))
1815            fid.variables[name[fid]].point_spacing='uneven'
1816            fid.variables[name[fid]].units=units[fid]
1817            fid.variables[name[fid]].assignValue(values[fid])
1818            fid.variables[name[fid]].missing_value = -99999999.
1819
1820
1821        fid1.sync(); fid1.close()
1822        fid2.sync(); fid2.close()
1823        fid3.sync(); fid3.close()
1824        fid4.sync(); fid4.close()
1825
1826        fid1 = NetCDFFile('test_ha.nc','r')
1827        fid2 = NetCDFFile('test_e.nc','r')
1828        fid3 = NetCDFFile('test_va.nc','r')
1829
1830
1831        first_amp = fid1.variables['HA'][:][0,0,0]
1832        third_amp = fid1.variables['HA'][:][0,0,2]
1833        first_elevation = fid2.variables['ELEVATION'][0,0]
1834        third_elevation= fid2.variables['ELEVATION'][:][0,2]
1835        first_speed = fid3.variables['VA'][0,0,0]
1836        third_speed = fid3.variables['VA'][:][0,0,2]
1837
1838        fid1.close()
1839        fid2.close()
1840        fid3.close()
1841
1842        #Call conversion (with zero origin)
1843        ferret2sww('test', verbose=False,
1844                   origin = (56, 0, 0))
1845
1846        os.remove('test_va.nc')
1847        os.remove('test_ua.nc')
1848        os.remove('test_ha.nc')
1849        os.remove('test_e.nc')
1850
1851        #Read output file 'test.sww'
1852        fid = NetCDFFile('test.sww')
1853
1854
1855        #Check first value
1856        elevation = fid.variables['elevation'][:]
1857        stage = fid.variables['stage'][:]
1858        xmomentum = fid.variables['xmomentum'][:]
1859        ymomentum = fid.variables['ymomentum'][:]
1860
1861        #print ymomentum
1862        first_height = first_amp/100 - first_elevation
1863        third_height = third_amp/100 - third_elevation
1864        first_momentum=first_speed*first_height/100
1865        third_momentum=third_speed*third_height/100
1866
1867        assert allclose(ymomentum[0][0],first_momentum)  #Meters
1868        assert allclose(ymomentum[0][2],third_momentum)  #Meters
1869
1870        fid.close()
1871
1872        #Cleanup
1873        os.remove('test.sww')
1874
1875
1876
1877
1878    def test_ferret2sww_nz_origin(self):
1879        from Scientific.IO.NetCDF import NetCDFFile
1880        from coordinate_transforms.redfearn import redfearn
1881
1882        #Call conversion (with nonzero origin)
1883        ferret2sww('small', verbose=False,
1884                   origin = (56, 100000, 200000))
1885
1886
1887        #Work out the UTM coordinates for first point
1888        zone, e, n = redfearn(-34.5, 150.66667)
1889
1890        #Read output file 'small.sww'
1891        fid = NetCDFFile('small.sww', 'r')
1892
1893        x = fid.variables['x'][:]
1894        y = fid.variables['y'][:]
1895
1896        #Check that first coordinate is correctly represented
1897        assert allclose(x[0], e-100000)
1898        assert allclose(y[0], n-200000)
1899
1900        fid.close()
1901
1902        #Cleanup
1903        import os
1904        os.remove('small.sww')
1905
1906
1907
1908    def test_sww_extent(self):
1909        """Not a test, rather a look at the sww format
1910        """
1911
1912        import time, os
1913        from Numeric import array, zeros, allclose, Float, concatenate
1914        from Scientific.IO.NetCDF import NetCDFFile
1915
1916        self.domain.filename = 'datatest' + str(id(self))
1917        self.domain.format = 'sww'
1918        self.domain.smooth = True
1919        self.domain.reduction = mean
1920        self.domain.set_datadir('.')
1921
1922
1923        sww = get_dataobject(self.domain)
1924        sww.store_connectivity()
1925        sww.store_timestep('stage')
1926        self.domain.time = 2.
1927
1928        #Modify stage at second timestep
1929        stage = self.domain.quantities['stage'].vertex_values
1930        self.domain.set_quantity('stage', stage/2)
1931
1932        sww.store_timestep('stage')
1933
1934        file_and_extension_name = self.domain.filename + ".sww"
1935        #print "file_and_extension_name",file_and_extension_name
1936        [xmin, xmax, ymin, ymax, stagemin, stagemax] = \
1937               extent_sww(file_and_extension_name )
1938
1939        assert allclose(xmin, 0.0)
1940        assert allclose(xmax, 1.0)
1941        assert allclose(ymin, 0.0)
1942        assert allclose(ymax, 1.0)
1943        assert allclose(stagemin, -0.85)
1944        assert allclose(stagemax, 0.15)
1945
1946
1947        #Cleanup
1948        os.remove(sww.filename)
1949
1950
1951
1952    def test_sww2domain(self):
1953        ################################################
1954        #Create a test domain, and evolve and save it.
1955        ################################################
1956        from mesh_factory import rectangular
1957        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1958             Constant_height, Time_boundary, Transmissive_boundary
1959        from Numeric import array
1960
1961        #Create basic mesh
1962
1963        yiel=0.01
1964        points, vertices, boundary = rectangular(10,10)
1965
1966        #Create shallow water domain
1967        domain = Domain(points, vertices, boundary)
1968        domain.geo_reference = Geo_reference(56,11,11)
1969        domain.smooth = False
1970        domain.visualise = False
1971        domain.store = True
1972        domain.filename = 'bedslope'
1973        domain.default_order=2
1974        #Bed-slope and friction
1975        domain.set_quantity('elevation', lambda x,y: -x/3)
1976        domain.set_quantity('friction', 0.1)
1977        # Boundary conditions
1978        from math import sin, pi
1979        Br = Reflective_boundary(domain)
1980        Bt = Transmissive_boundary(domain)
1981        Bd = Dirichlet_boundary([0.2,0.,0.])
1982        Bw = Time_boundary(domain=domain,
1983                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1984
1985        #domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1986        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1987
1988        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1989        #Initial condition
1990        h = 0.05
1991        elevation = domain.quantities['elevation'].vertex_values
1992        domain.set_quantity('stage', elevation + h)
1993        #elevation = domain.get_quantity('elevation')
1994        #domain.set_quantity('stage', elevation + h)
1995
1996        domain.check_integrity()
1997        #Evolution
1998        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1999        #    domain.write_time()
2000            pass
2001
2002
2003        ##########################################
2004        #Import the example's file as a new domain
2005        ##########################################
2006        from data_manager import sww2domain
2007        from Numeric import allclose
2008        import os
2009
2010        filename = domain.datadir+os.sep+domain.filename+'.sww'
2011        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
2012        #points, vertices, boundary = rectangular(15,15)
2013        #domain2.boundary = boundary
2014        ###################
2015        ##NOW TEST IT!!!
2016        ###################
2017
2018        bits = ['vertex_coordinates']
2019        for quantity in ['elevation']+domain.quantities_to_be_stored:
2020            bits.append('quantities["%s"].get_integral()'%quantity)
2021            bits.append('get_quantity("%s")'%quantity)
2022
2023        for bit in bits:
2024            #print 'testing that domain.'+bit+' has been restored'
2025            #print bit
2026        #print 'done'
2027            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
2028
2029        ######################################
2030        #Now evolve them both, just to be sure
2031        ######################################x
2032        visualise = False
2033        #visualise = True
2034        domain.visualise = visualise
2035        domain.time = 0.
2036        from time import sleep
2037
2038        final = .1
2039        domain.set_quantity('friction', 0.1)
2040        domain.store = False
2041        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
2042
2043
2044        for t in domain.evolve(yieldstep = yiel, finaltime = final):
2045            if visualise: sleep(1.)
2046            #domain.write_time()
2047            pass
2048
2049        final = final - (domain2.starttime-domain.starttime)
2050        #BUT since domain1 gets time hacked back to 0:
2051        final = final + (domain2.starttime-domain.starttime)
2052
2053        domain2.smooth = False
2054        domain2.visualise = visualise
2055        domain2.store = False
2056        domain2.default_order=2
2057        domain2.set_quantity('friction', 0.1)
2058        #Bed-slope and friction
2059        # Boundary conditions
2060        Bd2=Dirichlet_boundary([0.2,0.,0.])
2061        domain2.boundary = domain.boundary
2062        #print 'domain2.boundary'
2063        #print domain2.boundary
2064        domain2.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
2065        #domain2.set_boundary({'exterior': Bd})
2066
2067        domain2.check_integrity()
2068
2069        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
2070            if visualise: sleep(1.)
2071            #domain2.write_time()
2072            pass
2073
2074        ###################
2075        ##NOW TEST IT!!!
2076        ##################
2077
2078        bits = [ 'vertex_coordinates']
2079
2080        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
2081            bits.append('quantities["%s"].get_integral()'%quantity)
2082            bits.append('get_quantity("%s")'%quantity)
2083
2084        for bit in bits:
2085            #print bit
2086            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
2087
2088
2089    def test_sww2domain2(self):
2090        ##################################################################
2091        #Same as previous test, but this checks how NaNs are handled.
2092        ##################################################################
2093
2094
2095        from mesh_factory import rectangular
2096        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
2097             Constant_height, Time_boundary, Transmissive_boundary
2098        from Numeric import array
2099
2100        #Create basic mesh
2101        points, vertices, boundary = rectangular(2,2)
2102
2103        #Create shallow water domain
2104        domain = Domain(points, vertices, boundary)
2105        domain.smooth = False
2106        domain.visualise = False
2107        domain.store = True
2108        domain.filename = 'bedslope'
2109        domain.default_order=2
2110        domain.quantities_to_be_stored=['stage']
2111
2112        domain.set_quantity('elevation', lambda x,y: -x/3)
2113        domain.set_quantity('friction', 0.1)
2114
2115        from math import sin, pi
2116        Br = Reflective_boundary(domain)
2117        Bt = Transmissive_boundary(domain)
2118        Bd = Dirichlet_boundary([0.2,0.,0.])
2119        Bw = Time_boundary(domain=domain,
2120                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
2121
2122        domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
2123
2124        h = 0.05
2125        elevation = domain.quantities['elevation'].vertex_values
2126        domain.set_quantity('stage', elevation + h)
2127
2128        domain.check_integrity()
2129
2130        for t in domain.evolve(yieldstep = 1, finaltime = 2.0):
2131            pass
2132            #domain.write_time()
2133
2134
2135
2136        ##################################
2137        #Import the file as a new domain
2138        ##################################
2139        from data_manager import sww2domain
2140        from Numeric import allclose
2141        import os
2142
2143        filename = domain.datadir+os.sep+domain.filename+'.sww'
2144
2145        #Fail because NaNs are present
2146        try:
2147            domain2 = sww2domain(filename,boundary,fail_if_NaN=True,verbose=False)
2148            assert True == False
2149        except:
2150            #Now import it, filling NaNs to be 0
2151            filler = 0
2152            domain2 = sww2domain(filename,None,fail_if_NaN=False,NaN_filler = filler,verbose=False)
2153        bits = [ 'geo_reference.get_xllcorner()',
2154                'geo_reference.get_yllcorner()',
2155                'vertex_coordinates']
2156
2157        for quantity in ['elevation']+domain.quantities_to_be_stored:
2158            bits.append('quantities["%s"].get_integral()'%quantity)
2159            bits.append('get_quantity("%s")'%quantity)
2160
2161        for bit in bits:
2162        #    print 'testing that domain.'+bit+' has been restored'
2163            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
2164
2165        assert max(max(domain2.get_quantity('xmomentum')))==filler
2166        assert min(min(domain2.get_quantity('xmomentum')))==filler
2167        assert max(max(domain2.get_quantity('ymomentum')))==filler
2168        assert min(min(domain2.get_quantity('ymomentum')))==filler
2169
2170        #print 'passed'
2171
2172        #cleanup
2173        #import os
2174        #os.remove(domain.datadir+'/'+domain.filename+'.sww')
2175
2176
2177    #def test_weed(self):
2178        from data_manager import weed
2179
2180        coordinates1 = [[0.,0.],[1.,0.],[1.,1.],[1.,0.],[2.,0.],[1.,1.]]
2181        volumes1 = [[0,1,2],[3,4,5]]
2182        boundary1= {(0,1): 'external',(1,2): 'not external',(2,0): 'external',(3,4): 'external',(4,5): 'external',(5,3): 'not external'}
2183        coordinates2,volumes2,boundary2=weed(coordinates1,volumes1,boundary1)
2184
2185        points2 = {(0.,0.):None,(1.,0.):None,(1.,1.):None,(2.,0.):None}
2186
2187        assert len(points2)==len(coordinates2)
2188        for i in range(len(coordinates2)):
2189            coordinate = tuple(coordinates2[i])
2190            assert points2.has_key(coordinate)
2191            points2[coordinate]=i
2192
2193        for triangle in volumes1:
2194            for coordinate in triangle:
2195                assert coordinates2[points2[tuple(coordinates1[coordinate])]][0]==coordinates1[coordinate][0]
2196                assert coordinates2[points2[tuple(coordinates1[coordinate])]][1]==coordinates1[coordinate][1]
2197
2198
2199    #FIXME This fails - smooth makes the comparism too hard for allclose
2200    def ztest_sww2domain3(self):
2201        ################################################
2202        #DOMAIN.SMOOTH = TRUE !!!!!!!!!!!!!!!!!!!!!!!!!!
2203        ################################################
2204        from mesh_factory import rectangular
2205        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
2206             Constant_height, Time_boundary, Transmissive_boundary
2207        from Numeric import array
2208        #Create basic mesh
2209
2210        yiel=0.01
2211        points, vertices, boundary = rectangular(10,10)
2212
2213        #Create shallow water domain
2214        domain = Domain(points, vertices, boundary)
2215        domain.geo_reference = Geo_reference(56,11,11)
2216        domain.smooth = True
2217        domain.visualise = False
2218        domain.store = True
2219        domain.filename = 'bedslope'
2220        domain.default_order=2
2221        #Bed-slope and friction
2222        domain.set_quantity('elevation', lambda x,y: -x/3)
2223        domain.set_quantity('friction', 0.1)
2224        # Boundary conditions
2225        from math import sin, pi
2226        Br = Reflective_boundary(domain)
2227        Bt = Transmissive_boundary(domain)
2228        Bd = Dirichlet_boundary([0.2,0.,0.])
2229        Bw = Time_boundary(domain=domain,
2230                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
2231
2232        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
2233
2234        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
2235        #Initial condition
2236        h = 0.05
2237        elevation = domain.quantities['elevation'].vertex_values
2238        domain.set_quantity('stage', elevation + h)
2239        #elevation = domain.get_quantity('elevation')
2240        #domain.set_quantity('stage', elevation + h)
2241
2242        domain.check_integrity()
2243        #Evolution
2244        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
2245        #    domain.write_time()
2246            pass
2247
2248
2249        ##########################################
2250        #Import the example's file as a new domain
2251        ##########################################
2252        from data_manager import sww2domain
2253        from Numeric import allclose
2254        import os
2255
2256        filename = domain.datadir+os.sep+domain.filename+'.sww'
2257        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
2258        #points, vertices, boundary = rectangular(15,15)
2259        #domain2.boundary = boundary
2260        ###################
2261        ##NOW TEST IT!!!
2262        ###################
2263
2264        #FIXME smooth domain so that they can be compared
2265
2266
2267        bits = []#'vertex_coordinates']
2268        for quantity in ['elevation']+domain.quantities_to_be_stored:
2269            bits.append('quantities["%s"].get_integral()'%quantity)
2270            #bits.append('get_quantity("%s")'%quantity)
2271
2272        for bit in bits:
2273            #print 'testing that domain.'+bit+' has been restored'
2274            #print bit
2275            #print 'done'
2276            #print ('domain.'+bit), eval('domain.'+bit)
2277            #print ('domain2.'+bit), eval('domain2.'+bit)
2278            assert allclose(eval('domain.'+bit),eval('domain2.'+bit),rtol=1.0e-1,atol=1.e-3)
2279            pass
2280
2281        ######################################
2282        #Now evolve them both, just to be sure
2283        ######################################x
2284        visualise = False
2285        visualise = True
2286        domain.visualise = visualise
2287        domain.time = 0.
2288        from time import sleep
2289
2290        final = .5
2291        domain.set_quantity('friction', 0.1)
2292        domain.store = False
2293        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Br})
2294
2295        for t in domain.evolve(yieldstep = yiel, finaltime = final):
2296            if visualise: sleep(.03)
2297            #domain.write_time()
2298            pass
2299
2300        domain2.smooth = True
2301        domain2.visualise = visualise
2302        domain2.store = False
2303        domain2.default_order=2
2304        domain2.set_quantity('friction', 0.1)
2305        #Bed-slope and friction
2306        # Boundary conditions
2307        Bd2=Dirichlet_boundary([0.2,0.,0.])
2308        Br2 = Reflective_boundary(domain2)
2309        domain2.boundary = domain.boundary
2310        #print 'domain2.boundary'
2311        #print domain2.boundary
2312        domain2.set_boundary({'left': Bd2, 'right': Bd2, 'top': Bd2, 'bottom': Br2})
2313        #domain2.boundary = domain.boundary
2314        #domain2.set_boundary({'exterior': Bd})
2315
2316        domain2.check_integrity()
2317
2318        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
2319            if visualise: sleep(.03)
2320            #domain2.write_time()
2321            pass
2322
2323        ###################
2324        ##NOW TEST IT!!!
2325        ##################
2326
2327        bits = [ 'vertex_coordinates']
2328
2329        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
2330            #bits.append('quantities["%s"].get_integral()'%quantity)
2331            bits.append('get_quantity("%s")'%quantity)
2332
2333        for bit in bits:
2334            print bit
2335            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
2336
2337
2338    def test_decimate_dem(self):
2339        """Test decimation of dem file
2340        """
2341
2342        import os
2343        from Numeric import ones, allclose, Float, arange
2344        from Scientific.IO.NetCDF import NetCDFFile
2345
2346        #Write test dem file
2347        root = 'decdemtest'
2348
2349        filename = root + '.dem'
2350        fid = NetCDFFile(filename, 'w')
2351
2352        fid.institution = 'Geoscience Australia'
2353        fid.description = 'NetCDF DEM format for compact and portable ' +\
2354                          'storage of spatial point data'
2355
2356        nrows = 15
2357        ncols = 18
2358
2359        fid.ncols = ncols
2360        fid.nrows = nrows
2361        fid.xllcorner = 2000.5
2362        fid.yllcorner = 3000.5
2363        fid.cellsize = 25
2364        fid.NODATA_value = -9999
2365
2366        fid.zone = 56
2367        fid.false_easting = 0.0
2368        fid.false_northing = 0.0
2369        fid.projection = 'UTM'
2370        fid.datum = 'WGS84'
2371        fid.units = 'METERS'
2372
2373        fid.createDimension('number_of_points', nrows*ncols)
2374
2375        fid.createVariable('elevation', Float, ('number_of_points',))
2376
2377        elevation = fid.variables['elevation']
2378
2379        elevation[:] = (arange(nrows*ncols))
2380
2381        fid.close()
2382
2383        #generate the elevation values expected in the decimated file
2384        ref_elevation = [(  0+  1+  2+ 18+ 19+ 20+ 36+ 37+ 38) / 9.0,
2385                         (  4+  5+  6+ 22+ 23+ 24+ 40+ 41+ 42) / 9.0,
2386                         (  8+  9+ 10+ 26+ 27+ 28+ 44+ 45+ 46) / 9.0,
2387                         ( 12+ 13+ 14+ 30+ 31+ 32+ 48+ 49+ 50) / 9.0,
2388                         ( 72+ 73+ 74+ 90+ 91+ 92+108+109+110) / 9.0,
2389                         ( 76+ 77+ 78+ 94+ 95+ 96+112+113+114) / 9.0,
2390                         ( 80+ 81+ 82+ 98+ 99+100+116+117+118) / 9.0,
2391                         ( 84+ 85+ 86+102+103+104+120+121+122) / 9.0,
2392                         (144+145+146+162+163+164+180+181+182) / 9.0,
2393                         (148+149+150+166+167+168+184+185+186) / 9.0,
2394                         (152+153+154+170+171+172+188+189+190) / 9.0,
2395                         (156+157+158+174+175+176+192+193+194) / 9.0,
2396                         (216+217+218+234+235+236+252+253+254) / 9.0,
2397                         (220+221+222+238+239+240+256+257+258) / 9.0,
2398                         (224+225+226+242+243+244+260+261+262) / 9.0,
2399                         (228+229+230+246+247+248+264+265+266) / 9.0]
2400
2401        #generate a stencil for computing the decimated values
2402        stencil = ones((3,3), Float) / 9.0
2403
2404        decimate_dem(root, stencil=stencil, cellsize_new=100)
2405
2406        #Open decimated NetCDF file
2407        fid = NetCDFFile(root + '_100.dem', 'r')
2408
2409        # Get decimated elevation
2410        elevation = fid.variables['elevation']
2411
2412        #Check values
2413        assert allclose(elevation, ref_elevation)
2414
2415        #Cleanup
2416        fid.close()
2417
2418        os.remove(root + '.dem')
2419        os.remove(root + '_100.dem')
2420
2421    def test_decimate_dem_NODATA(self):
2422        """Test decimation of dem file that includes NODATA values
2423        """
2424
2425        import os
2426        from Numeric import ones, allclose, Float, arange, reshape
2427        from Scientific.IO.NetCDF import NetCDFFile
2428
2429        #Write test dem file
2430        root = 'decdemtest'
2431
2432        filename = root + '.dem'
2433        fid = NetCDFFile(filename, 'w')
2434
2435        fid.institution = 'Geoscience Australia'
2436        fid.description = 'NetCDF DEM format for compact and portable ' +\
2437                          'storage of spatial point data'
2438
2439        nrows = 15
2440        ncols = 18
2441        NODATA_value = -9999
2442
2443        fid.ncols = ncols
2444        fid.nrows = nrows
2445        fid.xllcorner = 2000.5
2446        fid.yllcorner = 3000.5
2447        fid.cellsize = 25
2448        fid.NODATA_value = NODATA_value
2449
2450        fid.zone = 56
2451        fid.false_easting = 0.0
2452        fid.false_northing = 0.0
2453        fid.projection = 'UTM'
2454        fid.datum = 'WGS84'
2455        fid.units = 'METERS'
2456
2457        fid.createDimension('number_of_points', nrows*ncols)
2458
2459        fid.createVariable('elevation', Float, ('number_of_points',))
2460
2461        elevation = fid.variables['elevation']
2462
2463        #generate initial elevation values
2464        elevation_tmp = (arange(nrows*ncols))
2465        #add some NODATA values
2466        elevation_tmp[0]   = NODATA_value
2467        elevation_tmp[95]  = NODATA_value
2468        elevation_tmp[188] = NODATA_value
2469        elevation_tmp[189] = NODATA_value
2470        elevation_tmp[190] = NODATA_value
2471        elevation_tmp[209] = NODATA_value
2472        elevation_tmp[252] = NODATA_value
2473
2474        elevation[:] = elevation_tmp
2475
2476        fid.close()
2477
2478        #generate the elevation values expected in the decimated file
2479        ref_elevation = [NODATA_value,
2480                         (  4+  5+  6+ 22+ 23+ 24+ 40+ 41+ 42) / 9.0,
2481                         (  8+  9+ 10+ 26+ 27+ 28+ 44+ 45+ 46) / 9.0,
2482                         ( 12+ 13+ 14+ 30+ 31+ 32+ 48+ 49+ 50) / 9.0,
2483                         ( 72+ 73+ 74+ 90+ 91+ 92+108+109+110) / 9.0,
2484                         NODATA_value,
2485                         ( 80+ 81+ 82+ 98+ 99+100+116+117+118) / 9.0,
2486                         ( 84+ 85+ 86+102+103+104+120+121+122) / 9.0,
2487                         (144+145+146+162+163+164+180+181+182) / 9.0,
2488                         (148+149+150+166+167+168+184+185+186) / 9.0,
2489                         NODATA_value,
2490                         (156+157+158+174+175+176+192+193+194) / 9.0,
2491                         NODATA_value,
2492                         (220+221+222+238+239+240+256+257+258) / 9.0,
2493                         (224+225+226+242+243+244+260+261+262) / 9.0,
2494                         (228+229+230+246+247+248+264+265+266) / 9.0]
2495
2496        #generate a stencil for computing the decimated values
2497        stencil = ones((3,3), Float) / 9.0
2498
2499        decimate_dem(root, stencil=stencil, cellsize_new=100)
2500
2501        #Open decimated NetCDF file
2502        fid = NetCDFFile(root + '_100.dem', 'r')
2503
2504        # Get decimated elevation
2505        elevation = fid.variables['elevation']
2506
2507        #Check values
2508        assert allclose(elevation, ref_elevation)
2509
2510        #Cleanup
2511        fid.close()
2512
2513        os.remove(root + '.dem')
2514        os.remove(root + '_100.dem')
2515
2516
2517
2518
2519#-------------------------------------------------------------
2520if __name__ == "__main__":
2521    suite = unittest.makeSuite(Test_Data_Manager,'test')
2522    #suite = unittest.makeSuite(Test_Data_Manager,'test_sww2dem_boundingbox')   
2523    #suite = unittest.makeSuite(Test_Data_Manager,'test_dem2pts_bounding_box')
2524    #suite = unittest.makeSuite(Test_Data_Manager,'test_decimate_dem')
2525    #suite = unittest.makeSuite(Test_Data_Manager,'test_decimate_dem_NODATA')
2526    runner = unittest.TextTestRunner()
2527    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.