source: inundation/ga/storm_surge/pyvolution-parallel/test_data_manager.py @ 1452

Last change on this file since 1452 was 1201, checked in by prow, 20 years ago

fixed smoothed sww2domain

File size: 49.3 KB
Line 
1#!/usr/bin/env python
2#
3
4import unittest
5import copy
6from Numeric import zeros, array, allclose, Float
7from util import mean
8
9from data_manager import *
10from shallow_water import *
11from config import epsilon
12
13from coordinate_transforms.geo_reference import Geo_reference
14
15class Test_Data_Manager(unittest.TestCase):
16    def setUp(self):
17        import time
18        from mesh_factory import rectangular
19
20
21        #Create basic mesh
22        points, vertices, boundary = rectangular(2, 2)
23
24        #Create shallow water domain
25        domain = Domain(points, vertices, boundary)
26        domain.default_order=2
27
28
29        #Set some field values
30        domain.set_quantity('elevation', lambda x,y: -x)
31        domain.set_quantity('friction', 0.03)
32
33
34        ######################
35        # Boundary conditions
36        B = Transmissive_boundary(domain)
37        domain.set_boundary( {'left': B, 'right': B, 'top': B, 'bottom': B})
38
39
40        ######################
41        #Initial condition - with jumps
42
43
44        bed = domain.quantities['elevation'].vertex_values
45        stage = zeros(bed.shape, Float)
46
47        h = 0.3
48        for i in range(stage.shape[0]):
49            if i % 2 == 0:
50                stage[i,:] = bed[i,:] + h
51            else:
52                stage[i,:] = bed[i,:]
53
54        domain.set_quantity('stage', stage)
55        self.initial_stage = copy.copy(domain.quantities['stage'].vertex_values)
56
57        domain.distribute_to_vertices_and_edges()
58
59
60        self.domain = domain
61
62        C = domain.get_vertex_coordinates()
63        self.X = C[:,0:6:2].copy()
64        self.Y = C[:,1:6:2].copy()
65
66        self.F = bed
67
68
69    def tearDown(self):
70        pass
71
72
73
74
75#     def test_xya(self):
76#         import os
77#         from Numeric import concatenate
78
79#         import time, os
80#         from Numeric import array, zeros, allclose, Float, concatenate
81
82#         domain = self.domain
83
84#         domain.filename = 'datatest' + str(time.time())
85#         domain.format = 'xya'
86#         domain.smooth = True
87
88#         xya = get_dataobject(self.domain)
89#         xya.store_all()
90
91
92#         #Read back
93#         file = open(xya.filename)
94#         lFile = file.read().split('\n')
95#         lFile = lFile[:-1]
96
97#         file.close()
98#         os.remove(xya.filename)
99
100#         #Check contents
101#         if domain.smooth:
102#             self.failUnless(lFile[0] == '9 3 # <vertex #> <x> <y> [attributes]')
103#         else:
104#             self.failUnless(lFile[0] == '24 3 # <vertex #> <x> <y> [attributes]')
105
106#         #Get smoothed field values with X and Y
107#         X,Y,F,V = domain.get_vertex_values(xy=True, value_array='field_values',
108#                                            indices = (0,1), precision = Float)
109
110
111#         Q,V = domain.get_vertex_values(xy=False, value_array='conserved_quantities',
112#                                            indices = (0,), precision = Float)
113
114
115
116#         for i, line in enumerate(lFile[1:]):
117#             fields = line.split()
118
119#             assert len(fields) == 5
120
121#             assert allclose(float(fields[0]), X[i])
122#             assert allclose(float(fields[1]), Y[i])
123#             assert allclose(float(fields[2]), F[i,0])
124#             assert allclose(float(fields[3]), Q[i,0])
125#             assert allclose(float(fields[4]), F[i,1])
126
127
128
129
130    def test_sww_constant(self):
131        """Test that constant sww information can be written correctly
132        (non smooth)
133        """
134
135        import time, os
136        from Numeric import array, zeros, allclose, Float, concatenate
137        from Scientific.IO.NetCDF import NetCDFFile
138
139        self.domain.filename = 'datatest' + str(id(self))
140        self.domain.format = 'sww'
141        self.domain.smooth = False
142
143        sww = get_dataobject(self.domain)
144        sww.store_connectivity()
145
146        #Check contents
147        #Get NetCDF
148        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
149
150        # Get the variables
151        x = fid.variables['x']
152        y = fid.variables['y']
153        z = fid.variables['elevation']
154
155        volumes = fid.variables['volumes']
156
157
158        assert allclose (x[:], self.X.flat)
159        assert allclose (y[:], self.Y.flat)
160        assert allclose (z[:], self.F.flat)
161
162        V = volumes
163
164        P = len(self.domain)
165        for k in range(P):
166            assert V[k, 0] == 3*k
167            assert V[k, 1] == 3*k+1
168            assert V[k, 2] == 3*k+2
169
170
171        fid.close()
172
173        #Cleanup
174        os.remove(sww.filename)
175
176
177    def test_sww_constant_smooth(self):
178        """Test that constant sww information can be written correctly
179        (non smooth)
180        """
181
182        import time, os
183        from Numeric import array, zeros, allclose, Float, concatenate
184        from Scientific.IO.NetCDF import NetCDFFile
185
186        self.domain.filename = 'datatest' + str(id(self))
187        self.domain.format = 'sww'
188        self.domain.smooth = True
189
190        sww = get_dataobject(self.domain)
191        sww.store_connectivity()
192
193        #Check contents
194        #Get NetCDF
195        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
196
197        # Get the variables
198        x = fid.variables['x']
199        y = fid.variables['y']
200        z = fid.variables['elevation']
201
202        volumes = fid.variables['volumes']
203
204        X = x[:]
205        Y = y[:]
206
207        assert allclose([X[0], Y[0]], array([0.0, 0.0]))
208        assert allclose([X[1], Y[1]], array([0.0, 0.5]))
209        assert allclose([X[2], Y[2]], array([0.0, 1.0]))
210
211        assert allclose([X[4], Y[4]], array([0.5, 0.5]))
212
213        assert allclose([X[7], Y[7]], array([1.0, 0.5]))
214
215        Z = z[:]
216        assert Z[4] == -0.5
217
218        V = volumes
219        assert V[2,0] == 4
220        assert V[2,1] == 5
221        assert V[2,2] == 1
222
223        assert V[4,0] == 6
224        assert V[4,1] == 7
225        assert V[4,2] == 3
226
227
228        fid.close()
229
230        #Cleanup
231        os.remove(sww.filename)
232
233
234
235    def test_sww_variable(self):
236        """Test that sww information can be written correctly
237        """
238
239        import time, os
240        from Numeric import array, zeros, allclose, Float, concatenate
241        from Scientific.IO.NetCDF import NetCDFFile
242
243        self.domain.filename = 'datatest' + str(id(self))
244        self.domain.format = 'sww'
245        self.domain.smooth = True
246        self.domain.reduction = mean
247
248        sww = get_dataobject(self.domain)
249        sww.store_connectivity()
250        sww.store_timestep('stage')
251
252        #Check contents
253        #Get NetCDF
254        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
255
256
257        # Get the variables
258        x = fid.variables['x']
259        y = fid.variables['y']
260        z = fid.variables['elevation']
261        time = fid.variables['time']
262        stage = fid.variables['stage']
263
264
265        Q = self.domain.quantities['stage']
266        Q0 = Q.vertex_values[:,0]
267        Q1 = Q.vertex_values[:,1]
268        Q2 = Q.vertex_values[:,2]
269
270        A = stage[0,:]
271        #print A[0], (Q2[0,0] + Q1[1,0])/2
272        assert allclose(A[0], (Q2[0] + Q1[1])/2)
273        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
274        assert allclose(A[2], Q0[3])
275        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
276
277        #Center point
278        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
279                                 Q0[5] + Q2[6] + Q1[7])/6)
280
281
282
283        fid.close()
284
285        #Cleanup
286        os.remove(sww.filename)
287
288
289    def test_sww_variable2(self):
290        """Test that sww information can be written correctly
291        multiple timesteps. Use average as reduction operator
292        """
293
294        import time, os
295        from Numeric import array, zeros, allclose, Float, concatenate
296        from Scientific.IO.NetCDF import NetCDFFile
297
298        self.domain.filename = 'datatest' + str(id(self))
299        self.domain.format = 'sww'
300        self.domain.smooth = True
301
302        self.domain.reduction = mean
303
304        sww = get_dataobject(self.domain)
305        sww.store_connectivity()
306        sww.store_timestep('stage')
307        self.domain.evolve_to_end(finaltime = 0.01)
308        sww.store_timestep('stage')
309
310
311        #Check contents
312        #Get NetCDF
313        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
314
315        # Get the variables
316        x = fid.variables['x']
317        y = fid.variables['y']
318        z = fid.variables['elevation']
319        time = fid.variables['time']
320        stage = fid.variables['stage']
321
322        #Check values
323        Q = self.domain.quantities['stage']
324        Q0 = Q.vertex_values[:,0]
325        Q1 = Q.vertex_values[:,1]
326        Q2 = Q.vertex_values[:,2]
327
328        A = stage[1,:]
329        assert allclose(A[0], (Q2[0] + Q1[1])/2)
330        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
331        assert allclose(A[2], Q0[3])
332        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
333
334        #Center point
335        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
336                                 Q0[5] + Q2[6] + Q1[7])/6)
337
338
339
340
341
342        fid.close()
343
344        #Cleanup
345        os.remove(sww.filename)
346
347    def test_sww_variable3(self):
348        """Test that sww information can be written correctly
349        multiple timesteps using a different reduction operator (min)
350        """
351
352        import time, os
353        from Numeric import array, zeros, allclose, Float, concatenate
354        from Scientific.IO.NetCDF import NetCDFFile
355
356        self.domain.filename = 'datatest' + str(id(self))
357        self.domain.format = 'sww'
358        self.domain.smooth = True
359        self.domain.reduction = min
360
361        sww = get_dataobject(self.domain)
362        sww.store_connectivity()
363        sww.store_timestep('stage')
364
365        self.domain.evolve_to_end(finaltime = 0.01)
366        sww.store_timestep('stage')
367
368
369        #Check contents
370        #Get NetCDF
371        fid = NetCDFFile(sww.filename, 'r')
372
373
374        # Get the variables
375        x = fid.variables['x']
376        y = fid.variables['y']
377        z = fid.variables['elevation']
378        time = fid.variables['time']
379        stage = fid.variables['stage']
380
381        #Check values
382        Q = self.domain.quantities['stage']
383        Q0 = Q.vertex_values[:,0]
384        Q1 = Q.vertex_values[:,1]
385        Q2 = Q.vertex_values[:,2]
386
387        A = stage[1,:]
388        assert allclose(A[0], min(Q2[0], Q1[1]))
389        assert allclose(A[1], min(Q0[1], Q1[3], Q2[2]))
390        assert allclose(A[2], Q0[3])
391        assert allclose(A[3], min(Q0[0], Q1[5], Q2[4]))
392
393        #Center point
394        assert allclose(A[4], min(Q1[0], Q2[1], Q0[2],\
395                                  Q0[5], Q2[6], Q1[7]))
396
397
398        fid.close()
399
400        #Cleanup
401        os.remove(sww.filename)
402
403
404    def test_sync(self):
405        """Test info stored at each timestep is as expected (incl initial condition)
406        """
407
408        import time, os, config
409        from Numeric import array, zeros, allclose, Float, concatenate
410        from Scientific.IO.NetCDF import NetCDFFile
411
412        self.domain.filename = 'synctest'
413        self.domain.format = 'sww'
414        self.domain.smooth = False
415        self.domain.store = True
416        self.domain.beta_h = 0
417
418        #Evolution
419        for t in self.domain.evolve(yieldstep = 1.0, finaltime = 4.0):
420            stage = self.domain.quantities['stage'].vertex_values
421
422            #Get NetCDF
423            fid = NetCDFFile(self.domain.writer.filename, 'r')
424            stage_file = fid.variables['stage']
425
426            if t == 0.0:
427                assert allclose(stage, self.initial_stage)
428                assert allclose(stage_file[:], stage.flat)
429            else:
430                assert not allclose(stage, self.initial_stage)
431                assert not allclose(stage_file[:], stage.flat)
432
433            fid.close()
434        os.remove(self.domain.writer.filename)
435
436
437
438    def test_sww_DSG(self):
439        """Not a test, rather a look at the sww format
440        """
441
442        import time, os
443        from Numeric import array, zeros, allclose, Float, concatenate
444        from Scientific.IO.NetCDF import NetCDFFile
445
446        self.domain.filename = 'datatest' + str(id(self))
447        self.domain.format = 'sww'
448        self.domain.smooth = True
449        self.domain.reduction = mean
450
451        sww = get_dataobject(self.domain)
452        sww.store_connectivity()
453        sww.store_timestep('stage')
454
455        #Check contents
456        #Get NetCDF
457        fid = NetCDFFile(sww.filename, 'r')
458
459        # Get the variables
460        x = fid.variables['x']
461        y = fid.variables['y']
462        z = fid.variables['elevation']
463
464        volumes = fid.variables['volumes']
465        time = fid.variables['time']
466
467        # 2D
468        stage = fid.variables['stage']
469
470        X = x[:]
471        Y = y[:]
472        Z = z[:]
473        V = volumes[:]
474        T = time[:]
475        S = stage[:,:]
476
477#         print "****************************"
478#         print "X ",X
479#         print "****************************"
480#         print "Y ",Y
481#         print "****************************"
482#         print "Z ",Z
483#         print "****************************"
484#         print "V ",V
485#         print "****************************"
486#         print "Time ",T
487#         print "****************************"
488#         print "Stage ",S
489#         print "****************************"
490
491
492        fid.close()
493
494        #Cleanup
495        os.remove(sww.filename)
496
497
498
499    def test_dem2pts(self):
500        """Test conversion from dem in ascii format to native NetCDF xya format
501        """
502
503        import time, os
504        from Numeric import array, zeros, allclose, Float, concatenate
505        from Scientific.IO.NetCDF import NetCDFFile
506
507        #Write test asc file
508        root = 'demtest'
509
510        filename = root+'.asc'
511        fid = open(filename, 'w')
512        fid.write("""ncols         5
513nrows         6
514xllcorner     2000.5
515yllcorner     3000.5
516cellsize      25
517NODATA_value  -9999
518""")
519        #Create linear function
520
521        ref_points = []
522        ref_elevation = []
523        for i in range(6):
524            y = (6-i)*25.0
525            for j in range(5):
526                x = j*25.0
527                z = x+2*y
528
529                ref_points.append( [x,y] )
530                ref_elevation.append(z)
531                fid.write('%f ' %z)
532            fid.write('\n')
533
534        fid.close()
535
536        #Write prj file with metadata
537        metafilename = root+'.prj'
538        fid = open(metafilename, 'w')
539
540
541        fid.write("""Projection UTM
542Zone 56
543Datum WGS84
544Zunits NO
545Units METERS
546Spheroid WGS84
547Xshift 0.0000000000
548Yshift 10000000.0000000000
549Parameters
550""")
551        fid.close()
552
553        #Convert to NetCDF pts
554        convert_dem_from_ascii2netcdf(root)
555        dem2pts(root)
556
557        #Check contents
558        #Get NetCDF
559        fid = NetCDFFile(root+'.pts', 'r')
560
561        # Get the variables
562        #print fid.variables.keys()
563        points = fid.variables['points']
564        elevation = fid.variables['elevation']
565
566        #Check values
567
568        #print points[:]
569        #print ref_points
570        assert allclose(points, ref_points)
571
572        #print attributes[:]
573        #print ref_elevation
574        assert allclose(elevation, ref_elevation)
575
576        #Cleanup
577        fid.close()
578
579
580        os.remove(root + '.pts')
581        os.remove(root + '.dem')
582        os.remove(root + '.asc')
583        os.remove(root + '.prj')
584
585
586
587    def test_sww2asc_elevation(self):
588        """Test that sww information can be converted correctly to asc/prj
589        format readable by e.g. ArcView
590        """
591
592        import time, os
593        from Numeric import array, zeros, allclose, Float, concatenate
594        from Scientific.IO.NetCDF import NetCDFFile
595
596        #Setup
597        self.domain.filename = 'datatest'
598       
599        prjfile = self.domain.filename + '_elevation.prj'
600        ascfile = self.domain.filename + '_elevation.asc'       
601        swwfile = self.domain.filename + '.sww'
602       
603        self.domain.set_datadir('.')
604        self.domain.format = 'sww'
605        self.domain.smooth = True
606        self.domain.set_quantity('elevation', lambda x,y: -x-y)
607
608        self.domain.geo_reference = Geo_reference(56,308500,6189000)
609       
610        sww = get_dataobject(self.domain)
611        sww.store_connectivity()
612        sww.store_timestep('stage')
613
614        self.domain.evolve_to_end(finaltime = 0.01)
615        sww.store_timestep('stage')
616
617        cellsize = 0.25
618        #Check contents
619        #Get NetCDF
620
621        fid = NetCDFFile(sww.filename, 'r')
622
623        # Get the variables
624        x = fid.variables['x'][:]
625        y = fid.variables['y'][:]
626        z = fid.variables['elevation'][:]
627        time = fid.variables['time'][:]
628        stage = fid.variables['stage'][:]
629
630
631        #Export to ascii/prj files
632        sww2asc(self.domain.filename, 
633                quantity = 'elevation',                         
634                cellsize = cellsize,
635                verbose = False)
636
637
638        #Check prj (meta data)
639        prjid = open(prjfile)
640        lines = prjid.readlines()
641        prjid.close()
642
643        L = lines[0].strip().split()
644        assert L[0].strip().lower() == 'projection'
645        assert L[1].strip().lower() == 'utm'
646
647        L = lines[1].strip().split()
648        assert L[0].strip().lower() == 'zone'
649        assert L[1].strip().lower() == '56'
650
651        L = lines[2].strip().split()
652        assert L[0].strip().lower() == 'datum'
653        assert L[1].strip().lower() == 'wgs84'
654
655        L = lines[3].strip().split()
656        assert L[0].strip().lower() == 'zunits'
657        assert L[1].strip().lower() == 'no'                       
658
659        L = lines[4].strip().split()
660        assert L[0].strip().lower() == 'units'
661        assert L[1].strip().lower() == 'meters'               
662
663        L = lines[5].strip().split()
664        assert L[0].strip().lower() == 'spheroid'
665        assert L[1].strip().lower() == 'wgs84'
666
667        L = lines[6].strip().split()
668        assert L[0].strip().lower() == 'xshift'
669        assert L[1].strip().lower() == '500000'
670
671        L = lines[7].strip().split()
672        assert L[0].strip().lower() == 'yshift'
673        assert L[1].strip().lower() == '10000000'       
674
675        L = lines[8].strip().split()
676        assert L[0].strip().lower() == 'parameters'
677       
678
679        #Check asc file
680        ascid = open(ascfile)
681        lines = ascid.readlines()
682        ascid.close()       
683
684        L = lines[0].strip().split()
685        assert L[0].strip().lower() == 'ncols'
686        assert L[1].strip().lower() == '5'
687
688        L = lines[1].strip().split()
689        assert L[0].strip().lower() == 'nrows'
690        assert L[1].strip().lower() == '5'       
691
692        L = lines[2].strip().split()
693        assert L[0].strip().lower() == 'xllcorner'
694        assert allclose(float(L[1].strip().lower()), 308500)
695
696        L = lines[3].strip().split()
697        assert L[0].strip().lower() == 'yllcorner'
698        assert allclose(float(L[1].strip().lower()), 6189000)
699
700        L = lines[4].strip().split()
701        assert L[0].strip().lower() == 'cellsize'
702        assert allclose(float(L[1].strip().lower()), cellsize)
703
704        L = lines[5].strip().split()
705        assert L[0].strip() == 'NODATA_value'
706        assert L[1].strip().lower() == '-9999'       
707
708        #Check grid values
709        for j in range(5):
710            L = lines[6+j].strip().split()           
711            y = (4-j) * cellsize
712            for i in range(5):
713                assert allclose(float(L[i]), -i*cellsize - y)
714       
715
716        fid.close()
717
718        #Cleanup
719        os.remove(prjfile)
720        os.remove(ascfile)       
721        os.remove(swwfile)
722
723
724    def test_sww2asc_stage_reduction(self):
725        """Test that sww information can be converted correctly to asc/prj
726        format readable by e.g. ArcView
727
728        This tests the reduction of quantity stage using min
729        """
730
731        import time, os
732        from Numeric import array, zeros, allclose, Float, concatenate
733        from Scientific.IO.NetCDF import NetCDFFile
734
735        #Setup
736        self.domain.filename = 'datatest'
737       
738        prjfile = self.domain.filename + '_stage.prj'
739        ascfile = self.domain.filename + '_stage.asc'               
740        swwfile = self.domain.filename + '.sww'
741       
742        self.domain.set_datadir('.')
743        self.domain.format = 'sww'
744        self.domain.smooth = True
745        self.domain.set_quantity('elevation', lambda x,y: -x-y)
746
747        self.domain.geo_reference = Geo_reference(56,308500,6189000)
748       
749       
750        sww = get_dataobject(self.domain)
751        sww.store_connectivity()
752        sww.store_timestep('stage')
753
754        self.domain.evolve_to_end(finaltime = 0.01)
755        sww.store_timestep('stage')
756
757        cellsize = 0.25
758        #Check contents
759        #Get NetCDF
760
761        fid = NetCDFFile(sww.filename, 'r')
762
763        # Get the variables
764        x = fid.variables['x'][:]
765        y = fid.variables['y'][:]
766        z = fid.variables['elevation'][:]
767        time = fid.variables['time'][:]
768        stage = fid.variables['stage'][:]
769
770
771        #Export to ascii/prj files
772        sww2asc(self.domain.filename, 
773                quantity = 'stage',                         
774                cellsize = cellsize,
775                reduction = min)
776
777
778        #Check asc file
779        ascid = open(ascfile)
780        lines = ascid.readlines()
781        ascid.close()       
782
783        L = lines[0].strip().split()
784        assert L[0].strip().lower() == 'ncols'
785        assert L[1].strip().lower() == '5'
786
787        L = lines[1].strip().split()
788        assert L[0].strip().lower() == 'nrows'
789        assert L[1].strip().lower() == '5'       
790
791        L = lines[2].strip().split()
792        assert L[0].strip().lower() == 'xllcorner'
793        assert allclose(float(L[1].strip().lower()), 308500)
794
795        L = lines[3].strip().split()
796        assert L[0].strip().lower() == 'yllcorner'
797        assert allclose(float(L[1].strip().lower()), 6189000)
798
799        L = lines[4].strip().split()
800        assert L[0].strip().lower() == 'cellsize'
801        assert allclose(float(L[1].strip().lower()), cellsize)
802
803        L = lines[5].strip().split()
804        assert L[0].strip() == 'NODATA_value'
805        assert L[1].strip().lower() == '-9999'       
806
807       
808        #Check grid values (where applicable)
809        for j in range(5):
810            if j%2 == 0:
811                L = lines[6+j].strip().split()           
812                jj = 4-j
813                for i in range(5):
814                    if i%2 == 0:
815                        index = jj/2 + i/2*3
816                        val0 = stage[0,index]
817                        val1 = stage[1,index]
818
819                        #print i, j, index, ':', L[i], val0, val1
820                        assert allclose(float(L[i]), min(val0, val1))
821
822
823        fid.close()
824
825        #Cleanup
826        os.remove(prjfile)
827        os.remove(ascfile)       
828        #os.remove(swwfile)
829
830
831
832
833    def test_sww2asc_missing_points(self):
834        """Test that sww information can be converted correctly to asc/prj
835        format readable by e.g. ArcView
836
837        This test includes the writing of missing values
838        """
839       
840        import time, os
841        from Numeric import array, zeros, allclose, Float, concatenate
842        from Scientific.IO.NetCDF import NetCDFFile
843
844        #Setup mesh not coinciding with rectangle.
845        #This will cause missing values to occur in gridded data
846
847
848        points = [                        [1.0, 1.0],
849                              [0.5, 0.5], [1.0, 0.5],
850                  [0.0, 0.0], [0.5, 0.0], [1.0, 0.0]]           
851
852        vertices = [ [4,1,3], [5,2,4], [1,4,2], [2,0,1]]
853       
854        #Create shallow water domain
855        domain = Domain(points, vertices)
856        domain.default_order=2
857
858
859        #Set some field values
860        domain.set_quantity('elevation', lambda x,y: -x-y)       
861        domain.set_quantity('friction', 0.03)
862
863
864        ######################
865        # Boundary conditions
866        B = Transmissive_boundary(domain)
867        domain.set_boundary( {'exterior': B} )
868
869
870        ######################
871        #Initial condition - with jumps
872
873        bed = domain.quantities['elevation'].vertex_values
874        stage = zeros(bed.shape, Float)
875
876        h = 0.3
877        for i in range(stage.shape[0]):
878            if i % 2 == 0:
879                stage[i,:] = bed[i,:] + h
880            else:
881                stage[i,:] = bed[i,:]
882
883        domain.set_quantity('stage', stage)
884        domain.distribute_to_vertices_and_edges()
885
886        domain.filename = 'datatest'
887       
888        prjfile = domain.filename + '_elevation.prj'
889        ascfile = domain.filename + '_elevation.asc'               
890        swwfile = domain.filename + '.sww'
891       
892        domain.set_datadir('.')
893        domain.format = 'sww'
894        domain.smooth = True
895
896        domain.geo_reference = Geo_reference(56,308500,6189000)
897       
898        sww = get_dataobject(domain)
899        sww.store_connectivity()
900        sww.store_timestep('stage')
901
902        cellsize = 0.25
903        #Check contents
904        #Get NetCDF
905
906        fid = NetCDFFile(swwfile, 'r')
907
908        # Get the variables
909        x = fid.variables['x'][:]
910        y = fid.variables['y'][:]
911        z = fid.variables['elevation'][:]
912        time = fid.variables['time'][:]
913
914        try:
915            geo_reference = Geo_reference(NetCDFObject=fid)
916        except AttributeError, e:
917            geo_reference = Geo_reference(DEFAULT_ZONE,0,0)
918       
919        #Export to ascii/prj files
920        sww2asc(domain.filename, 
921                quantity = 'elevation',                         
922                cellsize = cellsize,
923                verbose = False)
924
925
926        #Check asc file
927        ascid = open(ascfile)
928        lines = ascid.readlines()
929        ascid.close()       
930
931        L = lines[0].strip().split()
932        assert L[0].strip().lower() == 'ncols'
933        assert L[1].strip().lower() == '5'
934
935        L = lines[1].strip().split()
936        assert L[0].strip().lower() == 'nrows'
937        assert L[1].strip().lower() == '5'       
938
939        L = lines[2].strip().split()
940        assert L[0].strip().lower() == 'xllcorner'
941        assert allclose(float(L[1].strip().lower()), 308500)
942
943        L = lines[3].strip().split()
944        assert L[0].strip().lower() == 'yllcorner'
945        assert allclose(float(L[1].strip().lower()), 6189000)
946
947        L = lines[4].strip().split()
948        assert L[0].strip().lower() == 'cellsize'
949        assert allclose(float(L[1].strip().lower()), cellsize)
950
951        L = lines[5].strip().split()
952        assert L[0].strip() == 'NODATA_value'
953        assert L[1].strip().lower() == '-9999'       
954
955
956        #Check grid values
957        for j in range(5):
958            L = lines[6+j].strip().split()           
959            y = (4-j) * cellsize
960            for i in range(5):
961                if i+j >= 4:
962                    assert allclose(float(L[i]), -i*cellsize - y)
963                else:
964                    #Missing values
965                    assert allclose(float(L[i]), -9999)
966
967
968
969        fid.close()
970
971        #Cleanup
972        os.remove(prjfile)
973        os.remove(ascfile)       
974        os.remove(swwfile)
975
976
977    def test_ferret2sww(self):
978        """Test that georeferencing etc works when converting from
979        ferret format (lat/lon) to sww format (UTM)
980        """
981        from Scientific.IO.NetCDF import NetCDFFile
982
983        #The test file has
984        # LON = 150.66667, 150.83334, 151, 151.16667
985        # LAT = -34.5, -34.33333, -34.16667, -34 ;
986        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
987        #
988        # First value (index=0) in small_ha.nc is 0.3400644 cm,
989        # Fourth value (index==3) is -6.50198 cm
990
991
992        from coordinate_transforms.redfearn import redfearn
993
994        fid = NetCDFFile('small_ha.nc')
995        first_value = fid.variables['HA'][:][0,0,0]
996        fourth_value = fid.variables['HA'][:][0,0,3]
997
998
999        #Call conversion (with zero origin)
1000        ferret2sww('small', verbose=False,
1001                   origin = (56, 0, 0))
1002
1003
1004        #Work out the UTM coordinates for first point
1005        zone, e, n = redfearn(-34.5, 150.66667)
1006        #print zone, e, n
1007
1008        #Read output file 'small.sww'
1009        fid = NetCDFFile('small.sww')
1010
1011        x = fid.variables['x'][:]
1012        y = fid.variables['y'][:]
1013
1014        #Check that first coordinate is correctly represented
1015        assert allclose(x[0], e)
1016        assert allclose(y[0], n)
1017
1018        #Check first value
1019        stage = fid.variables['stage'][:]
1020        xmomentum = fid.variables['xmomentum'][:]
1021        ymomentum = fid.variables['ymomentum'][:]       
1022
1023        #print ymomentum
1024               
1025        assert allclose(stage[0,0], first_value/100)  #Meters
1026
1027        #Check fourth value
1028        assert allclose(stage[0,3], fourth_value/100)  #Meters
1029
1030        fid.close()
1031
1032        #Cleanup
1033        import os
1034        os.remove('small.sww')
1035
1036
1037
1038    def test_ferret2sww_2(self):
1039        """Test that georeferencing etc works when converting from
1040        ferret format (lat/lon) to sww format (UTM)
1041        """
1042        from Scientific.IO.NetCDF import NetCDFFile
1043
1044        #The test file has
1045        # LON = 150.66667, 150.83334, 151, 151.16667
1046        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1047        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1048        #
1049        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1050        # Fourth value (index==3) is -6.50198 cm
1051
1052
1053        from coordinate_transforms.redfearn import redfearn
1054
1055        fid = NetCDFFile('small_ha.nc')
1056
1057        #Pick a coordinate and a value
1058
1059        time_index = 1
1060        lat_index = 0
1061        lon_index = 2
1062
1063        test_value = fid.variables['HA'][:][time_index, lat_index, lon_index]
1064        test_time = fid.variables['TIME'][:][time_index]
1065        test_lat = fid.variables['LAT'][:][lat_index]
1066        test_lon = fid.variables['LON'][:][lon_index]
1067
1068        linear_point_index = lat_index*4 + lon_index
1069        fid.close()
1070
1071        #Call conversion (with zero origin)
1072        ferret2sww('small', verbose=False,
1073                   origin = (56, 0, 0))
1074
1075
1076        #Work out the UTM coordinates for test point
1077        zone, e, n = redfearn(test_lat, test_lon)
1078
1079        #Read output file 'small.sww'
1080        fid = NetCDFFile('small.sww')
1081
1082        x = fid.variables['x'][:]
1083        y = fid.variables['y'][:]
1084
1085        #Check that test coordinate is correctly represented
1086        assert allclose(x[linear_point_index], e)
1087        assert allclose(y[linear_point_index], n)
1088
1089        #Check test value
1090        stage = fid.variables['stage'][:]
1091
1092        assert allclose(stage[time_index, linear_point_index], test_value/100)
1093
1094        fid.close()
1095
1096        #Cleanup
1097        import os
1098        os.remove('small.sww')
1099
1100
1101
1102    def test_ferret2sww3(self):
1103        """
1104        """
1105        from Scientific.IO.NetCDF import NetCDFFile
1106
1107        #The test file has
1108        # LON = 150.66667, 150.83334, 151, 151.16667
1109        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1110        # ELEVATION = [-1 -2 -3 -4
1111        #             -5 -6 -7 -8
1112        #              ...
1113        #              ...      -16]
1114        # where the top left corner is -1m,
1115        # and the ll corner is -13.0m
1116        #
1117        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1118        # Fourth value (index==3) is -6.50198 cm
1119
1120        from coordinate_transforms.redfearn import redfearn
1121        import os
1122        fid1 = NetCDFFile('test_ha.nc','w')
1123        fid2 = NetCDFFile('test_ua.nc','w')
1124        fid3 = NetCDFFile('test_va.nc','w')
1125        fid4 = NetCDFFile('test_e.nc','w')
1126
1127        h1_list = [150.66667,150.83334,151.]
1128        h2_list = [-34.5,-34.33333]
1129
1130        long_name = 'LON'
1131        lat_name = 'LAT'
1132
1133        nx = 3
1134        ny = 2
1135
1136        for fid in [fid1,fid2,fid3]:
1137            fid.createDimension(long_name,nx)
1138            fid.createVariable(long_name,'d',(long_name,))
1139            fid.variables[long_name].point_spacing='uneven'
1140            fid.variables[long_name].units='degrees_east'
1141            fid.variables[long_name].assignValue(h1_list)
1142
1143            fid.createDimension(lat_name,ny)
1144            fid.createVariable(lat_name,'d',(lat_name,))
1145            fid.variables[lat_name].point_spacing='uneven'
1146            fid.variables[lat_name].units='degrees_north'
1147            fid.variables[lat_name].assignValue(h2_list)
1148       
1149            fid.createDimension('TIME',2)
1150            fid.createVariable('TIME','d',('TIME',))
1151            fid.variables['TIME'].point_spacing='uneven'
1152            fid.variables['TIME'].units='seconds'
1153            fid.variables['TIME'].assignValue([0.,1.])
1154            if fid == fid3: break
1155       
1156
1157        for fid in [fid4]:
1158            fid.createDimension(long_name,nx)
1159            fid.createVariable(long_name,'d',(long_name,))
1160            fid.variables[long_name].point_spacing='uneven'
1161            fid.variables[long_name].units='degrees_east'
1162            fid.variables[long_name].assignValue(h1_list)
1163
1164            fid.createDimension(lat_name,ny)
1165            fid.createVariable(lat_name,'d',(lat_name,))
1166            fid.variables[lat_name].point_spacing='uneven'
1167            fid.variables[lat_name].units='degrees_north'
1168            fid.variables[lat_name].assignValue(h2_list)
1169
1170        name = {}
1171        name[fid1]='HA'
1172        name[fid2]='UA'
1173        name[fid3]='VA'
1174        name[fid4]='ELEVATION'
1175       
1176        units = {}
1177        units[fid1]='cm'
1178        units[fid2]='cm/s'
1179        units[fid3]='cm/s'
1180        units[fid4]='m'
1181
1182        values = {}
1183        values[fid1]=[[[5., 10.,15.], [13.,18.,23.]],[[50.,100.,150.],[130.,180.,230.]]]
1184        values[fid2]=[[[1., 2.,3.], [4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]]
1185        values[fid3]=[[[13., 12.,11.], [10.,9.,8.]],[[7.,6.,5.],[4.,3.,2.]]]
1186        values[fid4]=[[-3000,-3100,-3200],[-4000,-5000,-6000]]
1187       
1188        for fid in [fid1,fid2,fid3]:
1189          fid.createVariable(name[fid],'d',('TIME',lat_name,long_name))
1190          fid.variables[name[fid]].point_spacing='uneven'
1191          fid.variables[name[fid]].units=units[fid]
1192          fid.variables[name[fid]].assignValue(values[fid])
1193          fid.variables[name[fid]].missing_value = -99999999.
1194          if fid == fid3: break
1195
1196        for fid in [fid4]:
1197            fid.createVariable(name[fid],'d',(lat_name,long_name))
1198            fid.variables[name[fid]].point_spacing='uneven'
1199            fid.variables[name[fid]].units=units[fid]
1200            fid.variables[name[fid]].assignValue(values[fid])
1201            fid.variables[name[fid]].missing_value = -99999999.
1202
1203       
1204        fid1.sync(); fid1.close()
1205        fid2.sync(); fid2.close()
1206        fid3.sync(); fid3.close()
1207        fid4.sync(); fid4.close()
1208
1209        fid1 = NetCDFFile('test_ha.nc','r')
1210        fid2 = NetCDFFile('test_e.nc','r')
1211        fid3 = NetCDFFile('test_va.nc','r')
1212       
1213
1214        first_amp = fid1.variables['HA'][:][0,0,0]
1215        third_amp = fid1.variables['HA'][:][0,0,2]
1216        first_elevation = fid2.variables['ELEVATION'][0,0]
1217        third_elevation= fid2.variables['ELEVATION'][:][0,2]
1218        first_speed = fid3.variables['VA'][0,0,0]
1219        third_speed = fid3.variables['VA'][:][0,0,2]
1220
1221        fid1.close()
1222        fid2.close()
1223        fid3.close()
1224
1225        #Call conversion (with zero origin)
1226        ferret2sww('test', verbose=False,
1227                   origin = (56, 0, 0))
1228
1229        os.remove('test_va.nc')
1230        os.remove('test_ua.nc')
1231        os.remove('test_ha.nc')
1232        os.remove('test_e.nc')
1233
1234        #Read output file 'test.sww'
1235        fid = NetCDFFile('test.sww')
1236
1237
1238        #Check first value
1239        elevation = fid.variables['elevation'][:]
1240        stage = fid.variables['stage'][:]
1241        xmomentum = fid.variables['xmomentum'][:]
1242        ymomentum = fid.variables['ymomentum'][:]       
1243
1244        #print ymomentum
1245        first_height = first_amp/100 - first_elevation
1246        third_height = third_amp/100 - third_elevation
1247        first_momentum=first_speed*first_height/100
1248        third_momentum=third_speed*third_height/100
1249
1250        assert allclose(ymomentum[0][0],first_momentum)  #Meters
1251        assert allclose(ymomentum[0][2],third_momentum)  #Meters
1252
1253        fid.close()
1254
1255        #Cleanup
1256        os.remove('test.sww')
1257
1258
1259
1260
1261    def test_sww_extent(self):
1262        """Not a test, rather a look at the sww format
1263        """
1264
1265        import time, os
1266        from Numeric import array, zeros, allclose, Float, concatenate
1267        from Scientific.IO.NetCDF import NetCDFFile
1268
1269        self.domain.filename = 'datatest' + str(id(self))
1270        self.domain.format = 'sww'
1271        self.domain.smooth = True
1272        self.domain.reduction = mean
1273        self.domain.set_datadir('.')
1274
1275
1276        sww = get_dataobject(self.domain)
1277        sww.store_connectivity()
1278        sww.store_timestep('stage')
1279        self.domain.time = 2.
1280
1281        #Modify stage at second timestep
1282        stage = self.domain.quantities['stage'].vertex_values
1283        self.domain.set_quantity('stage', stage/2)
1284
1285        sww.store_timestep('stage')
1286
1287        file_and_extension_name = self.domain.filename + ".sww"
1288        #print "file_and_extension_name",file_and_extension_name
1289        [xmin, xmax, ymin, ymax, stagemin, stagemax] = \
1290               extent_sww(file_and_extension_name )
1291
1292        assert allclose(xmin, 0.0)
1293        assert allclose(xmax, 1.0)
1294        assert allclose(ymin, 0.0)
1295        assert allclose(ymax, 1.0)
1296        assert allclose(stagemin, -0.85)
1297        assert allclose(stagemax, 0.15)
1298
1299
1300        #Cleanup
1301        os.remove(sww.filename)
1302
1303
1304    def test_ferret2sww_nz_origin(self):
1305        from Scientific.IO.NetCDF import NetCDFFile
1306        from coordinate_transforms.redfearn import redfearn
1307
1308        #Call conversion (with nonzero origin)
1309        ferret2sww('small', verbose=False,
1310                   origin = (56, 100000, 200000))
1311
1312
1313        #Work out the UTM coordinates for first point
1314        zone, e, n = redfearn(-34.5, 150.66667)
1315
1316        #Read output file 'small.sww'
1317        fid = NetCDFFile('small.sww', 'r')
1318
1319        x = fid.variables['x'][:]
1320        y = fid.variables['y'][:]
1321
1322        #Check that first coordinate is correctly represented
1323        assert allclose(x[0], e-100000)
1324        assert allclose(y[0], n-200000)
1325
1326        fid.close()
1327
1328        #Cleanup
1329        import os
1330        os.remove('small.sww')
1331
1332    def test_sww2domain(self):
1333        ################################################
1334        #Create a test domain, and evolve and save it.
1335        ################################################
1336        from mesh_factory import rectangular
1337        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1338             Constant_height, Time_boundary, Transmissive_boundary
1339        from Numeric import array
1340
1341        #Create basic mesh
1342
1343        yiel=0.01
1344        points, vertices, boundary = rectangular(10,10)
1345       
1346        #Create shallow water domain
1347        domain = Domain(points, vertices, boundary)
1348        domain.geo_reference = Geo_reference(56,11,11)
1349        domain.smooth = False
1350        domain.visualise = False
1351        domain.store = True
1352        domain.filename = 'bedslope'
1353        domain.default_order=2
1354        #Bed-slope and friction
1355        domain.set_quantity('elevation', lambda x,y: -x/3)
1356        domain.set_quantity('friction', 0.1)
1357        # Boundary conditions
1358        from math import sin, pi
1359        Br = Reflective_boundary(domain)
1360        Bt = Transmissive_boundary(domain)
1361        Bd = Dirichlet_boundary([0.2,0.,0.])
1362        Bw = Time_boundary(domain=domain,
1363                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1364
1365        #domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1366        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1367
1368        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1369        #Initial condition
1370        h = 0.05
1371        elevation = domain.quantities['elevation'].vertex_values
1372        domain.set_quantity('stage', elevation + h)
1373        #elevation = domain.get_quantity('elevation')
1374        #domain.set_quantity('stage', elevation + h)
1375
1376        domain.check_integrity()
1377        #Evolution
1378        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1379        #    domain.write_time()
1380            pass
1381
1382
1383        ##########################################
1384        #Import the example's file as a new domain
1385        ##########################################
1386        from data_manager import sww2domain
1387        from Numeric import allclose
1388        import os
1389
1390        filename = domain.datadir+os.sep+domain.filename+'.sww'
1391        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
1392        #points, vertices, boundary = rectangular(15,15)
1393        #domain2.boundary = boundary
1394        ###################
1395        ##NOW TEST IT!!!
1396        ###################
1397
1398        bits = ['vertex_coordinates']
1399        for quantity in ['elevation']+domain.quantities_to_be_stored:
1400            bits.append('quantities["%s"].get_integral()'%quantity)
1401            bits.append('get_quantity("%s")'%quantity)
1402
1403        for bit in bits:
1404            #print 'testing that domain.'+bit+' has been restored'
1405            #print bit
1406        #print 'done'
1407            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1408
1409        ######################################
1410        #Now evolve them both, just to be sure
1411        ######################################x
1412        visualise = False
1413        #visualise = True
1414        domain.visualise = visualise
1415        domain.time = 0.
1416        from time import sleep
1417
1418        final = .1
1419        domain.set_quantity('friction', 0.1)
1420        domain.store = False
1421        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1422
1423
1424        for t in domain.evolve(yieldstep = yiel, finaltime = final):
1425            if visualise: sleep(1.)
1426            #domain.write_time()
1427            pass
1428
1429        final = final - (domain2.starttime-domain.starttime)
1430        #BUT since domain1 gets time hacked back to 0:
1431        final = final + (domain2.starttime-domain.starttime)
1432
1433        domain2.smooth = False
1434        domain2.visualise = visualise
1435        domain2.store = False
1436        domain2.default_order=2
1437        domain2.set_quantity('friction', 0.1)
1438        #Bed-slope and friction
1439        # Boundary conditions
1440        Bd2=Dirichlet_boundary([0.2,0.,0.])
1441        domain2.boundary = domain.boundary
1442        #print 'domain2.boundary'
1443        #print domain2.boundary
1444        domain2.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1445        #domain2.set_boundary({'exterior': Bd})
1446
1447        domain2.check_integrity()
1448
1449        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
1450            if visualise: sleep(1.)
1451            #domain2.write_time()
1452            pass
1453
1454        ###################
1455        ##NOW TEST IT!!!
1456        ##################
1457
1458        bits = [ 'vertex_coordinates']
1459
1460        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
1461            bits.append('quantities["%s"].get_integral()'%quantity)
1462            bits.append('get_quantity("%s")'%quantity)
1463
1464        for bit in bits:
1465            #print bit
1466            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1467       
1468
1469    def test_sww2domain2(self):
1470        ##################################################################
1471        #Same as previous test, but this checks how NaNs are handled.
1472        ##################################################################
1473
1474
1475        from mesh_factory import rectangular
1476        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1477             Constant_height, Time_boundary, Transmissive_boundary
1478        from Numeric import array
1479
1480        #Create basic mesh
1481        points, vertices, boundary = rectangular(2,2)
1482
1483        #Create shallow water domain
1484        domain = Domain(points, vertices, boundary)
1485        domain.smooth = False
1486        domain.visualise = False
1487        domain.store = True
1488        domain.filename = 'bedslope'
1489        domain.default_order=2
1490        domain.quantities_to_be_stored=['stage']
1491
1492        domain.set_quantity('elevation', lambda x,y: -x/3)
1493        domain.set_quantity('friction', 0.1)
1494
1495        from math import sin, pi
1496        Br = Reflective_boundary(domain)
1497        Bt = Transmissive_boundary(domain)
1498        Bd = Dirichlet_boundary([0.2,0.,0.])
1499        Bw = Time_boundary(domain=domain,
1500                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1501
1502        domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1503
1504        h = 0.05
1505        elevation = domain.quantities['elevation'].vertex_values
1506        domain.set_quantity('stage', elevation + h)
1507
1508        domain.check_integrity()
1509       
1510        for t in domain.evolve(yieldstep = 1, finaltime = 2.0):
1511            pass
1512            #domain.write_time()
1513
1514       
1515
1516        ##################################
1517        #Import the file as a new domain
1518        ##################################
1519        from data_manager import sww2domain
1520        from Numeric import allclose
1521        import os
1522
1523        filename = domain.datadir+os.sep+domain.filename+'.sww'
1524
1525        #Fail because NaNs are present
1526        try:
1527            domain2 = sww2domain(filename,boundary,fail_if_NaN=True,verbose=False)
1528            assert True == False
1529        except:
1530            #Now import it, filling NaNs to be 0
1531            filler = 0
1532            domain2 = sww2domain(filename,None,fail_if_NaN=False,NaN_filler = filler,verbose=False)
1533        bits = [ 'geo_reference.get_xllcorner()',
1534                'geo_reference.get_yllcorner()',
1535                'vertex_coordinates']
1536
1537        for quantity in ['elevation']+domain.quantities_to_be_stored:
1538            bits.append('quantities["%s"].get_integral()'%quantity)
1539            bits.append('get_quantity("%s")'%quantity)
1540
1541        for bit in bits:
1542        #    print 'testing that domain.'+bit+' has been restored'
1543            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1544
1545        assert max(max(domain2.get_quantity('xmomentum')))==filler
1546        assert min(min(domain2.get_quantity('xmomentum')))==filler
1547        assert max(max(domain2.get_quantity('ymomentum')))==filler
1548        assert min(min(domain2.get_quantity('ymomentum')))==filler
1549
1550        #print 'passed'
1551
1552        #cleanup
1553        #import os
1554        #os.remove(domain.datadir+'/'+domain.filename+'.sww')
1555
1556
1557    #def test_weed(self):
1558        from data_manager import weed
1559
1560        coordinates1 = [[0.,0.],[1.,0.],[1.,1.],[1.,0.],[2.,0.],[1.,1.]]
1561        volumes1 = [[0,1,2],[3,4,5]]
1562        boundary1= {(0,1): 'external',(1,2): 'not external',(2,0): 'external',(3,4): 'external',(4,5): 'external',(5,3): 'not external'}
1563        coordinates2,volumes2,boundary2=weed(coordinates1,volumes1,boundary1)
1564
1565        points2 = {(0.,0.):None,(1.,0.):None,(1.,1.):None,(2.,0.):None}       
1566
1567        assert len(points2)==len(coordinates2)
1568        for i in range(len(coordinates2)):
1569            coordinate = tuple(coordinates2[i])
1570            assert points2.has_key(coordinate)
1571            points2[coordinate]=i
1572
1573        for triangle in volumes1:
1574            for coordinate in triangle:
1575                assert coordinates2[points2[tuple(coordinates1[coordinate])]][0]==coordinates1[coordinate][0]
1576                assert coordinates2[points2[tuple(coordinates1[coordinate])]][1]==coordinates1[coordinate][1]
1577
1578
1579     #FIXME This fails - smooth makes the comparism too hard for allclose
1580    def ztest_sww2domain3(self):
1581        ################################################
1582        #DOMAIN.SMOOTH = TRUE !!!!!!!!!!!!!!!!!!!!!!!!!!
1583        ################################################
1584        from mesh_factory import rectangular
1585        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1586             Constant_height, Time_boundary, Transmissive_boundary
1587        from Numeric import array
1588        #Create basic mesh
1589
1590        yiel=0.01
1591        points, vertices, boundary = rectangular(10,10)
1592       
1593        #Create shallow water domain
1594        domain = Domain(points, vertices, boundary)
1595        domain.geo_reference = Geo_reference(56,11,11)
1596        domain.smooth = True
1597        domain.visualise = False
1598        domain.store = True
1599        domain.filename = 'bedslope'
1600        domain.default_order=2
1601        #Bed-slope and friction
1602        domain.set_quantity('elevation', lambda x,y: -x/3)
1603        domain.set_quantity('friction', 0.1)
1604        # Boundary conditions
1605        from math import sin, pi
1606        Br = Reflective_boundary(domain)
1607        Bt = Transmissive_boundary(domain)
1608        Bd = Dirichlet_boundary([0.2,0.,0.])
1609        Bw = Time_boundary(domain=domain,
1610                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1611
1612        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1613
1614        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1615        #Initial condition
1616        h = 0.05
1617        elevation = domain.quantities['elevation'].vertex_values
1618        domain.set_quantity('stage', elevation + h)
1619        #elevation = domain.get_quantity('elevation')
1620        #domain.set_quantity('stage', elevation + h)
1621
1622        domain.check_integrity()
1623        #Evolution
1624        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1625        #    domain.write_time()
1626            pass
1627
1628
1629        ##########################################
1630        #Import the example's file as a new domain
1631        ##########################################
1632        from data_manager import sww2domain
1633        from Numeric import allclose
1634        import os
1635
1636        filename = domain.datadir+os.sep+domain.filename+'.sww'
1637        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
1638        #points, vertices, boundary = rectangular(15,15)
1639        #domain2.boundary = boundary
1640        ###################
1641        ##NOW TEST IT!!!
1642        ###################
1643
1644        #FIXME smooth domain so that they can be compared
1645
1646
1647        bits = []#'vertex_coordinates']
1648        for quantity in ['elevation']+domain.quantities_to_be_stored:
1649            bits.append('quantities["%s"].get_integral()'%quantity)
1650            #bits.append('get_quantity("%s")'%quantity)
1651
1652        for bit in bits:
1653            #print 'testing that domain.'+bit+' has been restored'
1654            #print bit
1655            #print 'done'
1656            #print ('domain.'+bit), eval('domain.'+bit)
1657            #print ('domain2.'+bit), eval('domain2.'+bit)
1658            assert allclose(eval('domain.'+bit),eval('domain2.'+bit),rtol=1.0e-1,atol=1.e-3)
1659            pass
1660
1661        ######################################
1662        #Now evolve them both, just to be sure
1663        ######################################x
1664        visualise = False
1665        visualise = True
1666        domain.visualise = visualise
1667        domain.time = 0.
1668        from time import sleep
1669
1670        final = .5
1671        domain.set_quantity('friction', 0.1)
1672        domain.store = False
1673        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Br})
1674
1675        for t in domain.evolve(yieldstep = yiel, finaltime = final):
1676            if visualise: sleep(.03)
1677            #domain.write_time()
1678            pass
1679
1680        domain2.smooth = True
1681        domain2.visualise = visualise
1682        domain2.store = False
1683        domain2.default_order=2
1684        domain2.set_quantity('friction', 0.1)
1685        #Bed-slope and friction
1686        # Boundary conditions
1687        Bd2=Dirichlet_boundary([0.2,0.,0.])
1688        Br2 = Reflective_boundary(domain2)
1689        domain2.boundary = domain.boundary
1690        #print 'domain2.boundary'
1691        #print domain2.boundary
1692        domain2.set_boundary({'left': Bd2, 'right': Bd2, 'top': Bd2, 'bottom': Br2})
1693        #domain2.boundary = domain.boundary
1694        #domain2.set_boundary({'exterior': Bd})
1695
1696        domain2.check_integrity()
1697
1698        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
1699            if visualise: sleep(.03)
1700            #domain2.write_time()
1701            pass
1702
1703        ###################
1704        ##NOW TEST IT!!!
1705        ##################
1706
1707        bits = [ 'vertex_coordinates']
1708
1709        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
1710            #bits.append('quantities["%s"].get_integral()'%quantity)
1711            bits.append('get_quantity("%s")'%quantity)
1712
1713        for bit in bits:
1714            print bit
1715            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1716       
1717
1718#-------------------------------------------------------------
1719if __name__ == "__main__":
1720    suite = unittest.makeSuite(Test_Data_Manager,'test')   
1721    #suite = unittest.makeSuite(Test_Data_Manager,'test_sww2domain')
1722    runner = unittest.TextTestRunner()
1723    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.