source: inundation/ga/storm_surge/pyvolution/test_data_manager.py @ 1453

Last change on this file since 1453 was 1360, checked in by steve, 20 years ago
File size: 50.5 KB
Line 
1#!/usr/bin/env python
2#
3
4import unittest
5import copy
6from Numeric import zeros, array, allclose, Float
7from util import mean
8
9from data_manager import *
10from shallow_water import *
11from config import epsilon
12
13from coordinate_transforms.geo_reference import Geo_reference
14
15class Test_Data_Manager(unittest.TestCase):
16    def setUp(self):
17        import time
18        from mesh_factory import rectangular
19
20
21        #Create basic mesh
22        points, vertices, boundary = rectangular(2, 2)
23
24        #Create shallow water domain
25        domain = Domain(points, vertices, boundary)
26        domain.default_order=2
27
28
29        #Set some field values
30        domain.set_quantity('elevation', lambda x,y: -x)
31        domain.set_quantity('friction', 0.03)
32
33
34        ######################
35        # Boundary conditions
36        B = Transmissive_boundary(domain)
37        domain.set_boundary( {'left': B, 'right': B, 'top': B, 'bottom': B})
38
39
40        ######################
41        #Initial condition - with jumps
42
43
44        bed = domain.quantities['elevation'].vertex_values
45        stage = zeros(bed.shape, Float)
46
47        h = 0.3
48        for i in range(stage.shape[0]):
49            if i % 2 == 0:
50                stage[i,:] = bed[i,:] + h
51            else:
52                stage[i,:] = bed[i,:]
53
54        domain.set_quantity('stage', stage)
55        self.initial_stage = copy.copy(domain.quantities['stage'].vertex_values)
56
57        domain.distribute_to_vertices_and_edges()
58
59
60        self.domain = domain
61
62        C = domain.get_vertex_coordinates()
63        self.X = C[:,0:6:2].copy()
64        self.Y = C[:,1:6:2].copy()
65
66        self.F = bed
67
68
69    def tearDown(self):
70        pass
71
72
73
74
75#     def test_xya(self):
76#         import os
77#         from Numeric import concatenate
78
79#         import time, os
80#         from Numeric import array, zeros, allclose, Float, concatenate
81
82#         domain = self.domain
83
84#         domain.filename = 'datatest' + str(time.time())
85#         domain.format = 'xya'
86#         domain.smooth = True
87
88#         xya = get_dataobject(self.domain)
89#         xya.store_all()
90
91
92#         #Read back
93#         file = open(xya.filename)
94#         lFile = file.read().split('\n')
95#         lFile = lFile[:-1]
96
97#         file.close()
98#         os.remove(xya.filename)
99
100#         #Check contents
101#         if domain.smooth:
102#             self.failUnless(lFile[0] == '9 3 # <vertex #> <x> <y> [attributes]')
103#         else:
104#             self.failUnless(lFile[0] == '24 3 # <vertex #> <x> <y> [attributes]')
105
106#         #Get smoothed field values with X and Y
107#         X,Y,F,V = domain.get_vertex_values(xy=True, value_array='field_values',
108#                                            indices = (0,1), precision = Float)
109
110
111#         Q,V = domain.get_vertex_values(xy=False, value_array='conserved_quantities',
112#                                            indices = (0,), precision = Float)
113
114
115
116#         for i, line in enumerate(lFile[1:]):
117#             fields = line.split()
118
119#             assert len(fields) == 5
120
121#             assert allclose(float(fields[0]), X[i])
122#             assert allclose(float(fields[1]), Y[i])
123#             assert allclose(float(fields[2]), F[i,0])
124#             assert allclose(float(fields[3]), Q[i,0])
125#             assert allclose(float(fields[4]), F[i,1])
126
127
128
129
130    def test_sww_constant(self):
131        """Test that constant sww information can be written correctly
132        (non smooth)
133        """
134
135        import time, os
136        from Numeric import array, zeros, allclose, Float, concatenate
137        from Scientific.IO.NetCDF import NetCDFFile
138
139        self.domain.filename = 'datatest' + str(id(self))
140        self.domain.format = 'sww'
141        self.domain.smooth = False
142
143        sww = get_dataobject(self.domain)
144        sww.store_connectivity()
145
146        #Check contents
147        #Get NetCDF
148        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
149
150        # Get the variables
151        x = fid.variables['x']
152        y = fid.variables['y']
153        z = fid.variables['elevation']
154
155        volumes = fid.variables['volumes']
156
157
158        assert allclose (x[:], self.X.flat)
159        assert allclose (y[:], self.Y.flat)
160        assert allclose (z[:], self.F.flat)
161
162        V = volumes
163
164        P = len(self.domain)
165        for k in range(P):
166            assert V[k, 0] == 3*k
167            assert V[k, 1] == 3*k+1
168            assert V[k, 2] == 3*k+2
169
170
171        fid.close()
172
173        #Cleanup
174        os.remove(sww.filename)
175
176
177    def test_sww_constant_smooth(self):
178        """Test that constant sww information can be written correctly
179        (non smooth)
180        """
181
182        import time, os
183        from Numeric import array, zeros, allclose, Float, concatenate
184        from Scientific.IO.NetCDF import NetCDFFile
185
186        self.domain.filename = 'datatest' + str(id(self))
187        self.domain.format = 'sww'
188        self.domain.smooth = True
189
190        sww = get_dataobject(self.domain)
191        sww.store_connectivity()
192
193        #Check contents
194        #Get NetCDF
195        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
196
197        # Get the variables
198        x = fid.variables['x']
199        y = fid.variables['y']
200        z = fid.variables['elevation']
201
202        volumes = fid.variables['volumes']
203
204        X = x[:]
205        Y = y[:]
206
207        assert allclose([X[0], Y[0]], array([0.0, 0.0]))
208        assert allclose([X[1], Y[1]], array([0.0, 0.5]))
209        assert allclose([X[2], Y[2]], array([0.0, 1.0]))
210
211        assert allclose([X[4], Y[4]], array([0.5, 0.5]))
212
213        assert allclose([X[7], Y[7]], array([1.0, 0.5]))
214
215        Z = z[:]
216        assert Z[4] == -0.5
217
218        V = volumes
219        assert V[2,0] == 4
220        assert V[2,1] == 5
221        assert V[2,2] == 1
222
223        assert V[4,0] == 6
224        assert V[4,1] == 7
225        assert V[4,2] == 3
226
227
228        fid.close()
229
230        #Cleanup
231        os.remove(sww.filename)
232
233
234
235    def test_sww_variable(self):
236        """Test that sww information can be written correctly
237        """
238
239        import time, os
240        from Numeric import array, zeros, allclose, Float, concatenate
241        from Scientific.IO.NetCDF import NetCDFFile
242
243        self.domain.filename = 'datatest' + str(id(self))
244        self.domain.format = 'sww'
245        self.domain.smooth = True
246        self.domain.reduction = mean
247
248        sww = get_dataobject(self.domain)
249        sww.store_connectivity()
250        sww.store_timestep('stage')
251
252        #Check contents
253        #Get NetCDF
254        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
255
256
257        # Get the variables
258        x = fid.variables['x']
259        y = fid.variables['y']
260        z = fid.variables['elevation']
261        time = fid.variables['time']
262        stage = fid.variables['stage']
263
264
265        Q = self.domain.quantities['stage']
266        Q0 = Q.vertex_values[:,0]
267        Q1 = Q.vertex_values[:,1]
268        Q2 = Q.vertex_values[:,2]
269
270        A = stage[0,:]
271        #print A[0], (Q2[0,0] + Q1[1,0])/2
272        assert allclose(A[0], (Q2[0] + Q1[1])/2)
273        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
274        assert allclose(A[2], Q0[3])
275        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
276
277        #Center point
278        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
279                                 Q0[5] + Q2[6] + Q1[7])/6)
280
281
282
283        fid.close()
284
285        #Cleanup
286        os.remove(sww.filename)
287
288
289    def test_sww_variable2(self):
290        """Test that sww information can be written correctly
291        multiple timesteps. Use average as reduction operator
292        """
293
294        import time, os
295        from Numeric import array, zeros, allclose, Float, concatenate
296        from Scientific.IO.NetCDF import NetCDFFile
297
298        self.domain.filename = 'datatest' + str(id(self))
299        self.domain.format = 'sww'
300        self.domain.smooth = True
301
302        self.domain.reduction = mean
303
304        sww = get_dataobject(self.domain)
305        sww.store_connectivity()
306        sww.store_timestep('stage')
307        self.domain.evolve_to_end(finaltime = 0.01)
308        sww.store_timestep('stage')
309
310
311        #Check contents
312        #Get NetCDF
313        fid = NetCDFFile(sww.filename, 'r')  #Open existing file for append
314
315        # Get the variables
316        x = fid.variables['x']
317        y = fid.variables['y']
318        z = fid.variables['elevation']
319        time = fid.variables['time']
320        stage = fid.variables['stage']
321
322        #Check values
323        Q = self.domain.quantities['stage']
324        Q0 = Q.vertex_values[:,0]
325        Q1 = Q.vertex_values[:,1]
326        Q2 = Q.vertex_values[:,2]
327
328        A = stage[1,:]
329        assert allclose(A[0], (Q2[0] + Q1[1])/2)
330        assert allclose(A[1], (Q0[1] + Q1[3] + Q2[2])/3)
331        assert allclose(A[2], Q0[3])
332        assert allclose(A[3], (Q0[0] + Q1[5] + Q2[4])/3)
333
334        #Center point
335        assert allclose(A[4], (Q1[0] + Q2[1] + Q0[2] +\
336                                 Q0[5] + Q2[6] + Q1[7])/6)
337
338
339        fid.close()
340
341        #Cleanup
342        os.remove(sww.filename)
343
344    def test_sww_variable3(self):
345        """Test that sww information can be written correctly
346        multiple timesteps using a different reduction operator (min)
347        """
348
349        import time, os
350        from Numeric import array, zeros, allclose, Float, concatenate
351        from Scientific.IO.NetCDF import NetCDFFile
352
353        self.domain.filename = 'datatest' + str(id(self))
354        self.domain.format = 'sww'
355        self.domain.smooth = True
356        self.domain.reduction = min
357
358        sww = get_dataobject(self.domain)
359        sww.store_connectivity()
360        sww.store_timestep('stage')
361
362        self.domain.evolve_to_end(finaltime = 0.01)
363        sww.store_timestep('stage')
364
365
366        #Check contents
367        #Get NetCDF
368        fid = NetCDFFile(sww.filename, 'r')
369
370
371        # Get the variables
372        x = fid.variables['x']
373        y = fid.variables['y']
374        z = fid.variables['elevation']
375        time = fid.variables['time']
376        stage = fid.variables['stage']
377
378        #Check values
379        Q = self.domain.quantities['stage']
380        Q0 = Q.vertex_values[:,0]
381        Q1 = Q.vertex_values[:,1]
382        Q2 = Q.vertex_values[:,2]
383
384        A = stage[1,:]
385        assert allclose(A[0], min(Q2[0], Q1[1]))
386        assert allclose(A[1], min(Q0[1], Q1[3], Q2[2]))
387        assert allclose(A[2], Q0[3])
388        assert allclose(A[3], min(Q0[0], Q1[5], Q2[4]))
389
390        #Center point
391        assert allclose(A[4], min(Q1[0], Q2[1], Q0[2],\
392                                  Q0[5], Q2[6], Q1[7]))
393
394
395        fid.close()
396
397        #Cleanup
398        os.remove(sww.filename)
399
400
401    def test_sync(self):
402        """Test info stored at each timestep is as expected (incl initial condition)
403        """
404
405        import time, os, config
406        from Numeric import array, zeros, allclose, Float, concatenate
407        from Scientific.IO.NetCDF import NetCDFFile
408
409        self.domain.filename = 'synctest'
410        self.domain.format = 'sww'
411        self.domain.smooth = False
412        self.domain.store = True
413        self.domain.beta_h = 0
414
415        #Evolution
416        for t in self.domain.evolve(yieldstep = 1.0, finaltime = 4.0):
417            stage = self.domain.quantities['stage'].vertex_values
418
419            #Get NetCDF
420            fid = NetCDFFile(self.domain.writer.filename, 'r')
421            stage_file = fid.variables['stage']
422
423            if t == 0.0:
424                assert allclose(stage, self.initial_stage)
425                assert allclose(stage_file[:], stage.flat)
426            else:
427                assert not allclose(stage, self.initial_stage)
428                assert not allclose(stage_file[:], stage.flat)
429
430            fid.close()
431
432        os.remove(self.domain.writer.filename)
433
434
435
436    def test_sww_DSG(self):
437        """Not a test, rather a look at the sww format
438        """
439
440        import time, os
441        from Numeric import array, zeros, allclose, Float, concatenate
442        from Scientific.IO.NetCDF import NetCDFFile
443
444        self.domain.filename = 'datatest' + str(id(self))
445        self.domain.format = 'sww'
446        self.domain.smooth = True
447        self.domain.reduction = mean
448
449        sww = get_dataobject(self.domain)
450        sww.store_connectivity()
451        sww.store_timestep('stage')
452
453        #Check contents
454        #Get NetCDF
455        fid = NetCDFFile(sww.filename, 'r')
456
457        # Get the variables
458        x = fid.variables['x']
459        y = fid.variables['y']
460        z = fid.variables['elevation']
461
462        volumes = fid.variables['volumes']
463        time = fid.variables['time']
464
465        # 2D
466        stage = fid.variables['stage']
467
468        X = x[:]
469        Y = y[:]
470        Z = z[:]
471        V = volumes[:]
472        T = time[:]
473        S = stage[:,:]
474
475#         print "****************************"
476#         print "X ",X
477#         print "****************************"
478#         print "Y ",Y
479#         print "****************************"
480#         print "Z ",Z
481#         print "****************************"
482#         print "V ",V
483#         print "****************************"
484#         print "Time ",T
485#         print "****************************"
486#         print "Stage ",S
487#         print "****************************"
488
489
490        fid.close()
491
492        #Cleanup
493        os.remove(sww.filename)
494
495
496
497    def test_dem2pts(self):
498        """Test conversion from dem in ascii format to native NetCDF xya format
499        """
500
501        import time, os
502        from Numeric import array, zeros, allclose, Float, concatenate
503        from Scientific.IO.NetCDF import NetCDFFile
504
505        #Write test asc file
506        root = 'demtest'
507
508        filename = root+'.asc'
509        fid = open(filename, 'w')
510        fid.write("""ncols         5
511nrows         6
512xllcorner     2000.5
513yllcorner     3000.5
514cellsize      25
515NODATA_value  -9999
516""")
517        #Create linear function
518
519        ref_points = []
520        ref_elevation = []
521        for i in range(6):
522            y = (6-i)*25.0
523            for j in range(5):
524                x = j*25.0
525                z = x+2*y
526
527                ref_points.append( [x,y] )
528                ref_elevation.append(z)
529                fid.write('%f ' %z)
530            fid.write('\n')
531
532        fid.close()
533
534        #Write prj file with metadata
535        metafilename = root+'.prj'
536        fid = open(metafilename, 'w')
537
538
539        fid.write("""Projection UTM
540Zone 56
541Datum WGS84
542Zunits NO
543Units METERS
544Spheroid WGS84
545Xshift 0.0000000000
546Yshift 10000000.0000000000
547Parameters
548""")
549        fid.close()
550
551        #Convert to NetCDF pts
552        convert_dem_from_ascii2netcdf(root)
553        dem2pts(root)
554
555        #Check contents
556        #Get NetCDF
557        fid = NetCDFFile(root+'.pts', 'r')
558
559        # Get the variables
560        #print fid.variables.keys()
561        points = fid.variables['points']
562        elevation = fid.variables['elevation']
563
564        #Check values
565
566        #print points[:]
567        #print ref_points
568        assert allclose(points, ref_points)
569
570        #print attributes[:]
571        #print ref_elevation
572        assert allclose(elevation, ref_elevation)
573
574        #Cleanup
575        fid.close()
576
577
578        os.remove(root + '.pts')
579        os.remove(root + '.dem')
580        os.remove(root + '.asc')
581        os.remove(root + '.prj')
582
583
584
585    def test_sww2asc_elevation(self):
586        """Test that sww information can be converted correctly to asc/prj
587        format readable by e.g. ArcView
588        """
589
590        import time, os
591        from Numeric import array, zeros, allclose, Float, concatenate
592        from Scientific.IO.NetCDF import NetCDFFile
593
594        #Setup
595        self.domain.filename = 'datatest'
596
597        prjfile = self.domain.filename + '_elevation.prj'
598        ascfile = self.domain.filename + '_elevation.asc'
599        swwfile = self.domain.filename + '.sww'
600
601        self.domain.set_datadir('.')
602        self.domain.format = 'sww'
603        self.domain.smooth = True
604        self.domain.set_quantity('elevation', lambda x,y: -x-y)
605
606        self.domain.geo_reference = Geo_reference(56,308500,6189000)
607
608        sww = get_dataobject(self.domain)
609        sww.store_connectivity()
610        sww.store_timestep('stage')
611
612        self.domain.evolve_to_end(finaltime = 0.01)
613        sww.store_timestep('stage')
614
615        cellsize = 0.25
616        #Check contents
617        #Get NetCDF
618
619        fid = NetCDFFile(sww.filename, 'r')
620
621        # Get the variables
622        x = fid.variables['x'][:]
623        y = fid.variables['y'][:]
624        z = fid.variables['elevation'][:]
625        time = fid.variables['time'][:]
626        stage = fid.variables['stage'][:]
627
628
629        #Export to ascii/prj files
630        sww2asc(self.domain.filename,
631                quantity = 'elevation',
632                cellsize = cellsize,
633                verbose = False)
634
635
636        #Check prj (meta data)
637        prjid = open(prjfile)
638        lines = prjid.readlines()
639        prjid.close()
640
641        L = lines[0].strip().split()
642        assert L[0].strip().lower() == 'projection'
643        assert L[1].strip().lower() == 'utm'
644
645        L = lines[1].strip().split()
646        assert L[0].strip().lower() == 'zone'
647        assert L[1].strip().lower() == '56'
648
649        L = lines[2].strip().split()
650        assert L[0].strip().lower() == 'datum'
651        assert L[1].strip().lower() == 'wgs84'
652
653        L = lines[3].strip().split()
654        assert L[0].strip().lower() == 'zunits'
655        assert L[1].strip().lower() == 'no'
656
657        L = lines[4].strip().split()
658        assert L[0].strip().lower() == 'units'
659        assert L[1].strip().lower() == 'meters'
660
661        L = lines[5].strip().split()
662        assert L[0].strip().lower() == 'spheroid'
663        assert L[1].strip().lower() == 'wgs84'
664
665        L = lines[6].strip().split()
666        assert L[0].strip().lower() == 'xshift'
667        assert L[1].strip().lower() == '500000'
668
669        L = lines[7].strip().split()
670        assert L[0].strip().lower() == 'yshift'
671        assert L[1].strip().lower() == '10000000'
672
673        L = lines[8].strip().split()
674        assert L[0].strip().lower() == 'parameters'
675
676
677        #Check asc file
678        ascid = open(ascfile)
679        lines = ascid.readlines()
680        ascid.close()
681
682        L = lines[0].strip().split()
683        assert L[0].strip().lower() == 'ncols'
684        assert L[1].strip().lower() == '5'
685
686        L = lines[1].strip().split()
687        assert L[0].strip().lower() == 'nrows'
688        assert L[1].strip().lower() == '5'
689
690        L = lines[2].strip().split()
691        assert L[0].strip().lower() == 'xllcorner'
692        assert allclose(float(L[1].strip().lower()), 308500)
693
694        L = lines[3].strip().split()
695        assert L[0].strip().lower() == 'yllcorner'
696        assert allclose(float(L[1].strip().lower()), 6189000)
697
698        L = lines[4].strip().split()
699        assert L[0].strip().lower() == 'cellsize'
700        assert allclose(float(L[1].strip().lower()), cellsize)
701
702        L = lines[5].strip().split()
703        assert L[0].strip() == 'NODATA_value'
704        assert L[1].strip().lower() == '-9999'
705
706        #Check grid values
707        for j in range(5):
708            L = lines[6+j].strip().split()
709            y = (4-j) * cellsize
710            for i in range(5):
711                assert allclose(float(L[i]), -i*cellsize - y)
712
713
714        fid.close()
715
716        #Cleanup
717        os.remove(prjfile)
718        os.remove(ascfile)
719        os.remove(swwfile)
720
721
722    def test_sww2asc_stage_reduction(self):
723        """Test that sww information can be converted correctly to asc/prj
724        format readable by e.g. ArcView
725
726        This tests the reduction of quantity stage using min
727        """
728
729        import time, os
730        from Numeric import array, zeros, allclose, Float, concatenate
731        from Scientific.IO.NetCDF import NetCDFFile
732
733        #Setup
734        self.domain.filename = 'datatest'
735
736        prjfile = self.domain.filename + '_stage.prj'
737        ascfile = self.domain.filename + '_stage.asc'
738        swwfile = self.domain.filename + '.sww'
739
740        self.domain.set_datadir('.')
741        self.domain.format = 'sww'
742        self.domain.smooth = True
743        self.domain.set_quantity('elevation', lambda x,y: -x-y)
744
745        self.domain.geo_reference = Geo_reference(56,308500,6189000)
746
747
748        sww = get_dataobject(self.domain)
749        sww.store_connectivity()
750        sww.store_timestep('stage')
751
752        self.domain.evolve_to_end(finaltime = 0.01)
753        sww.store_timestep('stage')
754
755        cellsize = 0.25
756        #Check contents
757        #Get NetCDF
758
759        fid = NetCDFFile(sww.filename, 'r')
760
761        # Get the variables
762        x = fid.variables['x'][:]
763        y = fid.variables['y'][:]
764        z = fid.variables['elevation'][:]
765        time = fid.variables['time'][:]
766        stage = fid.variables['stage'][:]
767
768
769        #Export to ascii/prj files
770        sww2asc(self.domain.filename,
771                quantity = 'stage',
772                cellsize = cellsize,
773                reduction = min)
774
775
776        #Check asc file
777        ascid = open(ascfile)
778        lines = ascid.readlines()
779        ascid.close()
780
781        L = lines[0].strip().split()
782        assert L[0].strip().lower() == 'ncols'
783        assert L[1].strip().lower() == '5'
784
785        L = lines[1].strip().split()
786        assert L[0].strip().lower() == 'nrows'
787        assert L[1].strip().lower() == '5'
788
789        L = lines[2].strip().split()
790        assert L[0].strip().lower() == 'xllcorner'
791        assert allclose(float(L[1].strip().lower()), 308500)
792
793        L = lines[3].strip().split()
794        assert L[0].strip().lower() == 'yllcorner'
795        assert allclose(float(L[1].strip().lower()), 6189000)
796
797        L = lines[4].strip().split()
798        assert L[0].strip().lower() == 'cellsize'
799        assert allclose(float(L[1].strip().lower()), cellsize)
800
801        L = lines[5].strip().split()
802        assert L[0].strip() == 'NODATA_value'
803        assert L[1].strip().lower() == '-9999'
804
805
806        #Check grid values (where applicable)
807        for j in range(5):
808            if j%2 == 0:
809                L = lines[6+j].strip().split()
810                jj = 4-j
811                for i in range(5):
812                    if i%2 == 0:
813                        index = jj/2 + i/2*3
814                        val0 = stage[0,index]
815                        val1 = stage[1,index]
816
817                        #print i, j, index, ':', L[i], val0, val1
818                        assert allclose(float(L[i]), min(val0, val1))
819
820
821        fid.close()
822
823        #Cleanup
824        os.remove(prjfile)
825        os.remove(ascfile)
826        #os.remove(swwfile)
827
828
829
830
831    def test_sww2asc_missing_points(self):
832        """Test that sww information can be converted correctly to asc/prj
833        format readable by e.g. ArcView
834
835        This test includes the writing of missing values
836        """
837
838        import time, os
839        from Numeric import array, zeros, allclose, Float, concatenate
840        from Scientific.IO.NetCDF import NetCDFFile
841
842        #Setup mesh not coinciding with rectangle.
843        #This will cause missing values to occur in gridded data
844
845
846        points = [                        [1.0, 1.0],
847                              [0.5, 0.5], [1.0, 0.5],
848                  [0.0, 0.0], [0.5, 0.0], [1.0, 0.0]]
849
850        vertices = [ [4,1,3], [5,2,4], [1,4,2], [2,0,1]]
851
852        #Create shallow water domain
853        domain = Domain(points, vertices)
854        domain.default_order=2
855
856
857        #Set some field values
858        domain.set_quantity('elevation', lambda x,y: -x-y)
859        domain.set_quantity('friction', 0.03)
860
861
862        ######################
863        # Boundary conditions
864        B = Transmissive_boundary(domain)
865        domain.set_boundary( {'exterior': B} )
866
867
868        ######################
869        #Initial condition - with jumps
870
871        bed = domain.quantities['elevation'].vertex_values
872        stage = zeros(bed.shape, Float)
873
874        h = 0.3
875        for i in range(stage.shape[0]):
876            if i % 2 == 0:
877                stage[i,:] = bed[i,:] + h
878            else:
879                stage[i,:] = bed[i,:]
880
881        domain.set_quantity('stage', stage)
882        domain.distribute_to_vertices_and_edges()
883
884        domain.filename = 'datatest'
885
886        prjfile = domain.filename + '_elevation.prj'
887        ascfile = domain.filename + '_elevation.asc'
888        swwfile = domain.filename + '.sww'
889
890        domain.set_datadir('.')
891        domain.format = 'sww'
892        domain.smooth = True
893
894        domain.geo_reference = Geo_reference(56,308500,6189000)
895
896        sww = get_dataobject(domain)
897        sww.store_connectivity()
898        sww.store_timestep('stage')
899
900        cellsize = 0.25
901        #Check contents
902        #Get NetCDF
903
904        fid = NetCDFFile(swwfile, 'r')
905
906        # Get the variables
907        x = fid.variables['x'][:]
908        y = fid.variables['y'][:]
909        z = fid.variables['elevation'][:]
910        time = fid.variables['time'][:]
911
912        try:
913            geo_reference = Geo_reference(NetCDFObject=fid)
914        except AttributeError, e:
915            geo_reference = Geo_reference(DEFAULT_ZONE,0,0)
916
917        #Export to ascii/prj files
918        sww2asc(domain.filename,
919                quantity = 'elevation',
920                cellsize = cellsize,
921                verbose = False)
922
923
924        #Check asc file
925        ascid = open(ascfile)
926        lines = ascid.readlines()
927        ascid.close()
928
929        L = lines[0].strip().split()
930        assert L[0].strip().lower() == 'ncols'
931        assert L[1].strip().lower() == '5'
932
933        L = lines[1].strip().split()
934        assert L[0].strip().lower() == 'nrows'
935        assert L[1].strip().lower() == '5'
936
937        L = lines[2].strip().split()
938        assert L[0].strip().lower() == 'xllcorner'
939        assert allclose(float(L[1].strip().lower()), 308500)
940
941        L = lines[3].strip().split()
942        assert L[0].strip().lower() == 'yllcorner'
943        assert allclose(float(L[1].strip().lower()), 6189000)
944
945        L = lines[4].strip().split()
946        assert L[0].strip().lower() == 'cellsize'
947        assert allclose(float(L[1].strip().lower()), cellsize)
948
949        L = lines[5].strip().split()
950        assert L[0].strip() == 'NODATA_value'
951        assert L[1].strip().lower() == '-9999'
952
953
954        #Check grid values
955        for j in range(5):
956            L = lines[6+j].strip().split()
957            y = (4-j) * cellsize
958            for i in range(5):
959                if i+j >= 4:
960                    assert allclose(float(L[i]), -i*cellsize - y)
961                else:
962                    #Missing values
963                    assert allclose(float(L[i]), -9999)
964
965
966
967        fid.close()
968
969        #Cleanup
970        os.remove(prjfile)
971        os.remove(ascfile)
972        os.remove(swwfile)
973
974
975    def test_ferret2sww(self):
976        """Test that georeferencing etc works when converting from
977        ferret format (lat/lon) to sww format (UTM)
978        """
979        from Scientific.IO.NetCDF import NetCDFFile
980
981        #The test file has
982        # LON = 150.66667, 150.83334, 151, 151.16667
983        # LAT = -34.5, -34.33333, -34.16667, -34 ;
984        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
985        #
986        # First value (index=0) in small_ha.nc is 0.3400644 cm,
987        # Fourth value (index==3) is -6.50198 cm
988
989
990        from coordinate_transforms.redfearn import redfearn
991
992        fid = NetCDFFile('small_ha.nc')
993        first_value = fid.variables['HA'][:][0,0,0]
994        fourth_value = fid.variables['HA'][:][0,0,3]
995
996
997        #Call conversion (with zero origin)
998        ferret2sww('small', verbose=False,
999                   origin = (56, 0, 0))
1000
1001
1002        #Work out the UTM coordinates for first point
1003        zone, e, n = redfearn(-34.5, 150.66667)
1004        #print zone, e, n
1005
1006        #Read output file 'small.sww'
1007        fid = NetCDFFile('small.sww')
1008
1009        x = fid.variables['x'][:]
1010        y = fid.variables['y'][:]
1011
1012        #Check that first coordinate is correctly represented
1013        assert allclose(x[0], e)
1014        assert allclose(y[0], n)
1015
1016        #Check first value
1017        stage = fid.variables['stage'][:]
1018        xmomentum = fid.variables['xmomentum'][:]
1019        ymomentum = fid.variables['ymomentum'][:]
1020
1021        #print ymomentum
1022
1023        assert allclose(stage[0,0], first_value/100)  #Meters
1024
1025        #Check fourth value
1026        assert allclose(stage[0,3], fourth_value/100)  #Meters
1027
1028        fid.close()
1029
1030        #Cleanup
1031        import os
1032        os.remove('small.sww')
1033
1034
1035
1036    def test_ferret2sww_2(self):
1037        """Test that georeferencing etc works when converting from
1038        ferret format (lat/lon) to sww format (UTM)
1039        """
1040        from Scientific.IO.NetCDF import NetCDFFile
1041
1042        #The test file has
1043        # LON = 150.66667, 150.83334, 151, 151.16667
1044        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1045        # TIME = 0, 0.1, 0.6, 1.1, 1.6, 2.1 ;
1046        #
1047        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1048        # Fourth value (index==3) is -6.50198 cm
1049
1050
1051        from coordinate_transforms.redfearn import redfearn
1052
1053        fid = NetCDFFile('small_ha.nc')
1054
1055        #Pick a coordinate and a value
1056
1057        time_index = 1
1058        lat_index = 0
1059        lon_index = 2
1060
1061        test_value = fid.variables['HA'][:][time_index, lat_index, lon_index]
1062        test_time = fid.variables['TIME'][:][time_index]
1063        test_lat = fid.variables['LAT'][:][lat_index]
1064        test_lon = fid.variables['LON'][:][lon_index]
1065
1066        linear_point_index = lat_index*4 + lon_index
1067        fid.close()
1068
1069        #Call conversion (with zero origin)
1070        ferret2sww('small', verbose=False,
1071                   origin = (56, 0, 0))
1072
1073
1074        #Work out the UTM coordinates for test point
1075        zone, e, n = redfearn(test_lat, test_lon)
1076
1077        #Read output file 'small.sww'
1078        fid = NetCDFFile('small.sww')
1079
1080        x = fid.variables['x'][:]
1081        y = fid.variables['y'][:]
1082
1083        #Check that test coordinate is correctly represented
1084        assert allclose(x[linear_point_index], e)
1085        assert allclose(y[linear_point_index], n)
1086
1087        #Check test value
1088        stage = fid.variables['stage'][:]
1089
1090        assert allclose(stage[time_index, linear_point_index], test_value/100)
1091
1092        fid.close()
1093
1094        #Cleanup
1095        import os
1096        os.remove('small.sww')
1097
1098
1099
1100    def test_ferret2sww3(self):
1101        """
1102        """
1103        from Scientific.IO.NetCDF import NetCDFFile
1104
1105        #The test file has
1106        # LON = 150.66667, 150.83334, 151, 151.16667
1107        # LAT = -34.5, -34.33333, -34.16667, -34 ;
1108        # ELEVATION = [-1 -2 -3 -4
1109        #             -5 -6 -7 -8
1110        #              ...
1111        #              ...      -16]
1112        # where the top left corner is -1m,
1113        # and the ll corner is -13.0m
1114        #
1115        # First value (index=0) in small_ha.nc is 0.3400644 cm,
1116        # Fourth value (index==3) is -6.50198 cm
1117
1118        from coordinate_transforms.redfearn import redfearn
1119        import os
1120        fid1 = NetCDFFile('test_ha.nc','w')
1121        fid2 = NetCDFFile('test_ua.nc','w')
1122        fid3 = NetCDFFile('test_va.nc','w')
1123        fid4 = NetCDFFile('test_e.nc','w')
1124
1125        h1_list = [150.66667,150.83334,151.]
1126        h2_list = [-34.5,-34.33333]
1127
1128        long_name = 'LON'
1129        lat_name = 'LAT'
1130
1131        nx = 3
1132        ny = 2
1133
1134        for fid in [fid1,fid2,fid3]:
1135            fid.createDimension(long_name,nx)
1136            fid.createVariable(long_name,'d',(long_name,))
1137            fid.variables[long_name].point_spacing='uneven'
1138            fid.variables[long_name].units='degrees_east'
1139            fid.variables[long_name].assignValue(h1_list)
1140
1141            fid.createDimension(lat_name,ny)
1142            fid.createVariable(lat_name,'d',(lat_name,))
1143            fid.variables[lat_name].point_spacing='uneven'
1144            fid.variables[lat_name].units='degrees_north'
1145            fid.variables[lat_name].assignValue(h2_list)
1146
1147            fid.createDimension('TIME',2)
1148            fid.createVariable('TIME','d',('TIME',))
1149            fid.variables['TIME'].point_spacing='uneven'
1150            fid.variables['TIME'].units='seconds'
1151            fid.variables['TIME'].assignValue([0.,1.])
1152            if fid == fid3: break
1153
1154
1155        for fid in [fid4]:
1156            fid.createDimension(long_name,nx)
1157            fid.createVariable(long_name,'d',(long_name,))
1158            fid.variables[long_name].point_spacing='uneven'
1159            fid.variables[long_name].units='degrees_east'
1160            fid.variables[long_name].assignValue(h1_list)
1161
1162            fid.createDimension(lat_name,ny)
1163            fid.createVariable(lat_name,'d',(lat_name,))
1164            fid.variables[lat_name].point_spacing='uneven'
1165            fid.variables[lat_name].units='degrees_north'
1166            fid.variables[lat_name].assignValue(h2_list)
1167
1168        name = {}
1169        name[fid1]='HA'
1170        name[fid2]='UA'
1171        name[fid3]='VA'
1172        name[fid4]='ELEVATION'
1173
1174        units = {}
1175        units[fid1]='cm'
1176        units[fid2]='cm/s'
1177        units[fid3]='cm/s'
1178        units[fid4]='m'
1179
1180        values = {}
1181        values[fid1]=[[[5., 10.,15.], [13.,18.,23.]],[[50.,100.,150.],[130.,180.,230.]]]
1182        values[fid2]=[[[1., 2.,3.], [4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]]
1183        values[fid3]=[[[13., 12.,11.], [10.,9.,8.]],[[7.,6.,5.],[4.,3.,2.]]]
1184        values[fid4]=[[-3000,-3100,-3200],[-4000,-5000,-6000]]
1185
1186        for fid in [fid1,fid2,fid3]:
1187          fid.createVariable(name[fid],'d',('TIME',lat_name,long_name))
1188          fid.variables[name[fid]].point_spacing='uneven'
1189          fid.variables[name[fid]].units=units[fid]
1190          fid.variables[name[fid]].assignValue(values[fid])
1191          fid.variables[name[fid]].missing_value = -99999999.
1192          if fid == fid3: break
1193
1194        for fid in [fid4]:
1195            fid.createVariable(name[fid],'d',(lat_name,long_name))
1196            fid.variables[name[fid]].point_spacing='uneven'
1197            fid.variables[name[fid]].units=units[fid]
1198            fid.variables[name[fid]].assignValue(values[fid])
1199            fid.variables[name[fid]].missing_value = -99999999.
1200
1201
1202        fid1.sync(); fid1.close()
1203        fid2.sync(); fid2.close()
1204        fid3.sync(); fid3.close()
1205        fid4.sync(); fid4.close()
1206
1207        fid1 = NetCDFFile('test_ha.nc','r')
1208        fid2 = NetCDFFile('test_e.nc','r')
1209        fid3 = NetCDFFile('test_va.nc','r')
1210
1211
1212        first_amp = fid1.variables['HA'][:][0,0,0]
1213        third_amp = fid1.variables['HA'][:][0,0,2]
1214        first_elevation = fid2.variables['ELEVATION'][0,0]
1215        third_elevation= fid2.variables['ELEVATION'][:][0,2]
1216        first_speed = fid3.variables['VA'][0,0,0]
1217        third_speed = fid3.variables['VA'][:][0,0,2]
1218
1219        fid1.close()
1220        fid2.close()
1221        fid3.close()
1222
1223        #Call conversion (with zero origin)
1224        ferret2sww('test', verbose=False,
1225                   origin = (56, 0, 0))
1226
1227        os.remove('test_va.nc')
1228        os.remove('test_ua.nc')
1229        os.remove('test_ha.nc')
1230        os.remove('test_e.nc')
1231
1232        #Read output file 'test.sww'
1233        fid = NetCDFFile('test.sww')
1234
1235
1236        #Check first value
1237        elevation = fid.variables['elevation'][:]
1238        stage = fid.variables['stage'][:]
1239        xmomentum = fid.variables['xmomentum'][:]
1240        ymomentum = fid.variables['ymomentum'][:]
1241
1242        #print ymomentum
1243        first_height = first_amp/100 - first_elevation
1244        third_height = third_amp/100 - third_elevation
1245        first_momentum=first_speed*first_height/100
1246        third_momentum=third_speed*third_height/100
1247
1248        assert allclose(ymomentum[0][0],first_momentum)  #Meters
1249        assert allclose(ymomentum[0][2],third_momentum)  #Meters
1250
1251        fid.close()
1252
1253        #Cleanup
1254        os.remove('test.sww')
1255
1256
1257
1258
1259    def test_sww_extent(self):
1260        """Not a test, rather a look at the sww format
1261        """
1262
1263        import time, os
1264        from Numeric import array, zeros, allclose, Float, concatenate
1265        from Scientific.IO.NetCDF import NetCDFFile
1266
1267        self.domain.filename = 'datatest' + str(id(self))
1268        self.domain.format = 'sww'
1269        self.domain.smooth = True
1270        self.domain.reduction = mean
1271        self.domain.set_datadir('.')
1272
1273
1274        sww = get_dataobject(self.domain)
1275        sww.store_connectivity()
1276        sww.store_timestep('stage')
1277        self.domain.time = 2.
1278
1279        #Modify stage at second timestep
1280        stage = self.domain.quantities['stage'].vertex_values
1281        self.domain.set_quantity('stage', stage/2)
1282
1283        sww.store_timestep('stage')
1284
1285        file_and_extension_name = self.domain.filename + ".sww"
1286        #print "file_and_extension_name",file_and_extension_name
1287        [xmin, xmax, ymin, ymax, stagemin, stagemax] = \
1288               extent_sww(file_and_extension_name )
1289
1290        assert allclose(xmin, 0.0)
1291        assert allclose(xmax, 1.0)
1292        assert allclose(ymin, 0.0)
1293        assert allclose(ymax, 1.0)
1294        assert allclose(stagemin, -0.85)
1295        assert allclose(stagemax, 0.15)
1296
1297
1298        #Cleanup
1299        os.remove(sww.filename)
1300
1301
1302    def test_ferret2sww_nz_origin(self):
1303        from Scientific.IO.NetCDF import NetCDFFile
1304        from coordinate_transforms.redfearn import redfearn
1305
1306        #Call conversion (with nonzero origin)
1307        ferret2sww('small', verbose=False,
1308                   origin = (56, 100000, 200000))
1309
1310
1311        #Work out the UTM coordinates for first point
1312        zone, e, n = redfearn(-34.5, 150.66667)
1313
1314        #Read output file 'small.sww'
1315        fid = NetCDFFile('small.sww', 'r')
1316
1317        x = fid.variables['x'][:]
1318        y = fid.variables['y'][:]
1319
1320        #Check that first coordinate is correctly represented
1321        assert allclose(x[0], e-100000)
1322        assert allclose(y[0], n-200000)
1323
1324        fid.close()
1325
1326        #Cleanup
1327        import os
1328        os.remove('small.sww')
1329
1330    def test_sww2domain(self):
1331        ################################################
1332        #Create a test domain, and evolve and save it.
1333        ################################################
1334        from mesh_factory import rectangular
1335        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1336             Constant_height, Time_boundary, Transmissive_boundary
1337        from Numeric import array
1338
1339        #Create basic mesh
1340
1341        yiel=0.01
1342        points, vertices, boundary = rectangular(10,10)
1343
1344        #Create shallow water domain
1345        domain = Domain(points, vertices, boundary)
1346        domain.geo_reference = Geo_reference(56,11,11)
1347        domain.smooth = False
1348        domain.visualise = False
1349        domain.store = True
1350        domain.filename = 'bedslope'
1351        domain.default_order=2
1352        #Bed-slope and friction
1353        domain.set_quantity('elevation', lambda x,y: -x/3)
1354        domain.set_quantity('friction', 0.1)
1355        # Boundary conditions
1356        from math import sin, pi
1357        Br = Reflective_boundary(domain)
1358        Bt = Transmissive_boundary(domain)
1359        Bd = Dirichlet_boundary([0.2,0.,0.])
1360        Bw = Time_boundary(domain=domain,
1361                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1362
1363        #domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1364        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1365
1366        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1367        #Initial condition
1368        h = 0.05
1369        elevation = domain.quantities['elevation'].vertex_values
1370        domain.set_quantity('stage', elevation + h)
1371        #elevation = domain.get_quantity('elevation')
1372        #domain.set_quantity('stage', elevation + h)
1373
1374        domain.check_integrity()
1375        #Evolution
1376        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1377        #    domain.write_time()
1378            pass
1379
1380
1381        ##########################################
1382        #Import the example's file as a new domain
1383        ##########################################
1384        from data_manager import sww2domain
1385        from Numeric import allclose
1386        import os
1387
1388        filename = domain.datadir+os.sep+domain.filename+'.sww'
1389        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
1390        #points, vertices, boundary = rectangular(15,15)
1391        #domain2.boundary = boundary
1392        ###################
1393        ##NOW TEST IT!!!
1394        ###################
1395
1396        bits = ['vertex_coordinates']
1397        for quantity in ['elevation']+domain.quantities_to_be_stored:
1398            bits.append('quantities["%s"].get_integral()'%quantity)
1399            bits.append('get_quantity("%s")'%quantity)
1400
1401        for bit in bits:
1402            #print 'testing that domain.'+bit+' has been restored'
1403            #print bit
1404        #print 'done'
1405            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1406
1407        ######################################
1408        #Now evolve them both, just to be sure
1409        ######################################x
1410        visualise = False
1411        #visualise = True
1412        domain.visualise = visualise
1413        domain.time = 0.
1414        from time import sleep
1415
1416        final = .1
1417        domain.set_quantity('friction', 0.1)
1418        domain.store = False
1419        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1420
1421
1422        for t in domain.evolve(yieldstep = yiel, finaltime = final):
1423            if visualise: sleep(1.)
1424            #domain.write_time()
1425            pass
1426
1427        final = final - (domain2.starttime-domain.starttime)
1428        #BUT since domain1 gets time hacked back to 0:
1429        final = final + (domain2.starttime-domain.starttime)
1430
1431        domain2.smooth = False
1432        domain2.visualise = visualise
1433        domain2.store = False
1434        domain2.default_order=2
1435        domain2.set_quantity('friction', 0.1)
1436        #Bed-slope and friction
1437        # Boundary conditions
1438        Bd2=Dirichlet_boundary([0.2,0.,0.])
1439        domain2.boundary = domain.boundary
1440        #print 'domain2.boundary'
1441        #print domain2.boundary
1442        domain2.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1443        #domain2.set_boundary({'exterior': Bd})
1444
1445        domain2.check_integrity()
1446
1447        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
1448            if visualise: sleep(1.)
1449            #domain2.write_time()
1450            pass
1451
1452        ###################
1453        ##NOW TEST IT!!!
1454        ##################
1455
1456        bits = [ 'vertex_coordinates']
1457
1458        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
1459            bits.append('quantities["%s"].get_integral()'%quantity)
1460            bits.append('get_quantity("%s")'%quantity)
1461
1462        for bit in bits:
1463            #print bit
1464            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1465
1466
1467    def test_sww2domain2(self):
1468        ##################################################################
1469        #Same as previous test, but this checks how NaNs are handled.
1470        ##################################################################
1471
1472
1473        from mesh_factory import rectangular
1474        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1475             Constant_height, Time_boundary, Transmissive_boundary
1476        from Numeric import array
1477
1478        #Create basic mesh
1479        points, vertices, boundary = rectangular(2,2)
1480
1481        #Create shallow water domain
1482        domain = Domain(points, vertices, boundary)
1483        domain.smooth = False
1484        domain.visualise = False
1485        domain.store = True
1486        domain.filename = 'bedslope'
1487        domain.default_order=2
1488        domain.quantities_to_be_stored=['stage']
1489
1490        domain.set_quantity('elevation', lambda x,y: -x/3)
1491        domain.set_quantity('friction', 0.1)
1492
1493        from math import sin, pi
1494        Br = Reflective_boundary(domain)
1495        Bt = Transmissive_boundary(domain)
1496        Bd = Dirichlet_boundary([0.2,0.,0.])
1497        Bw = Time_boundary(domain=domain,
1498                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1499
1500        domain.set_boundary({'left': Bd, 'right': Br, 'top': Br, 'bottom': Br})
1501
1502        h = 0.05
1503        elevation = domain.quantities['elevation'].vertex_values
1504        domain.set_quantity('stage', elevation + h)
1505
1506        domain.check_integrity()
1507
1508        for t in domain.evolve(yieldstep = 1, finaltime = 2.0):
1509            pass
1510            #domain.write_time()
1511
1512
1513
1514        ##################################
1515        #Import the file as a new domain
1516        ##################################
1517        from data_manager import sww2domain
1518        from Numeric import allclose
1519        import os
1520
1521        filename = domain.datadir+os.sep+domain.filename+'.sww'
1522
1523        #Fail because NaNs are present
1524        try:
1525            domain2 = sww2domain(filename,boundary,fail_if_NaN=True,verbose=False)
1526            assert True == False
1527        except:
1528            #Now import it, filling NaNs to be 0
1529            filler = 0
1530            domain2 = sww2domain(filename,None,fail_if_NaN=False,NaN_filler = filler,verbose=False)
1531        bits = [ 'geo_reference.get_xllcorner()',
1532                'geo_reference.get_yllcorner()',
1533                'vertex_coordinates']
1534
1535        for quantity in ['elevation']+domain.quantities_to_be_stored:
1536            bits.append('quantities["%s"].get_integral()'%quantity)
1537            bits.append('get_quantity("%s")'%quantity)
1538
1539        for bit in bits:
1540        #    print 'testing that domain.'+bit+' has been restored'
1541            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1542
1543        assert max(max(domain2.get_quantity('xmomentum')))==filler
1544        assert min(min(domain2.get_quantity('xmomentum')))==filler
1545        assert max(max(domain2.get_quantity('ymomentum')))==filler
1546        assert min(min(domain2.get_quantity('ymomentum')))==filler
1547
1548        #print 'passed'
1549
1550        #cleanup
1551        #import os
1552        #os.remove(domain.datadir+'/'+domain.filename+'.sww')
1553
1554
1555    #def test_weed(self):
1556        from data_manager import weed
1557
1558        coordinates1 = [[0.,0.],[1.,0.],[1.,1.],[1.,0.],[2.,0.],[1.,1.]]
1559        volumes1 = [[0,1,2],[3,4,5]]
1560        boundary1= {(0,1): 'external',(1,2): 'not external',(2,0): 'external',(3,4): 'external',(4,5): 'external',(5,3): 'not external'}
1561        coordinates2,volumes2,boundary2=weed(coordinates1,volumes1,boundary1)
1562
1563        points2 = {(0.,0.):None,(1.,0.):None,(1.,1.):None,(2.,0.):None}
1564
1565        assert len(points2)==len(coordinates2)
1566        for i in range(len(coordinates2)):
1567            coordinate = tuple(coordinates2[i])
1568            assert points2.has_key(coordinate)
1569            points2[coordinate]=i
1570
1571        for triangle in volumes1:
1572            for coordinate in triangle:
1573                assert coordinates2[points2[tuple(coordinates1[coordinate])]][0]==coordinates1[coordinate][0]
1574                assert coordinates2[points2[tuple(coordinates1[coordinate])]][1]==coordinates1[coordinate][1]
1575
1576
1577     #FIXME This fails - smooth makes the comparism too hard for allclose
1578    def ztest_sww2domain3(self):
1579        ################################################
1580        #DOMAIN.SMOOTH = TRUE !!!!!!!!!!!!!!!!!!!!!!!!!!
1581        ################################################
1582        from mesh_factory import rectangular
1583        from shallow_water import Domain, Reflective_boundary, Dirichlet_boundary,\
1584             Constant_height, Time_boundary, Transmissive_boundary
1585        from Numeric import array
1586        #Create basic mesh
1587
1588        yiel=0.01
1589        points, vertices, boundary = rectangular(10,10)
1590
1591        #Create shallow water domain
1592        domain = Domain(points, vertices, boundary)
1593        domain.geo_reference = Geo_reference(56,11,11)
1594        domain.smooth = True
1595        domain.visualise = False
1596        domain.store = True
1597        domain.filename = 'bedslope'
1598        domain.default_order=2
1599        #Bed-slope and friction
1600        domain.set_quantity('elevation', lambda x,y: -x/3)
1601        domain.set_quantity('friction', 0.1)
1602        # Boundary conditions
1603        from math import sin, pi
1604        Br = Reflective_boundary(domain)
1605        Bt = Transmissive_boundary(domain)
1606        Bd = Dirichlet_boundary([0.2,0.,0.])
1607        Bw = Time_boundary(domain=domain,
1608                           f=lambda t: [(0.1*sin(t*2*pi)), 0.0, 0.0])
1609
1610        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Bd})
1611
1612        domain.quantities_to_be_stored.extend(['xmomentum','ymomentum'])
1613        #Initial condition
1614        h = 0.05
1615        elevation = domain.quantities['elevation'].vertex_values
1616        domain.set_quantity('stage', elevation + h)
1617        #elevation = domain.get_quantity('elevation')
1618        #domain.set_quantity('stage', elevation + h)
1619
1620        domain.check_integrity()
1621        #Evolution
1622        for t in domain.evolve(yieldstep = yiel, finaltime = 0.05):
1623        #    domain.write_time()
1624            pass
1625
1626
1627        ##########################################
1628        #Import the example's file as a new domain
1629        ##########################################
1630        from data_manager import sww2domain
1631        from Numeric import allclose
1632        import os
1633
1634        filename = domain.datadir+os.sep+domain.filename+'.sww'
1635        domain2 = sww2domain(filename,None,fail_if_NaN=False,verbose = False)
1636        #points, vertices, boundary = rectangular(15,15)
1637        #domain2.boundary = boundary
1638        ###################
1639        ##NOW TEST IT!!!
1640        ###################
1641
1642        #FIXME smooth domain so that they can be compared
1643
1644
1645        bits = []#'vertex_coordinates']
1646        for quantity in ['elevation']+domain.quantities_to_be_stored:
1647            bits.append('quantities["%s"].get_integral()'%quantity)
1648            #bits.append('get_quantity("%s")'%quantity)
1649
1650        for bit in bits:
1651            #print 'testing that domain.'+bit+' has been restored'
1652            #print bit
1653            #print 'done'
1654            #print ('domain.'+bit), eval('domain.'+bit)
1655            #print ('domain2.'+bit), eval('domain2.'+bit)
1656            assert allclose(eval('domain.'+bit),eval('domain2.'+bit),rtol=1.0e-1,atol=1.e-3)
1657            pass
1658
1659        ######################################
1660        #Now evolve them both, just to be sure
1661        ######################################x
1662        visualise = False
1663        visualise = True
1664        domain.visualise = visualise
1665        domain.time = 0.
1666        from time import sleep
1667
1668        final = .5
1669        domain.set_quantity('friction', 0.1)
1670        domain.store = False
1671        domain.set_boundary({'left': Bd, 'right': Bd, 'top': Bd, 'bottom': Br})
1672
1673        for t in domain.evolve(yieldstep = yiel, finaltime = final):
1674            if visualise: sleep(.03)
1675            #domain.write_time()
1676            pass
1677
1678        domain2.smooth = True
1679        domain2.visualise = visualise
1680        domain2.store = False
1681        domain2.default_order=2
1682        domain2.set_quantity('friction', 0.1)
1683        #Bed-slope and friction
1684        # Boundary conditions
1685        Bd2=Dirichlet_boundary([0.2,0.,0.])
1686        Br2 = Reflective_boundary(domain2)
1687        domain2.boundary = domain.boundary
1688        #print 'domain2.boundary'
1689        #print domain2.boundary
1690        domain2.set_boundary({'left': Bd2, 'right': Bd2, 'top': Bd2, 'bottom': Br2})
1691        #domain2.boundary = domain.boundary
1692        #domain2.set_boundary({'exterior': Bd})
1693
1694        domain2.check_integrity()
1695
1696        for t in domain2.evolve(yieldstep = yiel, finaltime = final):
1697            if visualise: sleep(.03)
1698            #domain2.write_time()
1699            pass
1700
1701        ###################
1702        ##NOW TEST IT!!!
1703        ##################
1704
1705        bits = [ 'vertex_coordinates']
1706
1707        for quantity in ['elevation','xmomentum','ymomentum']:#+domain.quantities_to_be_stored:
1708            #bits.append('quantities["%s"].get_integral()'%quantity)
1709            bits.append('get_quantity("%s")'%quantity)
1710
1711        for bit in bits:
1712            print bit
1713            assert allclose(eval('domain.'+bit),eval('domain2.'+bit))
1714
1715
1716#-------------------------------------------------------------
1717if __name__ == "__main__":
1718    suite = unittest.makeSuite(Test_Data_Manager,'test')
1719    #suite = unittest.makeSuite(Test_Data_Manager,'test_sww2domain')
1720    runner = unittest.TextTestRunner()
1721    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.