source: anuga_core/source/anuga/abstract_2d_finite_volumes/test_util.py @ 4910

Last change on this file since 4910 was 4910, checked in by nick, 16 years ago

changed name of ploting function from gauges_sww2csv to sww2csv_gauges
and added some print statments to sww2timeseries to point users to the newer code.

File size: 54.3 KB
Line 
1#!/usr/bin/env python
2
3
4import unittest
5from Numeric import zeros, array, allclose, Float
6from math import sqrt, pi
7import tempfile, os
8from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
9
10from anuga.abstract_2d_finite_volumes.util import *
11from anuga.config import epsilon
12from anuga.shallow_water.data_manager import timefile2netcdf,del_dir
13
14from anuga.utilities.numerical_tools import NAN
15
16from sys import platform
17
18from anuga.pmesh.mesh import Mesh
19from anuga.shallow_water import Domain, Transmissive_boundary
20from anuga.shallow_water.data_manager import get_dataobject
21from csv import reader,writer
22import time
23import string
24
25def test_function(x, y):
26    return x+y
27
28class Test_Util(unittest.TestCase):
29    def setUp(self):
30        pass
31
32    def tearDown(self):
33        pass
34
35
36
37
38    #Geometric
39    #def test_distance(self):
40    #    from anuga.abstract_2d_finite_volumes.util import distance#
41    #
42    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
43    #                     'Distance is wrong!')
44    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
45    #                    'distance is wrong!')
46    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
47    #                    'distance is wrong!')
48    #
49    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
50    #                    'distance is wrong!')
51
52
53    def test_file_function_time1(self):
54        """Test that File function interpolates correctly
55        between given times. No x,y dependency here.
56        """
57
58        #Write file
59        import os, time
60        from anuga.config import time_format
61        from math import sin, pi
62
63        #Typical ASCII file
64        finaltime = 1200
65        filename = 'test_file_function'
66        fid = open(filename + '.txt', 'w')
67        start = time.mktime(time.strptime('2000', '%Y'))
68        dt = 60  #One minute intervals
69        t = 0.0
70        while t <= finaltime:
71            t_string = time.strftime(time_format, time.gmtime(t+start))
72            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
73            t += dt
74
75        fid.close()
76
77        #Convert ASCII file to NetCDF (Which is what we really like!)
78        timefile2netcdf(filename)
79
80
81        #Create file function from time series
82        F = file_function(filename + '.tms',
83                          quantities = ['Attribute0',
84                                        'Attribute1',
85                                        'Attribute2'])
86       
87        #Now try interpolation
88        for i in range(20):
89            t = i*10
90            q = F(t)
91
92            #Exact linear intpolation
93            assert allclose(q[0], 2*t)
94            if i%6 == 0:
95                assert allclose(q[1], t**2)
96                assert allclose(q[2], sin(t*pi/600))
97
98        #Check non-exact
99
100        t = 90 #Halfway between 60 and 120
101        q = F(t)
102        assert allclose( (120**2 + 60**2)/2, q[1] )
103        assert allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
104
105
106        t = 100 #Two thirds of the way between between 60 and 120
107        q = F(t)
108        assert allclose( 2*120**2/3 + 60**2/3, q[1] )
109        assert allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
110
111        os.remove(filename + '.txt')
112        os.remove(filename + '.tms')       
113
114
115       
116    def test_spatio_temporal_file_function_basic(self):
117        """Test that spatio temporal file function performs the correct
118        interpolations in both time and space
119        NetCDF version (x,y,t dependency)       
120        """
121        import time
122
123        #Create sww file of simple propagation from left to right
124        #through rectangular domain
125        from shallow_water import Domain, Dirichlet_boundary
126        from mesh_factory import rectangular
127        from Numeric import take, concatenate, reshape
128
129        #Create basic mesh and shallow water domain
130        points, vertices, boundary = rectangular(3, 3)
131        domain1 = Domain(points, vertices, boundary)
132
133        from anuga.utilities.numerical_tools import mean
134        domain1.reduction = mean
135        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
136                              # only one value.
137
138        domain1.default_order = 2
139        domain1.store = True
140        domain1.set_datadir('.')
141        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
142        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
143
144        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
145        domain1.set_quantity('elevation', 0)
146        domain1.set_quantity('friction', 0)
147        domain1.set_quantity('stage', 0)
148
149        # Boundary conditions
150        B0 = Dirichlet_boundary([0,0,0])
151        B6 = Dirichlet_boundary([0.6,0,0])
152        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
153        domain1.check_integrity()
154
155        finaltime = 8
156        #Evolution
157        t0 = -1
158        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
159            #print 'Timesteps: %.16f, %.16f' %(t0, t)
160            #if t == t0:
161            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
162            #   raise msg
163            t0 = t
164             
165            #domain1.write_time()
166
167
168        #Now read data from sww and check
169        from Scientific.IO.NetCDF import NetCDFFile
170        filename = domain1.get_name() + '.' + domain1.format
171        fid = NetCDFFile(filename)
172
173        x = fid.variables['x'][:]
174        y = fid.variables['y'][:]
175        stage = fid.variables['stage'][:]
176        xmomentum = fid.variables['xmomentum'][:]
177        ymomentum = fid.variables['ymomentum'][:]
178        time = fid.variables['time'][:]
179
180        #Take stage vertex values at last timestep on diagonal
181        #Diagonal is identified by vertices: 0, 5, 10, 15
182
183        last_time_index = len(time)-1 #Last last_time_index
184        d_stage = reshape(take(stage[last_time_index, :], [0,5,10,15]), (4,1))
185        d_uh = reshape(take(xmomentum[last_time_index, :], [0,5,10,15]), (4,1))
186        d_vh = reshape(take(ymomentum[last_time_index, :], [0,5,10,15]), (4,1))
187        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
188
189        #Reference interpolated values at midpoints on diagonal at
190        #this timestep are
191        r0 = (D[0] + D[1])/2
192        r1 = (D[1] + D[2])/2
193        r2 = (D[2] + D[3])/2
194
195        #And the midpoints are found now
196        Dx = take(reshape(x, (16,1)), [0,5,10,15])
197        Dy = take(reshape(y, (16,1)), [0,5,10,15])
198
199        diag = concatenate( (Dx, Dy), axis=1)
200        d_midpoints = (diag[1:] + diag[:-1])/2
201
202        #Let us see if the file function can find the correct
203        #values at the midpoints at the last timestep:
204        f = file_function(filename, domain1,
205                          interpolation_points = d_midpoints)
206
207        T = f.get_time()
208        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
209        assert not T[-1] == T[-2], msg
210        t = time[last_time_index]
211        q = f(t, point_id=0); assert allclose(r0, q)
212        q = f(t, point_id=1); assert allclose(r1, q)
213        q = f(t, point_id=2); assert allclose(r2, q)
214
215
216        ##################
217        #Now do the same for the first timestep
218
219        timestep = 0 #First timestep
220        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
221        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
222        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
223        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
224
225        #Reference interpolated values at midpoints on diagonal at
226        #this timestep are
227        r0 = (D[0] + D[1])/2
228        r1 = (D[1] + D[2])/2
229        r2 = (D[2] + D[3])/2
230
231        #Let us see if the file function can find the correct
232        #values
233        q = f(0, point_id=0); assert allclose(r0, q)
234        q = f(0, point_id=1); assert allclose(r1, q)
235        q = f(0, point_id=2); assert allclose(r2, q)
236
237
238        ##################
239        #Now do it again for a timestep in the middle
240
241        timestep = 33
242        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
243        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
244        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
245        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
246
247        #Reference interpolated values at midpoints on diagonal at
248        #this timestep are
249        r0 = (D[0] + D[1])/2
250        r1 = (D[1] + D[2])/2
251        r2 = (D[2] + D[3])/2
252
253        q = f(timestep/10., point_id=0); assert allclose(r0, q)
254        q = f(timestep/10., point_id=1); assert allclose(r1, q)
255        q = f(timestep/10., point_id=2); assert allclose(r2, q)
256
257
258        ##################
259        #Now check temporal interpolation
260        #Halfway between timestep 15 and 16
261
262        timestep = 15
263        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
264        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
265        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
266        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
267
268        #Reference interpolated values at midpoints on diagonal at
269        #this timestep are
270        r0_0 = (D[0] + D[1])/2
271        r1_0 = (D[1] + D[2])/2
272        r2_0 = (D[2] + D[3])/2
273
274        #
275        timestep = 16
276        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
277        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
278        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
279        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
280
281        #Reference interpolated values at midpoints on diagonal at
282        #this timestep are
283        r0_1 = (D[0] + D[1])/2
284        r1_1 = (D[1] + D[2])/2
285        r2_1 = (D[2] + D[3])/2
286
287        # The reference values are
288        r0 = (r0_0 + r0_1)/2
289        r1 = (r1_0 + r1_1)/2
290        r2 = (r2_0 + r2_1)/2
291
292        q = f((timestep - 0.5)/10., point_id=0); assert allclose(r0, q)
293        q = f((timestep - 0.5)/10., point_id=1); assert allclose(r1, q)
294        q = f((timestep - 0.5)/10., point_id=2); assert allclose(r2, q)
295
296        ##################
297        #Finally check interpolation 2 thirds of the way
298        #between timestep 15 and 16
299
300        # The reference values are
301        r0 = (r0_0 + 2*r0_1)/3
302        r1 = (r1_0 + 2*r1_1)/3
303        r2 = (r2_0 + 2*r2_1)/3
304
305        #And the file function gives
306        q = f((timestep - 1.0/3)/10., point_id=0); assert allclose(r0, q)
307        q = f((timestep - 1.0/3)/10., point_id=1); assert allclose(r1, q)
308        q = f((timestep - 1.0/3)/10., point_id=2); assert allclose(r2, q)
309
310        fid.close()
311        import os
312        os.remove(filename)
313
314
315
316    def test_spatio_temporal_file_function_different_origin(self):
317        """Test that spatio temporal file function performs the correct
318        interpolations in both time and space where space is offset by
319        xllcorner and yllcorner
320        NetCDF version (x,y,t dependency)       
321        """
322        import time
323
324        #Create sww file of simple propagation from left to right
325        #through rectangular domain
326        from shallow_water import Domain, Dirichlet_boundary
327        from mesh_factory import rectangular
328        from Numeric import take, concatenate, reshape
329
330
331        from anuga.coordinate_transforms.geo_reference import Geo_reference
332        xllcorner = 2048
333        yllcorner = 11000
334        zone = 2
335
336        #Create basic mesh and shallow water domain
337        points, vertices, boundary = rectangular(3, 3)
338        domain1 = Domain(points, vertices, boundary,
339                         geo_reference = Geo_reference(xllcorner = xllcorner,
340                                                       yllcorner = yllcorner))
341       
342
343        from anuga.utilities.numerical_tools import mean       
344        domain1.reduction = mean
345        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
346                              # only one value.
347
348        domain1.default_order = 2
349        domain1.store = True
350        domain1.set_datadir('.')
351        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
352        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
353
354        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
355        domain1.set_quantity('elevation', 0)
356        domain1.set_quantity('friction', 0)
357        domain1.set_quantity('stage', 0)
358
359        # Boundary conditions
360        B0 = Dirichlet_boundary([0,0,0])
361        B6 = Dirichlet_boundary([0.6,0,0])
362        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
363        domain1.check_integrity()
364
365        finaltime = 8
366        #Evolution
367        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
368            pass
369            #domain1.write_time()
370
371
372        #Now read data from sww and check
373        from Scientific.IO.NetCDF import NetCDFFile
374        filename = domain1.get_name() + '.' + domain1.format
375        fid = NetCDFFile(filename)
376
377        x = fid.variables['x'][:]
378        y = fid.variables['y'][:]
379        stage = fid.variables['stage'][:]
380        xmomentum = fid.variables['xmomentum'][:]
381        ymomentum = fid.variables['ymomentum'][:]
382        time = fid.variables['time'][:]
383
384        #Take stage vertex values at last timestep on diagonal
385        #Diagonal is identified by vertices: 0, 5, 10, 15
386
387        last_time_index = len(time)-1 #Last last_time_index     
388        d_stage = reshape(take(stage[last_time_index, :], [0,5,10,15]), (4,1))
389        d_uh = reshape(take(xmomentum[last_time_index, :], [0,5,10,15]), (4,1))
390        d_vh = reshape(take(ymomentum[last_time_index, :], [0,5,10,15]), (4,1))
391        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
392
393        #Reference interpolated values at midpoints on diagonal at
394        #this timestep are
395        r0 = (D[0] + D[1])/2
396        r1 = (D[1] + D[2])/2
397        r2 = (D[2] + D[3])/2
398
399        #And the midpoints are found now
400        Dx = take(reshape(x, (16,1)), [0,5,10,15])
401        Dy = take(reshape(y, (16,1)), [0,5,10,15])
402
403        diag = concatenate( (Dx, Dy), axis=1)
404        d_midpoints = (diag[1:] + diag[:-1])/2
405
406
407        #Adjust for georef - make interpolation points absolute
408        d_midpoints[:,0] += xllcorner
409        d_midpoints[:,1] += yllcorner               
410
411        #Let us see if the file function can find the correct
412        #values at the midpoints at the last timestep:
413        f = file_function(filename, domain1,
414                          interpolation_points = d_midpoints)
415
416        t = time[last_time_index]                         
417        q = f(t, point_id=0); assert allclose(r0, q)
418        q = f(t, point_id=1); assert allclose(r1, q)
419        q = f(t, point_id=2); assert allclose(r2, q)
420
421
422        ##################
423        #Now do the same for the first timestep
424
425        timestep = 0 #First timestep
426        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
427        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
428        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
429        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
430
431        #Reference interpolated values at midpoints on diagonal at
432        #this timestep are
433        r0 = (D[0] + D[1])/2
434        r1 = (D[1] + D[2])/2
435        r2 = (D[2] + D[3])/2
436
437        #Let us see if the file function can find the correct
438        #values
439        q = f(0, point_id=0); assert allclose(r0, q)
440        q = f(0, point_id=1); assert allclose(r1, q)
441        q = f(0, point_id=2); assert allclose(r2, q)
442
443
444        ##################
445        #Now do it again for a timestep in the middle
446
447        timestep = 33
448        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
449        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
450        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
451        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
452
453        #Reference interpolated values at midpoints on diagonal at
454        #this timestep are
455        r0 = (D[0] + D[1])/2
456        r1 = (D[1] + D[2])/2
457        r2 = (D[2] + D[3])/2
458
459        q = f(timestep/10., point_id=0); assert allclose(r0, q)
460        q = f(timestep/10., point_id=1); assert allclose(r1, q)
461        q = f(timestep/10., point_id=2); assert allclose(r2, q)
462
463
464        ##################
465        #Now check temporal interpolation
466        #Halfway between timestep 15 and 16
467
468        timestep = 15
469        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
470        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
471        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
472        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
473
474        #Reference interpolated values at midpoints on diagonal at
475        #this timestep are
476        r0_0 = (D[0] + D[1])/2
477        r1_0 = (D[1] + D[2])/2
478        r2_0 = (D[2] + D[3])/2
479
480        #
481        timestep = 16
482        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
483        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
484        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
485        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
486
487        #Reference interpolated values at midpoints on diagonal at
488        #this timestep are
489        r0_1 = (D[0] + D[1])/2
490        r1_1 = (D[1] + D[2])/2
491        r2_1 = (D[2] + D[3])/2
492
493        # The reference values are
494        r0 = (r0_0 + r0_1)/2
495        r1 = (r1_0 + r1_1)/2
496        r2 = (r2_0 + r2_1)/2
497
498        q = f((timestep - 0.5)/10., point_id=0); assert allclose(r0, q)
499        q = f((timestep - 0.5)/10., point_id=1); assert allclose(r1, q)
500        q = f((timestep - 0.5)/10., point_id=2); assert allclose(r2, q)
501
502        ##################
503        #Finally check interpolation 2 thirds of the way
504        #between timestep 15 and 16
505
506        # The reference values are
507        r0 = (r0_0 + 2*r0_1)/3
508        r1 = (r1_0 + 2*r1_1)/3
509        r2 = (r2_0 + 2*r2_1)/3
510
511        #And the file function gives
512        q = f((timestep - 1.0/3)/10., point_id=0); assert allclose(r0, q)
513        q = f((timestep - 1.0/3)/10., point_id=1); assert allclose(r1, q)
514        q = f((timestep - 1.0/3)/10., point_id=2); assert allclose(r2, q)
515
516        fid.close()
517        import os
518        os.remove(filename)
519
520       
521
522
523    def qtest_spatio_temporal_file_function_time(self):
524        """Test that File function interpolates correctly
525        between given times.
526        NetCDF version (x,y,t dependency)
527        """
528
529        #Create NetCDF (sww) file to be read
530        # x: 0, 5, 10, 15
531        # y: -20, -10, 0, 10
532        # t: 0, 60, 120, ...., 1200
533        #
534        # test quantities (arbitrary but non-trivial expressions):
535        #
536        #   stage     = 3*x - y**2 + 2*t
537        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
538        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
539
540        #NOTE: Nice test that may render some of the others redundant.
541
542        import os, time
543        from anuga.config import time_format
544        from Numeric import sin, pi, exp
545        from mesh_factory import rectangular
546        from shallow_water import Domain
547        import anuga.shallow_water.data_manager
548
549        finaltime = 1200
550        filename = 'test_file_function'
551
552        #Create a domain to hold test grid
553        #(0:15, -20:10)
554        points, vertices, boundary =\
555                rectangular(4, 4, 15, 30, origin = (0, -20))
556        print "points", points
557
558        #print 'Number of elements', len(vertices)
559        domain = Domain(points, vertices, boundary)
560        domain.smooth = False
561        domain.default_order = 2
562        domain.set_datadir('.')
563        domain.set_name(filename)
564        domain.store = True
565        domain.format = 'sww'   #Native netcdf visualisation format
566
567        #print points
568        start = time.mktime(time.strptime('2000', '%Y'))
569        domain.starttime = start
570
571
572        #Store structure
573        domain.initialise_storage()
574
575        #Compute artificial time steps and store
576        dt = 60  #One minute intervals
577        t = 0.0
578        while t <= finaltime:
579            #Compute quantities
580            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
581            domain.set_quantity('stage', f1)
582
583            f2 = lambda x,y: x+y+t**2
584            domain.set_quantity('xmomentum', f2)
585
586            f3 = lambda x,y: x**2 + y**2 * sin(t*pi/600)
587            domain.set_quantity('ymomentum', f3)
588
589            #Store and advance time
590            domain.time = t
591            domain.store_timestep(domain.conserved_quantities)
592            t += dt
593
594
595        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
596     
597        #Deliberately set domain.starttime to too early
598        domain.starttime = start - 1
599
600        #Create file function
601        F = file_function(filename + '.sww', domain,
602                          quantities = domain.conserved_quantities,
603                          interpolation_points = interpolation_points)
604
605        #Check that FF updates fixes domain starttime
606        assert allclose(domain.starttime, start)
607
608        #Check that domain.starttime isn't updated if later
609        domain.starttime = start + 1
610        F = file_function(filename + '.sww', domain,
611                          quantities = domain.conserved_quantities,
612                          interpolation_points = interpolation_points)
613        assert allclose(domain.starttime, start+1)
614        domain.starttime = start
615
616
617        #Check linear interpolation in time
618        F = file_function(filename + '.sww', domain,
619                          quantities = domain.conserved_quantities,
620                          interpolation_points = interpolation_points)               
621        for id in range(len(interpolation_points)):
622            x = interpolation_points[id][0]
623            y = interpolation_points[id][1]
624
625            for i in range(20):
626                t = i*10
627                k = i%6
628
629                if k == 0:
630                    q0 = F(t, point_id=id)
631                    q1 = F(t+60, point_id=id)
632
633                if q0 == NAN:
634                    actual = q0
635                else:
636                    actual = (k*q1 + (6-k)*q0)/6
637                q = F(t, point_id=id)
638                #print i, k, t, q
639                #print ' ', q0
640                #print ' ', q1
641                print "q",q
642                print "actual", actual
643                #print
644                if q0 == NAN:
645                     self.failUnless( q == actual, 'Fail!')
646                else:
647                    assert allclose(q, actual)
648
649
650        #Another check of linear interpolation in time
651        for id in range(len(interpolation_points)):
652            q60 = F(60, point_id=id)
653            q120 = F(120, point_id=id)
654
655            t = 90 #Halfway between 60 and 120
656            q = F(t, point_id=id)
657            assert allclose( (q120+q60)/2, q )
658
659            t = 100 #Two thirds of the way between between 60 and 120
660            q = F(t, point_id=id)
661            assert allclose(q60/3 + 2*q120/3, q)
662
663
664
665        #Check that domain.starttime isn't updated if later than file starttime but earlier
666        #than file end time
667        delta = 23
668        domain.starttime = start + delta
669        F = file_function(filename + '.sww', domain,
670                          quantities = domain.conserved_quantities,
671                          interpolation_points = interpolation_points)
672        assert allclose(domain.starttime, start+delta)
673
674
675
676
677        #Now try interpolation with delta offset
678        for id in range(len(interpolation_points)):           
679            x = interpolation_points[id][0]
680            y = interpolation_points[id][1]
681
682            for i in range(20):
683                t = i*10
684                k = i%6
685
686                if k == 0:
687                    q0 = F(t-delta, point_id=id)
688                    q1 = F(t+60-delta, point_id=id)
689
690                q = F(t-delta, point_id=id)
691                assert allclose(q, (k*q1 + (6-k)*q0)/6)
692
693
694        os.remove(filename + '.sww')
695
696
697
698    def NOtest_spatio_temporal_file_function_time(self):
699        # FIXME: This passes but needs some TLC
700        # Test that File function interpolates correctly
701        # When some points are outside the mesh
702
703        import os, time
704        from anuga.config import time_format
705        from Numeric import sin, pi, exp
706        from mesh_factory import rectangular
707        from shallow_water import Domain
708        import anuga.shallow_water.data_manager 
709        from anuga.pmesh.mesh_interface import create_mesh_from_regions
710        finaltime = 1200
711       
712        filename = tempfile.mktemp()
713        #print "filename",filename
714        filename = 'test_file_function'
715
716        meshfilename = tempfile.mktemp(".tsh")
717
718        boundary_tags = {'walls':[0,1],'bom':[2]}
719       
720        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
721       
722        create_mesh_from_regions(polygon_absolute,
723                                 boundary_tags,
724                                 10000000,
725                                 filename=meshfilename)
726        domain = Domain(mesh_filename=meshfilename)
727        domain.smooth = False
728        domain.default_order = 2
729        domain.set_datadir('.')
730        domain.set_name(filename)
731        domain.store = True
732        domain.format = 'sww'   #Native netcdf visualisation format
733
734        #print points
735        start = time.mktime(time.strptime('2000', '%Y'))
736        domain.starttime = start
737       
738
739        #Store structure
740        domain.initialise_storage()
741
742        #Compute artificial time steps and store
743        dt = 60  #One minute intervals
744        t = 0.0
745        while t <= finaltime:
746            #Compute quantities
747            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
748            domain.set_quantity('stage', f1)
749
750            f2 = lambda x,y: x+y+t**2
751            domain.set_quantity('xmomentum', f2)
752
753            f3 = lambda x,y: x**2 + y**2 * sin(t*pi/600)
754            domain.set_quantity('ymomentum', f3)
755
756            #Store and advance time
757            domain.time = t
758            domain.store_timestep(domain.conserved_quantities)
759            t += dt
760
761        interpolation_points = [[1,0]]
762        interpolation_points = [[100,1000]]
763       
764        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
765                                [78787,78787],[7878,3432]]
766           
767        #Deliberately set domain.starttime to too early
768        domain.starttime = start - 1
769
770        #Create file function
771        F = file_function(filename + '.sww', domain,
772                          quantities = domain.conserved_quantities,
773                          interpolation_points = interpolation_points)
774
775        #Check that FF updates fixes domain starttime
776        assert allclose(domain.starttime, start)
777
778        #Check that domain.starttime isn't updated if later
779        domain.starttime = start + 1
780        F = file_function(filename + '.sww', domain,
781                          quantities = domain.conserved_quantities,
782                          interpolation_points = interpolation_points)
783        assert allclose(domain.starttime, start+1)
784        domain.starttime = start
785
786
787        #Check linear interpolation in time
788        # checking points inside and outside the mesh
789        F = file_function(filename + '.sww', domain,
790                          quantities = domain.conserved_quantities,
791                          interpolation_points = interpolation_points)
792       
793        for id in range(len(interpolation_points)):
794            x = interpolation_points[id][0]
795            y = interpolation_points[id][1]
796
797            for i in range(20):
798                t = i*10
799                k = i%6
800
801                if k == 0:
802                    q0 = F(t, point_id=id)
803                    q1 = F(t+60, point_id=id)
804
805                if q0 == NAN:
806                    actual = q0
807                else:
808                    actual = (k*q1 + (6-k)*q0)/6
809                q = F(t, point_id=id)
810                #print i, k, t, q
811                #print ' ', q0
812                #print ' ', q1
813                #print "q",q
814                #print "actual", actual
815                #print
816                if q0 == NAN:
817                     self.failUnless( q == actual, 'Fail!')
818                else:
819                    assert allclose(q, actual)
820
821        # now lets check points inside the mesh
822        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
823        interpolation_points = [[10,-12.5]]
824           
825        print "len(interpolation_points)",len(interpolation_points) 
826        F = file_function(filename + '.sww', domain,
827                          quantities = domain.conserved_quantities,
828                          interpolation_points = interpolation_points)
829
830        domain.starttime = start
831
832
833        #Check linear interpolation in time
834        F = file_function(filename + '.sww', domain,
835                          quantities = domain.conserved_quantities,
836                          interpolation_points = interpolation_points)               
837        for id in range(len(interpolation_points)):
838            x = interpolation_points[id][0]
839            y = interpolation_points[id][1]
840
841            for i in range(20):
842                t = i*10
843                k = i%6
844
845                if k == 0:
846                    q0 = F(t, point_id=id)
847                    q1 = F(t+60, point_id=id)
848
849                if q0 == NAN:
850                    actual = q0
851                else:
852                    actual = (k*q1 + (6-k)*q0)/6
853                q = F(t, point_id=id)
854                print "############"
855                print "id, x, y ", id, x, y #k, t, q
856                print "t", t
857                #print ' ', q0
858                #print ' ', q1
859                print "q",q
860                print "actual", actual
861                #print
862                if q0 == NAN:
863                     self.failUnless( q == actual, 'Fail!')
864                else:
865                    assert allclose(q, actual)
866
867
868        #Another check of linear interpolation in time
869        for id in range(len(interpolation_points)):
870            q60 = F(60, point_id=id)
871            q120 = F(120, point_id=id)
872
873            t = 90 #Halfway between 60 and 120
874            q = F(t, point_id=id)
875            assert allclose( (q120+q60)/2, q )
876
877            t = 100 #Two thirds of the way between between 60 and 120
878            q = F(t, point_id=id)
879            assert allclose(q60/3 + 2*q120/3, q)
880
881
882
883        #Check that domain.starttime isn't updated if later than file starttime but earlier
884        #than file end time
885        delta = 23
886        domain.starttime = start + delta
887        F = file_function(filename + '.sww', domain,
888                          quantities = domain.conserved_quantities,
889                          interpolation_points = interpolation_points)
890        assert allclose(domain.starttime, start+delta)
891
892
893
894
895        #Now try interpolation with delta offset
896        for id in range(len(interpolation_points)):           
897            x = interpolation_points[id][0]
898            y = interpolation_points[id][1]
899
900            for i in range(20):
901                t = i*10
902                k = i%6
903
904                if k == 0:
905                    q0 = F(t-delta, point_id=id)
906                    q1 = F(t+60-delta, point_id=id)
907
908                q = F(t-delta, point_id=id)
909                assert allclose(q, (k*q1 + (6-k)*q0)/6)
910
911
912        os.remove(filename + '.sww')
913
914    def test_file_function_time_with_domain(self):
915        """Test that File function interpolates correctly
916        between given times. No x,y dependency here.
917        Use domain with starttime
918        """
919
920        #Write file
921        import os, time, calendar
922        from anuga.config import time_format
923        from math import sin, pi
924        from domain import Domain
925
926        finaltime = 1200
927        filename = 'test_file_function'
928        fid = open(filename + '.txt', 'w')
929        start = time.mktime(time.strptime('2000', '%Y'))
930        dt = 60  #One minute intervals
931        t = 0.0
932        while t <= finaltime:
933            t_string = time.strftime(time_format, time.gmtime(t+start))
934            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
935            t += dt
936
937        fid.close()
938
939
940        #Convert ASCII file to NetCDF (Which is what we really like!)
941        timefile2netcdf(filename)
942
943
944
945        a = [0.0, 0.0]
946        b = [4.0, 0.0]
947        c = [0.0, 3.0]
948
949        points = [a, b, c]
950        vertices = [[0,1,2]]
951        domain = Domain(points, vertices)
952
953        #Check that domain.starttime is updated if non-existing
954        F = file_function(filename + '.tms', domain)
955       
956        assert allclose(domain.starttime, start)
957
958        #Check that domain.starttime is updated if too early
959        domain.starttime = start - 1
960        F = file_function(filename + '.tms', domain)
961        assert allclose(domain.starttime, start)
962
963        #Check that domain.starttime isn't updated if later
964        domain.starttime = start + 1
965        F = file_function(filename + '.tms', domain)       
966        assert allclose(domain.starttime, start+1)
967
968        domain.starttime = start
969        F = file_function(filename + '.tms', domain,
970                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
971                          use_cache=True)
972       
973
974        #print F.precomputed_values
975        #print 'F(60)', F(60)
976       
977        #Now try interpolation
978        for i in range(20):
979            t = i*10
980            q = F(t)
981
982            #Exact linear intpolation
983            assert allclose(q[0], 2*t)
984            if i%6 == 0:
985                assert allclose(q[1], t**2)
986                assert allclose(q[2], sin(t*pi/600))
987
988        #Check non-exact
989
990        t = 90 #Halfway between 60 and 120
991        q = F(t)
992        assert allclose( (120**2 + 60**2)/2, q[1] )
993        assert allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
994
995
996        t = 100 #Two thirds of the way between between 60 and 120
997        q = F(t)
998        assert allclose( 2*120**2/3 + 60**2/3, q[1] )
999        assert allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1000
1001        os.remove(filename + '.tms')
1002        os.remove(filename + '.txt')       
1003
1004    def test_file_function_time_with_domain_different_start(self):
1005        """Test that File function interpolates correctly
1006        between given times. No x,y dependency here.
1007        Use domain with a starttime later than that of file
1008
1009        ASCII version
1010        """
1011
1012        #Write file
1013        import os, time, calendar
1014        from anuga.config import time_format
1015        from math import sin, pi
1016        from domain import Domain
1017
1018        finaltime = 1200
1019        filename = 'test_file_function'
1020        fid = open(filename + '.txt', 'w')
1021        start = time.mktime(time.strptime('2000', '%Y'))
1022        dt = 60  #One minute intervals
1023        t = 0.0
1024        while t <= finaltime:
1025            t_string = time.strftime(time_format, time.gmtime(t+start))
1026            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1027            t += dt
1028
1029        fid.close()
1030
1031        #Convert ASCII file to NetCDF (Which is what we really like!)
1032        timefile2netcdf(filename)       
1033
1034        a = [0.0, 0.0]
1035        b = [4.0, 0.0]
1036        c = [0.0, 3.0]
1037
1038        points = [a, b, c]
1039        vertices = [[0,1,2]]
1040        domain = Domain(points, vertices)
1041
1042        #Check that domain.starttime isn't updated if later than file starttime but earlier
1043        #than file end time
1044        delta = 23
1045        domain.starttime = start + delta
1046        F = file_function(filename + '.tms', domain,
1047                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1048        assert allclose(domain.starttime, start+delta)
1049
1050
1051
1052
1053        #Now try interpolation with delta offset
1054        for i in range(20):
1055            t = i*10
1056            q = F(t-delta)
1057
1058            #Exact linear intpolation
1059            assert allclose(q[0], 2*t)
1060            if i%6 == 0:
1061                assert allclose(q[1], t**2)
1062                assert allclose(q[2], sin(t*pi/600))
1063
1064        #Check non-exact
1065
1066        t = 90 #Halfway between 60 and 120
1067        q = F(t-delta)
1068        assert allclose( (120**2 + 60**2)/2, q[1] )
1069        assert allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1070
1071
1072        t = 100 #Two thirds of the way between between 60 and 120
1073        q = F(t-delta)
1074        assert allclose( 2*120**2/3 + 60**2/3, q[1] )
1075        assert allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1076
1077
1078        os.remove(filename + '.tms')
1079        os.remove(filename + '.txt')               
1080
1081
1082
1083    def test_apply_expression_to_dictionary(self):
1084
1085        #FIXME: Division is not expected to work for integers.
1086        #This must be caught.
1087        foo = array([[1,2,3],
1088                     [4,5,6]], Float)
1089
1090        bar = array([[-1,0,5],
1091                     [6,1,1]], Float)                 
1092
1093        D = {'X': foo, 'Y': bar}
1094
1095        Z = apply_expression_to_dictionary('X+Y', D)       
1096        assert allclose(Z, foo+bar)
1097
1098        Z = apply_expression_to_dictionary('X*Y', D)       
1099        assert allclose(Z, foo*bar)       
1100
1101        Z = apply_expression_to_dictionary('4*X+Y', D)       
1102        assert allclose(Z, 4*foo+bar)       
1103
1104        # test zero division is OK
1105        Z = apply_expression_to_dictionary('X/Y', D)
1106        assert allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1107
1108        # make an error for zero on zero
1109        # this is really an error in Numeric, SciPy core can handle it
1110        # Z = apply_expression_to_dictionary('0/Y', D)
1111
1112        #Check exceptions
1113        try:
1114            #Wrong name
1115            Z = apply_expression_to_dictionary('4*X+A', D)       
1116        except NameError:
1117            pass
1118        else:
1119            msg = 'Should have raised a NameError Exception'
1120            raise msg
1121
1122
1123        try:
1124            #Wrong order
1125            Z = apply_expression_to_dictionary(D, '4*X+A')       
1126        except AssertionError:
1127            pass
1128        else:
1129            msg = 'Should have raised a AssertionError Exception'
1130            raise msg       
1131       
1132
1133    def test_multiple_replace(self):
1134        """Hard test that checks a true word-by-word simultaneous replace
1135        """
1136       
1137        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1138        exp = '3*x+y + xi'
1139       
1140        new = multiple_replace(exp, D)
1141       
1142        assert new == '3*xi+eta + lam'
1143                         
1144
1145
1146    def test_point_on_line_obsolete(self):
1147        """Test that obsolete call issues appropriate warning"""
1148
1149        #Turn warning into an exception
1150        import warnings
1151        warnings.filterwarnings('error')
1152
1153        try:
1154            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1155        except DeprecationWarning:
1156            pass
1157        else:
1158            msg = 'point_on_line should have issued a DeprecationWarning'
1159            raise Exception(msg)   
1160
1161        warnings.resetwarnings()
1162   
1163    def test_get_revision_number(self):
1164        """test_get_revision_number(self):
1165
1166        Test that revision number can be retrieved.
1167        """
1168        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1169            # I have a known snv incompatability issue,
1170            # so I'm skipping this test.
1171            # FIXME when SVN is upgraded on our clusters
1172            pass
1173        else:   
1174            n = get_revision_number()
1175            assert n>=0
1176
1177
1178       
1179    def test_add_directories(self):
1180       
1181        import tempfile
1182        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1183        directories = ['ja','ne','ke']
1184        kens_dir = add_directories(root_dir, directories)
1185        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1186               sep + 'ke'
1187        assert access(root_dir,F_OK)
1188
1189        add_directories(root_dir, directories)
1190        assert access(root_dir,F_OK)
1191       
1192        #clean up!
1193        os.rmdir(kens_dir)
1194        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1195        os.rmdir(root_dir + sep + 'ja')
1196        os.rmdir(root_dir)
1197
1198    def test_add_directories_bad(self):
1199       
1200        import tempfile
1201        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1202        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1203       
1204        try:
1205            kens_dir = add_directories(root_dir, directories)
1206        except OSError:
1207            pass
1208        else:
1209            msg = 'bad dir name should give OSError'
1210            raise Exception(msg)   
1211           
1212        #clean up!
1213        os.rmdir(root_dir)
1214
1215    def test_check_list(self):
1216
1217        check_list(['stage','xmomentum'])
1218
1219       
1220    def test_add_directories(self):
1221       
1222        import tempfile
1223        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1224        directories = ['ja','ne','ke']
1225        kens_dir = add_directories(root_dir, directories)
1226        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1227               sep + 'ke'
1228        assert access(root_dir,F_OK)
1229
1230        add_directories(root_dir, directories)
1231        assert access(root_dir,F_OK)
1232       
1233        #clean up!
1234        os.rmdir(kens_dir)
1235        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1236        os.rmdir(root_dir + sep + 'ja')
1237        os.rmdir(root_dir)
1238
1239    def test_add_directories_bad(self):
1240       
1241        import tempfile
1242        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1243        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1244       
1245        try:
1246            kens_dir = add_directories(root_dir, directories)
1247        except OSError:
1248            pass
1249        else:
1250            msg = 'bad dir name should give OSError'
1251            raise Exception(msg)   
1252           
1253        #clean up!
1254        os.rmdir(root_dir)
1255
1256    def test_check_list(self):
1257
1258        check_list(['stage','xmomentum'])
1259       
1260    def test_remove_lone_verts_d(self):
1261        verts = [[0,0],[1,0],[0,1]]
1262        tris = [[0,1,2]]
1263        new_verts, new_tris = remove_lone_verts(verts, tris)
1264        assert new_verts == verts
1265        assert new_tris == tris
1266     
1267
1268    def test_remove_lone_verts_e(self):
1269        verts = [[0,0],[1,0],[0,1],[99,99]]
1270        tris = [[0,1,2]]
1271        new_verts, new_tris = remove_lone_verts(verts, tris)
1272        assert new_verts == verts[0:3]
1273        assert new_tris == tris
1274       
1275    def test_remove_lone_verts_a(self):
1276        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1277        tris = [[1,2,4]]
1278        new_verts, new_tris = remove_lone_verts(verts, tris)
1279        #print "new_verts", new_verts
1280        assert new_verts == [[0,0],[1,0],[0,1]]
1281        assert new_tris == [[0,1,2]]
1282     
1283    def test_remove_lone_verts_c(self):
1284        verts = [[0,0],[1,0],[99,99],[0,1]]
1285        tris = [[0,1,3]]
1286        new_verts, new_tris = remove_lone_verts(verts, tris)
1287        #print "new_verts", new_verts
1288        assert new_verts == [[0,0],[1,0],[0,1]]
1289        assert new_tris == [[0,1,2]]
1290       
1291    def test_remove_lone_verts_b(self):
1292        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1293        tris = [[0,1,2]]
1294        new_verts, new_tris = remove_lone_verts(verts, tris)
1295        assert new_verts == verts[0:3]
1296        assert new_tris == tris
1297     
1298
1299    def test_remove_lone_verts_e(self):
1300        verts = [[0,0],[1,0],[0,1],[99,99]]
1301        tris = [[0,1,2]]
1302        new_verts, new_tris = remove_lone_verts(verts, tris)
1303        assert new_verts == verts[0:3]
1304        assert new_tris == tris
1305       
1306    def test_get_min_max_values(self):
1307       
1308        list=[8,9,6,1,4]
1309        min1, max1 = get_min_max_values(list)
1310       
1311        assert min1==1 
1312        assert max1==9
1313       
1314    def test_get_min_max_values1(self):
1315       
1316        list=[-8,-9,-6,-1,-4]
1317        min1, max1 = get_min_max_values(list,10,-10)
1318       
1319#        print 'min1,max1',min1,max1
1320        assert min1==-9 
1321        assert max1==-1
1322
1323    def test_get_min_max_values2(self):
1324        '''
1325        The min and max supplied are greater than the ones in the
1326        list and therefore are the ones returned
1327        '''
1328        list=[-8,-9,-6,-1,-4]
1329        min1, max1 = get_min_max_values(list,-10,10)
1330       
1331#        print 'min1,max1',min1,max1
1332        assert min1==-10 
1333        assert max1==10
1334       
1335    def bad_test_make_plots_from_csv_files(self):
1336       
1337        try: 
1338            import pylab
1339        except ImportError:
1340            #ANUGA don't need pylab to work so the system doesn't
1341            #rely on pylab being installed
1342            return
1343       
1344        if sys.platform == 'win32':  #Windows
1345       
1346            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1347            temp_dir = tempfile.mkdtemp('','figures')
1348    #        print 'temp_dir',temp_dir
1349            fileName = temp_dir+sep+'time_series_3.csv'
1350            file = open(fileName,"w")
1351            file.write("Time,Stage,Speed,Momentum,Elevation\n\
13521.0, 0, 0, 0, 10 \n\
13532.0, 5, 2, 4, 10 \n\
13543.0, 3, 3, 5, 10 \n")
1355            file.close()
1356   
1357            fileName1 = temp_dir+sep+'time_series_4.csv'
1358            file1 = open(fileName1,"w")
1359            file1.write("Time,Stage,Speed,Momentum,Elevation\n\
13601.0, 0, 0, 0, 5 \n\
13612.0, -5, -2, -4, 5 \n\
13623.0, -4, -3, -5, 5 \n")
1363            file1.close()
1364   
1365            fileName2 = temp_dir+sep+'time_series_5.csv'
1366            file2 = open(fileName2,"w")
1367            file2.write("Time,Stage,Speed,Momentum,Elevation\n\
13681.0, 0, 0, 0, 7 \n\
13692.0, 4, -0.45, 57, 7 \n\
13703.0, 6, -0.5, 56, 7 \n")
1371            file2.close()
1372           
1373            dir, name=os.path.split(fileName)
1374            make_plots_from_csv_file(directories_dic={dir:['gauge', 0, 0]},
1375                                output_dir=temp_dir,
1376                                base_name='time_series_',
1377                                plot_numbers=['3-5'],
1378                                quantities=['Speed','Stage','Momentum'],
1379                                assess_all_csv_files=True,
1380                                extra_plot_name='test')
1381           
1382    #        print 'stage+fileName[:-4]+test.png',dir+sep+'stage_'+name[:-4]+'_test.png'
1383            assert(access(dir+sep+'stage_'+name[:-4]+'_test.png',F_OK)==True)
1384            assert(access(dir+sep+'speed_'+name[:-4]+'_test.png',F_OK)==True)
1385            assert(access(dir+sep+'momentum_'+name[:-4]+'_test.png',F_OK)==True)
1386   
1387            dir1, name1=os.path.split(fileName1)
1388            assert(access(dir+sep+'stage_'+name1[:-4]+'_test.png',F_OK)==True)
1389            assert(access(dir+sep+'speed_'+name1[:-4]+'_test.png',F_OK)==True)
1390            assert(access(dir+sep+'momentum_'+name1[:-4]+'_test.png',F_OK)==True)
1391   
1392   
1393            dir2, name2=os.path.split(fileName2)
1394            assert(access(dir+sep+'stage_'+name2[:-4]+'_test.png',F_OK)==True)
1395            assert(access(dir+sep+'speed_'+name2[:-4]+'_test.png',F_OK)==True)
1396            assert(access(dir+sep+'momentum_'+name2[:-4]+'_test.png',F_OK)==True)
1397   
1398            del_dir(temp_dir)
1399       
1400
1401    def test_sww2csv_gauges(self):
1402
1403        def elevation_function(x, y):
1404            return -x
1405       
1406        """Most of this test was copied from test_interpolate test_interpole_sww2csv
1407       
1408        This is testing the gauge_sww2csv function, by creating a sww file and
1409        then exporting the gauges and checking the results.
1410        """
1411       
1412        # create mesh
1413        mesh_file = tempfile.mktemp(".tsh")   
1414        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1415        m = Mesh()
1416        m.add_vertices(points)
1417        m.auto_segment()
1418        m.generate_mesh(verbose=False)
1419        m.export_mesh_file(mesh_file)
1420       
1421        #Create shallow water domain
1422        domain = Domain(mesh_file)
1423        os.remove(mesh_file)
1424       
1425        domain.default_order=2
1426        domain.beta_h = 0
1427
1428        #Set some field values
1429        domain.set_quantity('elevation', elevation_function)
1430        domain.set_quantity('friction', 0.03)
1431        domain.set_quantity('xmomentum', 3.0)
1432        domain.set_quantity('ymomentum', 4.0)
1433
1434        ######################
1435        # Boundary conditions
1436        B = Transmissive_boundary(domain)
1437        domain.set_boundary( {'exterior': B})
1438
1439        # This call mangles the stage values.
1440        domain.distribute_to_vertices_and_edges()
1441        domain.set_quantity('stage', 1.0)
1442
1443
1444        domain.set_name('datatest' + str(time.time()))
1445        domain.format = 'sww'
1446        domain.smooth = True
1447        domain.reduction = mean
1448
1449        sww = get_dataobject(domain)
1450        sww.store_connectivity()
1451        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1452        domain.set_quantity('stage', 10.0) # This is automatically limited
1453        # so it will not be less than the elevation
1454        domain.time = 2.
1455        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1456
1457        # test the function
1458        points = [[5.0,1.],[0.5,2.]]
1459
1460        points_file = tempfile.mktemp(".csv")
1461#        points_file = 'test_point.csv'
1462        file_id = open(points_file,"w")
1463        file_id.write("name, easting, northing, elevation \n\
1464point1, 5.0, 1.0, 3.0\n\
1465point2, 0.5, 2.0, 9.0\n")
1466        file_id.close()
1467
1468       
1469        sww2csv_gauges(sww.filename, 
1470                            points_file,
1471                            verbose=False,
1472                            use_cache=False)
1473
1474        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1475        point1_filename = 'gauge_point1.csv'
1476        point1_handle = file(point1_filename)
1477        point1_reader = reader(point1_handle)
1478        point1_reader.next()
1479
1480        line=[]
1481        for i,row in enumerate(point1_reader):
1482#            print 'i',i,'row',row
1483            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4])])
1484            #print 'assert line',line[i],'point1',point1_answers_array[i]
1485            assert allclose(line[i], point1_answers_array[i])
1486
1487        point2_answers_array = [[0.0,1.0,-0.5,3.0,4.0], [2.0,10.0,-0.5,3.0,4.0]]
1488        point2_filename = 'gauge_point2.csv' 
1489        point2_handle = file(point2_filename)
1490        point2_reader = reader(point2_handle)
1491        point2_reader.next()
1492                       
1493        line=[]
1494        for i,row in enumerate(point2_reader):
1495#            print 'i',i,'row',row
1496            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4])])
1497            #print 'assert line',line[i],'point1',point1_answers_array[i]
1498            assert allclose(line[i], point2_answers_array[i])
1499                         
1500        # clean up
1501        point1_handle.close()
1502        point2_handle.close()
1503        #print "sww.filename",sww.filename
1504        os.remove(sww.filename)
1505        os.remove(points_file)
1506        os.remove(point1_filename)
1507        os.remove(point2_filename)
1508
1509
1510
1511    def test_sww2csv_gauges1(self):
1512        from anuga.pmesh.mesh import Mesh
1513        from anuga.shallow_water import Domain, Transmissive_boundary
1514        from anuga.shallow_water.data_manager import get_dataobject
1515        from csv import reader,writer
1516        import time
1517        import string
1518
1519        def elevation_function(x, y):
1520            return -x
1521       
1522        """Most of this test was copied from test_interpolate test_interpole_sww2csv
1523       
1524        This is testing the gauge_sww2csv function, by creating a sww file and
1525        then exporting the gauges and checking the results.
1526       
1527        This tests the ablity not to have elevation in the points file and
1528        not store xmomentum and ymomentum
1529        """
1530       
1531        # create mesh
1532        mesh_file = tempfile.mktemp(".tsh")   
1533        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1534        m = Mesh()
1535        m.add_vertices(points)
1536        m.auto_segment()
1537        m.generate_mesh(verbose=False)
1538        m.export_mesh_file(mesh_file)
1539       
1540        #Create shallow water domain
1541        domain = Domain(mesh_file)
1542        os.remove(mesh_file)
1543       
1544        domain.default_order=2
1545        domain.beta_h = 0
1546
1547        #Set some field values
1548        domain.set_quantity('elevation', elevation_function)
1549        domain.set_quantity('friction', 0.03)
1550        domain.set_quantity('xmomentum', 3.0)
1551        domain.set_quantity('ymomentum', 4.0)
1552
1553        ######################
1554        # Boundary conditions
1555        B = Transmissive_boundary(domain)
1556        domain.set_boundary( {'exterior': B})
1557
1558        # This call mangles the stage values.
1559        domain.distribute_to_vertices_and_edges()
1560        domain.set_quantity('stage', 1.0)
1561
1562
1563        domain.set_name('datatest' + str(time.time()))
1564        domain.format = 'sww'
1565        domain.smooth = True
1566        domain.reduction = mean
1567
1568        sww = get_dataobject(domain)
1569        sww.store_connectivity()
1570        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1571        domain.set_quantity('stage', 10.0) # This is automatically limited
1572        # so it will not be less than the elevation
1573        domain.time = 2.
1574        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1575
1576        # test the function
1577        points = [[5.0,1.],[0.5,2.]]
1578
1579        points_file = tempfile.mktemp(".csv")
1580#        points_file = 'test_point.csv'
1581        file_id = open(points_file,"w")
1582        file_id.write("name, easting, northing \n\
1583point1, 5.0, 1.0\n\
1584point2, 0.5, 2.0\n")
1585        file_id.close()
1586
1587        sww2csv_gauges(sww.filename, 
1588                            points_file,
1589                            quantities=['Stage', 'elevation'],
1590                            use_cache=False,
1591                            verbose=False)
1592
1593        point1_answers_array = [[0.0,1.0,-5.0], [2.0,10.0,-5.0]]
1594        point1_filename = 'gauge_point1.csv'
1595        point1_handle = file(point1_filename)
1596        point1_reader = reader(point1_handle)
1597        point1_reader.next()
1598
1599        line=[]
1600        for i,row in enumerate(point1_reader):
1601#            print 'i',i,'row',row
1602            line.append([float(row[0]),float(row[1]),float(row[2])])
1603            #print 'line',line[i],'point1',point1_answers_array[i]
1604            assert allclose(line[i], point1_answers_array[i])
1605
1606        point2_answers_array = [[0.0,1.0,-0.5], [2.0,10.0,-0.5]]
1607        point2_filename = 'gauge_point2.csv' 
1608        point2_handle = file(point2_filename)
1609        point2_reader = reader(point2_handle)
1610        point2_reader.next()
1611                       
1612        line=[]
1613        for i,row in enumerate(point2_reader):
1614#            print 'i',i,'row',row
1615            line.append([float(row[0]),float(row[1]),float(row[2])])
1616#            print 'line',line[i],'point1',point1_answers_array[i]
1617            assert allclose(line[i], point2_answers_array[i])
1618                         
1619        # clean up
1620        point1_handle.close()
1621        point2_handle.close()
1622        #print "sww.filename",sww.filename
1623        os.remove(sww.filename)
1624        os.remove(points_file)
1625        os.remove(point1_filename)
1626        os.remove(point2_filename)
1627
1628    def test_greens_law(self):
1629
1630        from math import sqrt
1631       
1632        d1 = 80.0
1633        d2 = 20.0
1634        h1 = 1.0
1635        h2 = greens_law(d1,d2,h1)
1636
1637        assert h2==sqrt(2.0)
1638       
1639
1640#-------------------------------------------------------------
1641if __name__ == "__main__":
1642    suite = unittest.makeSuite(Test_Util,'test')
1643#    suite = unittest.makeSuite(Test_Util,'test_gauges_sww')
1644#    runner = unittest.TextTestRunner(verbosity=2)
1645    runner = unittest.TextTestRunner(verbosity=1)
1646    runner.run(suite)
1647
1648
1649
1650
Note: See TracBrowser for help on using the repository browser.