source: anuga_core/source/anuga/abstract_2d_finite_volumes/test_util.py @ 5221

Last change on this file since 5221 was 5221, checked in by ole, 16 years ago

Work done during Water Down Under 2008.
Hardwired the three conserved quantities from sww into file_function as it was getting messy trying to exclude irrelevant quantities. Also did some formatting and commenting.

File size: 60.4 KB
Line 
1#!/usr/bin/env python
2
3
4import unittest
5from Numeric import zeros, array, allclose, Float
6from math import sqrt, pi
7import tempfile, os
8from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
9
10from anuga.abstract_2d_finite_volumes.util import *
11from anuga.config import epsilon
12from anuga.shallow_water.data_manager import timefile2netcdf,del_dir
13
14from anuga.utilities.numerical_tools import NAN
15
16from sys import platform
17
18from anuga.pmesh.mesh import Mesh
19from anuga.shallow_water import Domain, Transmissive_boundary
20from anuga.shallow_water.data_manager import get_dataobject
21from csv import reader,writer
22import time
23import string
24
25def test_function(x, y):
26    return x+y
27
28class Test_Util(unittest.TestCase):
29    def setUp(self):
30        pass
31
32    def tearDown(self):
33        pass
34
35
36
37
38    #Geometric
39    #def test_distance(self):
40    #    from anuga.abstract_2d_finite_volumes.util import distance#
41    #
42    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
43    #                     'Distance is wrong!')
44    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
45    #                    'distance is wrong!')
46    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
47    #                    'distance is wrong!')
48    #
49    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
50    #                    'distance is wrong!')
51
52
53    def test_file_function_time1(self):
54        """Test that File function interpolates correctly
55        between given times. No x,y dependency here.
56        """
57
58        #Write file
59        import os, time
60        from anuga.config import time_format
61        from math import sin, pi
62
63        #Typical ASCII file
64        finaltime = 1200
65        filename = 'test_file_function'
66        fid = open(filename + '.txt', 'w')
67        start = time.mktime(time.strptime('2000', '%Y'))
68        dt = 60  #One minute intervals
69        t = 0.0
70        while t <= finaltime:
71            t_string = time.strftime(time_format, time.gmtime(t+start))
72            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
73            t += dt
74
75        fid.close()
76
77        #Convert ASCII file to NetCDF (Which is what we really like!)
78        timefile2netcdf(filename)
79
80
81        #Create file function from time series
82        F = file_function(filename + '.tms',
83                          quantities = ['Attribute0',
84                                        'Attribute1',
85                                        'Attribute2'])
86       
87        #Now try interpolation
88        for i in range(20):
89            t = i*10
90            q = F(t)
91
92            #Exact linear intpolation
93            assert allclose(q[0], 2*t)
94            if i%6 == 0:
95                assert allclose(q[1], t**2)
96                assert allclose(q[2], sin(t*pi/600))
97
98        #Check non-exact
99
100        t = 90 #Halfway between 60 and 120
101        q = F(t)
102        assert allclose( (120**2 + 60**2)/2, q[1] )
103        assert allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
104
105
106        t = 100 #Two thirds of the way between between 60 and 120
107        q = F(t)
108        assert allclose( 2*120**2/3 + 60**2/3, q[1] )
109        assert allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
110
111        os.remove(filename + '.txt')
112        os.remove(filename + '.tms')       
113
114
115       
116    def test_spatio_temporal_file_function_basic(self):
117        """Test that spatio temporal file function performs the correct
118        interpolations in both time and space
119        NetCDF version (x,y,t dependency)       
120        """
121        import time
122
123        #Create sww file of simple propagation from left to right
124        #through rectangular domain
125        from shallow_water import Domain, Dirichlet_boundary
126        from mesh_factory import rectangular
127        from Numeric import take, concatenate, reshape
128
129        #Create basic mesh and shallow water domain
130        points, vertices, boundary = rectangular(3, 3)
131        domain1 = Domain(points, vertices, boundary)
132
133        from anuga.utilities.numerical_tools import mean
134        domain1.reduction = mean
135        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
136                              # only one value.
137
138        domain1.default_order = 2
139        domain1.store = True
140        domain1.set_datadir('.')
141        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
142        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
143
144        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
145        domain1.set_quantity('elevation', 0)
146        domain1.set_quantity('friction', 0)
147        domain1.set_quantity('stage', 0)
148
149        # Boundary conditions
150        B0 = Dirichlet_boundary([0,0,0])
151        B6 = Dirichlet_boundary([0.6,0,0])
152        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
153        domain1.check_integrity()
154
155        finaltime = 8
156        #Evolution
157        t0 = -1
158        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
159            #print 'Timesteps: %.16f, %.16f' %(t0, t)
160            #if t == t0:
161            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
162            #   raise msg
163            t0 = t
164             
165            #domain1.write_time()
166
167
168        #Now read data from sww and check
169        from Scientific.IO.NetCDF import NetCDFFile
170        filename = domain1.get_name() + '.' + domain1.format
171        fid = NetCDFFile(filename)
172
173        x = fid.variables['x'][:]
174        y = fid.variables['y'][:]
175        stage = fid.variables['stage'][:]
176        xmomentum = fid.variables['xmomentum'][:]
177        ymomentum = fid.variables['ymomentum'][:]
178        time = fid.variables['time'][:]
179
180        #Take stage vertex values at last timestep on diagonal
181        #Diagonal is identified by vertices: 0, 5, 10, 15
182
183        last_time_index = len(time)-1 #Last last_time_index
184        d_stage = reshape(take(stage[last_time_index, :], [0,5,10,15]), (4,1))
185        d_uh = reshape(take(xmomentum[last_time_index, :], [0,5,10,15]), (4,1))
186        d_vh = reshape(take(ymomentum[last_time_index, :], [0,5,10,15]), (4,1))
187        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
188
189        #Reference interpolated values at midpoints on diagonal at
190        #this timestep are
191        r0 = (D[0] + D[1])/2
192        r1 = (D[1] + D[2])/2
193        r2 = (D[2] + D[3])/2
194
195        #And the midpoints are found now
196        Dx = take(reshape(x, (16,1)), [0,5,10,15])
197        Dy = take(reshape(y, (16,1)), [0,5,10,15])
198
199        diag = concatenate( (Dx, Dy), axis=1)
200        d_midpoints = (diag[1:] + diag[:-1])/2
201
202        #Let us see if the file function can find the correct
203        #values at the midpoints at the last timestep:
204        f = file_function(filename, domain1,
205                          interpolation_points = d_midpoints)
206
207        T = f.get_time()
208        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
209        assert not T[-1] == T[-2], msg
210        t = time[last_time_index]
211        q = f(t, point_id=0); assert allclose(r0, q)
212        q = f(t, point_id=1); assert allclose(r1, q)
213        q = f(t, point_id=2); assert allclose(r2, q)
214
215
216        ##################
217        #Now do the same for the first timestep
218
219        timestep = 0 #First timestep
220        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
221        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
222        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
223        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
224
225        #Reference interpolated values at midpoints on diagonal at
226        #this timestep are
227        r0 = (D[0] + D[1])/2
228        r1 = (D[1] + D[2])/2
229        r2 = (D[2] + D[3])/2
230
231        #Let us see if the file function can find the correct
232        #values
233        q = f(0, point_id=0); assert allclose(r0, q)
234        q = f(0, point_id=1); assert allclose(r1, q)
235        q = f(0, point_id=2); assert allclose(r2, q)
236
237
238        ##################
239        #Now do it again for a timestep in the middle
240
241        timestep = 33
242        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
243        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
244        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
245        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
246
247        #Reference interpolated values at midpoints on diagonal at
248        #this timestep are
249        r0 = (D[0] + D[1])/2
250        r1 = (D[1] + D[2])/2
251        r2 = (D[2] + D[3])/2
252
253        q = f(timestep/10., point_id=0); assert allclose(r0, q)
254        q = f(timestep/10., point_id=1); assert allclose(r1, q)
255        q = f(timestep/10., point_id=2); assert allclose(r2, q)
256
257
258        ##################
259        #Now check temporal interpolation
260        #Halfway between timestep 15 and 16
261
262        timestep = 15
263        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
264        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
265        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
266        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
267
268        #Reference interpolated values at midpoints on diagonal at
269        #this timestep are
270        r0_0 = (D[0] + D[1])/2
271        r1_0 = (D[1] + D[2])/2
272        r2_0 = (D[2] + D[3])/2
273
274        #
275        timestep = 16
276        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
277        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
278        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
279        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
280
281        #Reference interpolated values at midpoints on diagonal at
282        #this timestep are
283        r0_1 = (D[0] + D[1])/2
284        r1_1 = (D[1] + D[2])/2
285        r2_1 = (D[2] + D[3])/2
286
287        # The reference values are
288        r0 = (r0_0 + r0_1)/2
289        r1 = (r1_0 + r1_1)/2
290        r2 = (r2_0 + r2_1)/2
291
292        q = f((timestep - 0.5)/10., point_id=0); assert allclose(r0, q)
293        q = f((timestep - 0.5)/10., point_id=1); assert allclose(r1, q)
294        q = f((timestep - 0.5)/10., point_id=2); assert allclose(r2, q)
295
296        ##################
297        #Finally check interpolation 2 thirds of the way
298        #between timestep 15 and 16
299
300        # The reference values are
301        r0 = (r0_0 + 2*r0_1)/3
302        r1 = (r1_0 + 2*r1_1)/3
303        r2 = (r2_0 + 2*r2_1)/3
304
305        #And the file function gives
306        q = f((timestep - 1.0/3)/10., point_id=0); assert allclose(r0, q)
307        q = f((timestep - 1.0/3)/10., point_id=1); assert allclose(r1, q)
308        q = f((timestep - 1.0/3)/10., point_id=2); assert allclose(r2, q)
309
310        fid.close()
311        import os
312        os.remove(filename)
313
314
315
316    def test_spatio_temporal_file_function_different_origin(self):
317        """Test that spatio temporal file function performs the correct
318        interpolations in both time and space where space is offset by
319        xllcorner and yllcorner
320        NetCDF version (x,y,t dependency)       
321        """
322        import time
323
324        #Create sww file of simple propagation from left to right
325        #through rectangular domain
326        from shallow_water import Domain, Dirichlet_boundary
327        from mesh_factory import rectangular
328        from Numeric import take, concatenate, reshape
329
330
331        from anuga.coordinate_transforms.geo_reference import Geo_reference
332        xllcorner = 2048
333        yllcorner = 11000
334        zone = 2
335
336        #Create basic mesh and shallow water domain
337        points, vertices, boundary = rectangular(3, 3)
338        domain1 = Domain(points, vertices, boundary,
339                         geo_reference = Geo_reference(xllcorner = xllcorner,
340                                                       yllcorner = yllcorner))
341       
342
343        from anuga.utilities.numerical_tools import mean       
344        domain1.reduction = mean
345        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
346                              # only one value.
347
348        domain1.default_order = 2
349        domain1.store = True
350        domain1.set_datadir('.')
351        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
352        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
353
354        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
355        domain1.set_quantity('elevation', 0)
356        domain1.set_quantity('friction', 0)
357        domain1.set_quantity('stage', 0)
358
359        # Boundary conditions
360        B0 = Dirichlet_boundary([0,0,0])
361        B6 = Dirichlet_boundary([0.6,0,0])
362        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
363        domain1.check_integrity()
364
365        finaltime = 8
366        #Evolution
367        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
368            pass
369            #domain1.write_time()
370
371
372        #Now read data from sww and check
373        from Scientific.IO.NetCDF import NetCDFFile
374        filename = domain1.get_name() + '.' + domain1.format
375        fid = NetCDFFile(filename)
376
377        x = fid.variables['x'][:]
378        y = fid.variables['y'][:]
379        stage = fid.variables['stage'][:]
380        xmomentum = fid.variables['xmomentum'][:]
381        ymomentum = fid.variables['ymomentum'][:]
382        time = fid.variables['time'][:]
383
384        #Take stage vertex values at last timestep on diagonal
385        #Diagonal is identified by vertices: 0, 5, 10, 15
386
387        last_time_index = len(time)-1 #Last last_time_index     
388        d_stage = reshape(take(stage[last_time_index, :], [0,5,10,15]), (4,1))
389        d_uh = reshape(take(xmomentum[last_time_index, :], [0,5,10,15]), (4,1))
390        d_vh = reshape(take(ymomentum[last_time_index, :], [0,5,10,15]), (4,1))
391        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
392
393        #Reference interpolated values at midpoints on diagonal at
394        #this timestep are
395        r0 = (D[0] + D[1])/2
396        r1 = (D[1] + D[2])/2
397        r2 = (D[2] + D[3])/2
398
399        #And the midpoints are found now
400        Dx = take(reshape(x, (16,1)), [0,5,10,15])
401        Dy = take(reshape(y, (16,1)), [0,5,10,15])
402
403        diag = concatenate( (Dx, Dy), axis=1)
404        d_midpoints = (diag[1:] + diag[:-1])/2
405
406
407        #Adjust for georef - make interpolation points absolute
408        d_midpoints[:,0] += xllcorner
409        d_midpoints[:,1] += yllcorner               
410
411        #Let us see if the file function can find the correct
412        #values at the midpoints at the last timestep:
413        f = file_function(filename, domain1,
414                          interpolation_points = d_midpoints)
415
416        t = time[last_time_index]                         
417        q = f(t, point_id=0); assert allclose(r0, q)
418        q = f(t, point_id=1); assert allclose(r1, q)
419        q = f(t, point_id=2); assert allclose(r2, q)
420
421
422        ##################
423        #Now do the same for the first timestep
424
425        timestep = 0 #First timestep
426        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
427        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
428        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
429        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
430
431        #Reference interpolated values at midpoints on diagonal at
432        #this timestep are
433        r0 = (D[0] + D[1])/2
434        r1 = (D[1] + D[2])/2
435        r2 = (D[2] + D[3])/2
436
437        #Let us see if the file function can find the correct
438        #values
439        q = f(0, point_id=0); assert allclose(r0, q)
440        q = f(0, point_id=1); assert allclose(r1, q)
441        q = f(0, point_id=2); assert allclose(r2, q)
442
443
444        ##################
445        #Now do it again for a timestep in the middle
446
447        timestep = 33
448        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
449        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
450        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
451        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
452
453        #Reference interpolated values at midpoints on diagonal at
454        #this timestep are
455        r0 = (D[0] + D[1])/2
456        r1 = (D[1] + D[2])/2
457        r2 = (D[2] + D[3])/2
458
459        q = f(timestep/10., point_id=0); assert allclose(r0, q)
460        q = f(timestep/10., point_id=1); assert allclose(r1, q)
461        q = f(timestep/10., point_id=2); assert allclose(r2, q)
462
463
464        ##################
465        #Now check temporal interpolation
466        #Halfway between timestep 15 and 16
467
468        timestep = 15
469        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
470        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
471        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
472        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
473
474        #Reference interpolated values at midpoints on diagonal at
475        #this timestep are
476        r0_0 = (D[0] + D[1])/2
477        r1_0 = (D[1] + D[2])/2
478        r2_0 = (D[2] + D[3])/2
479
480        #
481        timestep = 16
482        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
483        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
484        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
485        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
486
487        #Reference interpolated values at midpoints on diagonal at
488        #this timestep are
489        r0_1 = (D[0] + D[1])/2
490        r1_1 = (D[1] + D[2])/2
491        r2_1 = (D[2] + D[3])/2
492
493        # The reference values are
494        r0 = (r0_0 + r0_1)/2
495        r1 = (r1_0 + r1_1)/2
496        r2 = (r2_0 + r2_1)/2
497
498        q = f((timestep - 0.5)/10., point_id=0); assert allclose(r0, q)
499        q = f((timestep - 0.5)/10., point_id=1); assert allclose(r1, q)
500        q = f((timestep - 0.5)/10., point_id=2); assert allclose(r2, q)
501
502        ##################
503        #Finally check interpolation 2 thirds of the way
504        #between timestep 15 and 16
505
506        # The reference values are
507        r0 = (r0_0 + 2*r0_1)/3
508        r1 = (r1_0 + 2*r1_1)/3
509        r2 = (r2_0 + 2*r2_1)/3
510
511        #And the file function gives
512        q = f((timestep - 1.0/3)/10., point_id=0); assert allclose(r0, q)
513        q = f((timestep - 1.0/3)/10., point_id=1); assert allclose(r1, q)
514        q = f((timestep - 1.0/3)/10., point_id=2); assert allclose(r2, q)
515
516        fid.close()
517        import os
518        os.remove(filename)
519
520       
521
522
523    def qtest_spatio_temporal_file_function_time(self):
524        """Test that File function interpolates correctly
525        between given times.
526        NetCDF version (x,y,t dependency)
527        """
528
529        #Create NetCDF (sww) file to be read
530        # x: 0, 5, 10, 15
531        # y: -20, -10, 0, 10
532        # t: 0, 60, 120, ...., 1200
533        #
534        # test quantities (arbitrary but non-trivial expressions):
535        #
536        #   stage     = 3*x - y**2 + 2*t
537        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
538        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
539
540        #NOTE: Nice test that may render some of the others redundant.
541
542        import os, time
543        from anuga.config import time_format
544        from Numeric import sin, pi, exp
545        from mesh_factory import rectangular
546        from shallow_water import Domain
547        import anuga.shallow_water.data_manager
548
549        finaltime = 1200
550        filename = 'test_file_function'
551
552        #Create a domain to hold test grid
553        #(0:15, -20:10)
554        points, vertices, boundary =\
555                rectangular(4, 4, 15, 30, origin = (0, -20))
556        print "points", points
557
558        #print 'Number of elements', len(vertices)
559        domain = Domain(points, vertices, boundary)
560        domain.smooth = False
561        domain.default_order = 2
562        domain.set_datadir('.')
563        domain.set_name(filename)
564        domain.store = True
565        domain.format = 'sww'   #Native netcdf visualisation format
566
567        #print points
568        start = time.mktime(time.strptime('2000', '%Y'))
569        domain.starttime = start
570
571
572        #Store structure
573        domain.initialise_storage()
574
575        #Compute artificial time steps and store
576        dt = 60  #One minute intervals
577        t = 0.0
578        while t <= finaltime:
579            #Compute quantities
580            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
581            domain.set_quantity('stage', f1)
582
583            f2 = lambda x,y: x+y+t**2
584            domain.set_quantity('xmomentum', f2)
585
586            f3 = lambda x,y: x**2 + y**2 * sin(t*pi/600)
587            domain.set_quantity('ymomentum', f3)
588
589            #Store and advance time
590            domain.time = t
591            domain.store_timestep(domain.conserved_quantities)
592            t += dt
593
594
595        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
596     
597        #Deliberately set domain.starttime to too early
598        domain.starttime = start - 1
599
600        #Create file function
601        F = file_function(filename + '.sww', domain,
602                          quantities = domain.conserved_quantities,
603                          interpolation_points = interpolation_points)
604
605        #Check that FF updates fixes domain starttime
606        assert allclose(domain.starttime, start)
607
608        #Check that domain.starttime isn't updated if later
609        domain.starttime = start + 1
610        F = file_function(filename + '.sww', domain,
611                          quantities = domain.conserved_quantities,
612                          interpolation_points = interpolation_points)
613        assert allclose(domain.starttime, start+1)
614        domain.starttime = start
615
616
617        #Check linear interpolation in time
618        F = file_function(filename + '.sww', domain,
619                          quantities = domain.conserved_quantities,
620                          interpolation_points = interpolation_points)               
621        for id in range(len(interpolation_points)):
622            x = interpolation_points[id][0]
623            y = interpolation_points[id][1]
624
625            for i in range(20):
626                t = i*10
627                k = i%6
628
629                if k == 0:
630                    q0 = F(t, point_id=id)
631                    q1 = F(t+60, point_id=id)
632
633                if q0 == NAN:
634                    actual = q0
635                else:
636                    actual = (k*q1 + (6-k)*q0)/6
637                q = F(t, point_id=id)
638                #print i, k, t, q
639                #print ' ', q0
640                #print ' ', q1
641                print "q",q
642                print "actual", actual
643                #print
644                if q0 == NAN:
645                     self.failUnless( q == actual, 'Fail!')
646                else:
647                    assert allclose(q, actual)
648
649
650        #Another check of linear interpolation in time
651        for id in range(len(interpolation_points)):
652            q60 = F(60, point_id=id)
653            q120 = F(120, point_id=id)
654
655            t = 90 #Halfway between 60 and 120
656            q = F(t, point_id=id)
657            assert allclose( (q120+q60)/2, q )
658
659            t = 100 #Two thirds of the way between between 60 and 120
660            q = F(t, point_id=id)
661            assert allclose(q60/3 + 2*q120/3, q)
662
663
664
665        #Check that domain.starttime isn't updated if later than file starttime but earlier
666        #than file end time
667        delta = 23
668        domain.starttime = start + delta
669        F = file_function(filename + '.sww', domain,
670                          quantities = domain.conserved_quantities,
671                          interpolation_points = interpolation_points)
672        assert allclose(domain.starttime, start+delta)
673
674
675
676
677        #Now try interpolation with delta offset
678        for id in range(len(interpolation_points)):           
679            x = interpolation_points[id][0]
680            y = interpolation_points[id][1]
681
682            for i in range(20):
683                t = i*10
684                k = i%6
685
686                if k == 0:
687                    q0 = F(t-delta, point_id=id)
688                    q1 = F(t+60-delta, point_id=id)
689
690                q = F(t-delta, point_id=id)
691                assert allclose(q, (k*q1 + (6-k)*q0)/6)
692
693
694        os.remove(filename + '.sww')
695
696
697
698    def NOtest_spatio_temporal_file_function_time(self):
699        # FIXME: This passes but needs some TLC
700        # Test that File function interpolates correctly
701        # When some points are outside the mesh
702
703        import os, time
704        from anuga.config import time_format
705        from Numeric import sin, pi, exp
706        from mesh_factory import rectangular
707        from shallow_water import Domain
708        import anuga.shallow_water.data_manager 
709        from anuga.pmesh.mesh_interface import create_mesh_from_regions
710        finaltime = 1200
711       
712        filename = tempfile.mktemp()
713        #print "filename",filename
714        filename = 'test_file_function'
715
716        meshfilename = tempfile.mktemp(".tsh")
717
718        boundary_tags = {'walls':[0,1],'bom':[2]}
719       
720        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
721       
722        create_mesh_from_regions(polygon_absolute,
723                                 boundary_tags,
724                                 10000000,
725                                 filename=meshfilename)
726        domain = Domain(mesh_filename=meshfilename)
727        domain.smooth = False
728        domain.default_order = 2
729        domain.set_datadir('.')
730        domain.set_name(filename)
731        domain.store = True
732        domain.format = 'sww'   #Native netcdf visualisation format
733
734        #print points
735        start = time.mktime(time.strptime('2000', '%Y'))
736        domain.starttime = start
737       
738
739        #Store structure
740        domain.initialise_storage()
741
742        #Compute artificial time steps and store
743        dt = 60  #One minute intervals
744        t = 0.0
745        while t <= finaltime:
746            #Compute quantities
747            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
748            domain.set_quantity('stage', f1)
749
750            f2 = lambda x,y: x+y+t**2
751            domain.set_quantity('xmomentum', f2)
752
753            f3 = lambda x,y: x**2 + y**2 * sin(t*pi/600)
754            domain.set_quantity('ymomentum', f3)
755
756            #Store and advance time
757            domain.time = t
758            domain.store_timestep(domain.conserved_quantities)
759            t += dt
760
761        interpolation_points = [[1,0]]
762        interpolation_points = [[100,1000]]
763       
764        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
765                                [78787,78787],[7878,3432]]
766           
767        #Deliberately set domain.starttime to too early
768        domain.starttime = start - 1
769
770        #Create file function
771        F = file_function(filename + '.sww', domain,
772                          quantities = domain.conserved_quantities,
773                          interpolation_points = interpolation_points)
774
775        #Check that FF updates fixes domain starttime
776        assert allclose(domain.starttime, start)
777
778        #Check that domain.starttime isn't updated if later
779        domain.starttime = start + 1
780        F = file_function(filename + '.sww', domain,
781                          quantities = domain.conserved_quantities,
782                          interpolation_points = interpolation_points)
783        assert allclose(domain.starttime, start+1)
784        domain.starttime = start
785
786
787        #Check linear interpolation in time
788        # checking points inside and outside the mesh
789        F = file_function(filename + '.sww', domain,
790                          quantities = domain.conserved_quantities,
791                          interpolation_points = interpolation_points)
792       
793        for id in range(len(interpolation_points)):
794            x = interpolation_points[id][0]
795            y = interpolation_points[id][1]
796
797            for i in range(20):
798                t = i*10
799                k = i%6
800
801                if k == 0:
802                    q0 = F(t, point_id=id)
803                    q1 = F(t+60, point_id=id)
804
805                if q0 == NAN:
806                    actual = q0
807                else:
808                    actual = (k*q1 + (6-k)*q0)/6
809                q = F(t, point_id=id)
810                #print i, k, t, q
811                #print ' ', q0
812                #print ' ', q1
813                #print "q",q
814                #print "actual", actual
815                #print
816                if q0 == NAN:
817                     self.failUnless( q == actual, 'Fail!')
818                else:
819                    assert allclose(q, actual)
820
821        # now lets check points inside the mesh
822        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
823        interpolation_points = [[10,-12.5]]
824           
825        print "len(interpolation_points)",len(interpolation_points) 
826        F = file_function(filename + '.sww', domain,
827                          quantities = domain.conserved_quantities,
828                          interpolation_points = interpolation_points)
829
830        domain.starttime = start
831
832
833        #Check linear interpolation in time
834        F = file_function(filename + '.sww', domain,
835                          quantities = domain.conserved_quantities,
836                          interpolation_points = interpolation_points)               
837        for id in range(len(interpolation_points)):
838            x = interpolation_points[id][0]
839            y = interpolation_points[id][1]
840
841            for i in range(20):
842                t = i*10
843                k = i%6
844
845                if k == 0:
846                    q0 = F(t, point_id=id)
847                    q1 = F(t+60, point_id=id)
848
849                if q0 == NAN:
850                    actual = q0
851                else:
852                    actual = (k*q1 + (6-k)*q0)/6
853                q = F(t, point_id=id)
854                print "############"
855                print "id, x, y ", id, x, y #k, t, q
856                print "t", t
857                #print ' ', q0
858                #print ' ', q1
859                print "q",q
860                print "actual", actual
861                #print
862                if q0 == NAN:
863                     self.failUnless( q == actual, 'Fail!')
864                else:
865                    assert allclose(q, actual)
866
867
868        #Another check of linear interpolation in time
869        for id in range(len(interpolation_points)):
870            q60 = F(60, point_id=id)
871            q120 = F(120, point_id=id)
872
873            t = 90 #Halfway between 60 and 120
874            q = F(t, point_id=id)
875            assert allclose( (q120+q60)/2, q )
876
877            t = 100 #Two thirds of the way between between 60 and 120
878            q = F(t, point_id=id)
879            assert allclose(q60/3 + 2*q120/3, q)
880
881
882
883        #Check that domain.starttime isn't updated if later than file starttime but earlier
884        #than file end time
885        delta = 23
886        domain.starttime = start + delta
887        F = file_function(filename + '.sww', domain,
888                          quantities = domain.conserved_quantities,
889                          interpolation_points = interpolation_points)
890        assert allclose(domain.starttime, start+delta)
891
892
893
894
895        #Now try interpolation with delta offset
896        for id in range(len(interpolation_points)):           
897            x = interpolation_points[id][0]
898            y = interpolation_points[id][1]
899
900            for i in range(20):
901                t = i*10
902                k = i%6
903
904                if k == 0:
905                    q0 = F(t-delta, point_id=id)
906                    q1 = F(t+60-delta, point_id=id)
907
908                q = F(t-delta, point_id=id)
909                assert allclose(q, (k*q1 + (6-k)*q0)/6)
910
911
912        os.remove(filename + '.sww')
913
914    def test_file_function_time_with_domain(self):
915        """Test that File function interpolates correctly
916        between given times. No x,y dependency here.
917        Use domain with starttime
918        """
919
920        #Write file
921        import os, time, calendar
922        from anuga.config import time_format
923        from math import sin, pi
924        from domain import Domain
925
926        finaltime = 1200
927        filename = 'test_file_function'
928        fid = open(filename + '.txt', 'w')
929        start = time.mktime(time.strptime('2000', '%Y'))
930        dt = 60  #One minute intervals
931        t = 0.0
932        while t <= finaltime:
933            t_string = time.strftime(time_format, time.gmtime(t+start))
934            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
935            t += dt
936
937        fid.close()
938
939
940        #Convert ASCII file to NetCDF (Which is what we really like!)
941        timefile2netcdf(filename)
942
943
944
945        a = [0.0, 0.0]
946        b = [4.0, 0.0]
947        c = [0.0, 3.0]
948
949        points = [a, b, c]
950        vertices = [[0,1,2]]
951        domain = Domain(points, vertices)
952
953        # Check that domain.starttime is updated if non-existing
954        F = file_function(filename + '.tms',
955                          domain,
956                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
957        assert allclose(domain.starttime, start)
958
959        # Check that domain.starttime is updated if too early
960        domain.starttime = start - 1
961        F = file_function(filename + '.tms',
962                          domain,
963                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
964        assert allclose(domain.starttime, start)
965
966        # Check that domain.starttime isn't updated if later
967        domain.starttime = start + 1
968        F = file_function(filename + '.tms',
969                          domain,
970                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
971        assert allclose(domain.starttime, start+1)
972
973        domain.starttime = start
974        F = file_function(filename + '.tms',
975                          domain,
976                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
977                          use_cache=True)
978       
979
980        #print F.precomputed_values
981        #print 'F(60)', F(60)
982       
983        #Now try interpolation
984        for i in range(20):
985            t = i*10
986            q = F(t)
987
988            #Exact linear intpolation
989            assert allclose(q[0], 2*t)
990            if i%6 == 0:
991                assert allclose(q[1], t**2)
992                assert allclose(q[2], sin(t*pi/600))
993
994        #Check non-exact
995
996        t = 90 #Halfway between 60 and 120
997        q = F(t)
998        assert allclose( (120**2 + 60**2)/2, q[1] )
999        assert allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1000
1001
1002        t = 100 #Two thirds of the way between between 60 and 120
1003        q = F(t)
1004        assert allclose( 2*120**2/3 + 60**2/3, q[1] )
1005        assert allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1006
1007        os.remove(filename + '.tms')
1008        os.remove(filename + '.txt')       
1009
1010    def test_file_function_time_with_domain_different_start(self):
1011        """Test that File function interpolates correctly
1012        between given times. No x,y dependency here.
1013        Use domain with a starttime later than that of file
1014
1015        ASCII version
1016        """
1017
1018        #Write file
1019        import os, time, calendar
1020        from anuga.config import time_format
1021        from math import sin, pi
1022        from domain import Domain
1023
1024        finaltime = 1200
1025        filename = 'test_file_function'
1026        fid = open(filename + '.txt', 'w')
1027        start = time.mktime(time.strptime('2000', '%Y'))
1028        dt = 60  #One minute intervals
1029        t = 0.0
1030        while t <= finaltime:
1031            t_string = time.strftime(time_format, time.gmtime(t+start))
1032            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1033            t += dt
1034
1035        fid.close()
1036
1037        #Convert ASCII file to NetCDF (Which is what we really like!)
1038        timefile2netcdf(filename)       
1039
1040        a = [0.0, 0.0]
1041        b = [4.0, 0.0]
1042        c = [0.0, 3.0]
1043
1044        points = [a, b, c]
1045        vertices = [[0,1,2]]
1046        domain = Domain(points, vertices)
1047
1048        #Check that domain.starttime isn't updated if later than file starttime but earlier
1049        #than file end time
1050        delta = 23
1051        domain.starttime = start + delta
1052        F = file_function(filename + '.tms', domain,
1053                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1054        assert allclose(domain.starttime, start+delta)
1055
1056
1057
1058
1059        #Now try interpolation with delta offset
1060        for i in range(20):
1061            t = i*10
1062            q = F(t-delta)
1063
1064            #Exact linear intpolation
1065            assert allclose(q[0], 2*t)
1066            if i%6 == 0:
1067                assert allclose(q[1], t**2)
1068                assert allclose(q[2], sin(t*pi/600))
1069
1070        #Check non-exact
1071
1072        t = 90 #Halfway between 60 and 120
1073        q = F(t-delta)
1074        assert allclose( (120**2 + 60**2)/2, q[1] )
1075        assert allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1076
1077
1078        t = 100 #Two thirds of the way between between 60 and 120
1079        q = F(t-delta)
1080        assert allclose( 2*120**2/3 + 60**2/3, q[1] )
1081        assert allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1082
1083
1084        os.remove(filename + '.tms')
1085        os.remove(filename + '.txt')               
1086
1087
1088
1089    def test_apply_expression_to_dictionary(self):
1090
1091        #FIXME: Division is not expected to work for integers.
1092        #This must be caught.
1093        foo = array([[1,2,3],
1094                     [4,5,6]], Float)
1095
1096        bar = array([[-1,0,5],
1097                     [6,1,1]], Float)                 
1098
1099        D = {'X': foo, 'Y': bar}
1100
1101        Z = apply_expression_to_dictionary('X+Y', D)       
1102        assert allclose(Z, foo+bar)
1103
1104        Z = apply_expression_to_dictionary('X*Y', D)       
1105        assert allclose(Z, foo*bar)       
1106
1107        Z = apply_expression_to_dictionary('4*X+Y', D)       
1108        assert allclose(Z, 4*foo+bar)       
1109
1110        # test zero division is OK
1111        Z = apply_expression_to_dictionary('X/Y', D)
1112        assert allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1113
1114        # make an error for zero on zero
1115        # this is really an error in Numeric, SciPy core can handle it
1116        # Z = apply_expression_to_dictionary('0/Y', D)
1117
1118        #Check exceptions
1119        try:
1120            #Wrong name
1121            Z = apply_expression_to_dictionary('4*X+A', D)       
1122        except NameError:
1123            pass
1124        else:
1125            msg = 'Should have raised a NameError Exception'
1126            raise msg
1127
1128
1129        try:
1130            #Wrong order
1131            Z = apply_expression_to_dictionary(D, '4*X+A')       
1132        except AssertionError:
1133            pass
1134        else:
1135            msg = 'Should have raised a AssertionError Exception'
1136            raise msg       
1137       
1138
1139    def test_multiple_replace(self):
1140        """Hard test that checks a true word-by-word simultaneous replace
1141        """
1142       
1143        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1144        exp = '3*x+y + xi'
1145       
1146        new = multiple_replace(exp, D)
1147       
1148        assert new == '3*xi+eta + lam'
1149                         
1150
1151
1152    def test_point_on_line_obsolete(self):
1153        """Test that obsolete call issues appropriate warning"""
1154
1155        #Turn warning into an exception
1156        import warnings
1157        warnings.filterwarnings('error')
1158
1159        try:
1160            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1161        except DeprecationWarning:
1162            pass
1163        else:
1164            msg = 'point_on_line should have issued a DeprecationWarning'
1165            raise Exception(msg)   
1166
1167        warnings.resetwarnings()
1168   
1169    def test_get_revision_number(self):
1170        """test_get_revision_number(self):
1171
1172        Test that revision number can be retrieved.
1173        """
1174        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1175            # I have a known snv incompatability issue,
1176            # so I'm skipping this test.
1177            # FIXME when SVN is upgraded on our clusters
1178            pass
1179        else:   
1180            n = get_revision_number()
1181            assert n>=0
1182
1183
1184       
1185    def test_add_directories(self):
1186       
1187        import tempfile
1188        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1189        directories = ['ja','ne','ke']
1190        kens_dir = add_directories(root_dir, directories)
1191        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1192               sep + 'ke'
1193        assert access(root_dir,F_OK)
1194
1195        add_directories(root_dir, directories)
1196        assert access(root_dir,F_OK)
1197       
1198        #clean up!
1199        os.rmdir(kens_dir)
1200        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1201        os.rmdir(root_dir + sep + 'ja')
1202        os.rmdir(root_dir)
1203
1204    def test_add_directories_bad(self):
1205       
1206        import tempfile
1207        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1208        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1209       
1210        try:
1211            kens_dir = add_directories(root_dir, directories)
1212        except OSError:
1213            pass
1214        else:
1215            msg = 'bad dir name should give OSError'
1216            raise Exception(msg)   
1217           
1218        #clean up!
1219        os.rmdir(root_dir)
1220
1221    def test_check_list(self):
1222
1223        check_list(['stage','xmomentum'])
1224
1225       
1226    def test_add_directories(self):
1227       
1228        import tempfile
1229        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1230        directories = ['ja','ne','ke']
1231        kens_dir = add_directories(root_dir, directories)
1232        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1233               sep + 'ke'
1234        assert access(root_dir,F_OK)
1235
1236        add_directories(root_dir, directories)
1237        assert access(root_dir,F_OK)
1238       
1239        #clean up!
1240        os.rmdir(kens_dir)
1241        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1242        os.rmdir(root_dir + sep + 'ja')
1243        os.rmdir(root_dir)
1244
1245    def test_add_directories_bad(self):
1246       
1247        import tempfile
1248        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1249        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1250       
1251        try:
1252            kens_dir = add_directories(root_dir, directories)
1253        except OSError:
1254            pass
1255        else:
1256            msg = 'bad dir name should give OSError'
1257            raise Exception(msg)   
1258           
1259        #clean up!
1260        os.rmdir(root_dir)
1261
1262    def test_check_list(self):
1263
1264        check_list(['stage','xmomentum'])
1265       
1266    def test_remove_lone_verts_d(self):
1267        verts = [[0,0],[1,0],[0,1]]
1268        tris = [[0,1,2]]
1269        new_verts, new_tris = remove_lone_verts(verts, tris)
1270        assert new_verts == verts
1271        assert new_tris == tris
1272     
1273
1274    def test_remove_lone_verts_e(self):
1275        verts = [[0,0],[1,0],[0,1],[99,99]]
1276        tris = [[0,1,2]]
1277        new_verts, new_tris = remove_lone_verts(verts, tris)
1278        assert new_verts == verts[0:3]
1279        assert new_tris == tris
1280       
1281    def test_remove_lone_verts_a(self):
1282        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1283        tris = [[1,2,4]]
1284        new_verts, new_tris = remove_lone_verts(verts, tris)
1285        #print "new_verts", new_verts
1286        assert new_verts == [[0,0],[1,0],[0,1]]
1287        assert new_tris == [[0,1,2]]
1288     
1289    def test_remove_lone_verts_c(self):
1290        verts = [[0,0],[1,0],[99,99],[0,1]]
1291        tris = [[0,1,3]]
1292        new_verts, new_tris = remove_lone_verts(verts, tris)
1293        #print "new_verts", new_verts
1294        assert new_verts == [[0,0],[1,0],[0,1]]
1295        assert new_tris == [[0,1,2]]
1296       
1297    def test_remove_lone_verts_b(self):
1298        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1299        tris = [[0,1,2]]
1300        new_verts, new_tris = remove_lone_verts(verts, tris)
1301        assert new_verts == verts[0:3]
1302        assert new_tris == tris
1303     
1304
1305    def test_remove_lone_verts_e(self):
1306        verts = [[0,0],[1,0],[0,1],[99,99]]
1307        tris = [[0,1,2]]
1308        new_verts, new_tris = remove_lone_verts(verts, tris)
1309        assert new_verts == verts[0:3]
1310        assert new_tris == tris
1311       
1312    def test_get_min_max_values(self):
1313       
1314        list=[8,9,6,1,4]
1315        min1, max1 = get_min_max_values(list)
1316       
1317        assert min1==1 
1318        assert max1==9
1319       
1320    def test_get_min_max_values1(self):
1321       
1322        list=[-8,-9,-6,-1,-4]
1323        min1, max1 = get_min_max_values(list)
1324       
1325#        print 'min1,max1',min1,max1
1326        assert min1==-9 
1327        assert max1==-1
1328
1329#    def test_get_min_max_values2(self):
1330#        '''
1331#        The min and max supplied are greater than the ones in the
1332#        list and therefore are the ones returned
1333#        '''
1334#        list=[-8,-9,-6,-1,-4]
1335#        min1, max1 = get_min_max_values(list,-10,10)
1336#       
1337##        print 'min1,max1',min1,max1
1338#        assert min1==-10
1339#        assert max1==10
1340       
1341    def test_make_plots_from_csv_files(self):
1342       
1343        #if sys.platform == 'win32':  #Windows
1344            try: 
1345                import pylab
1346            except ImportError:
1347                #ANUGA don't need pylab to work so the system doesn't
1348                #rely on pylab being installed
1349                return
1350           
1351       
1352            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1353            temp_dir = tempfile.mkdtemp('','figures')
1354    #        print 'temp_dir',temp_dir
1355            fileName = temp_dir+sep+'time_series_3.csv'
1356            file = open(fileName,"w")
1357            file.write("time,stage,speed,momentum,elevation\n\
13581.0, 0, 0, 0, 10 \n\
13592.0, 5, 2, 4, 10 \n\
13603.0, 3, 3, 5, 10 \n")
1361            file.close()
1362   
1363            fileName1 = temp_dir+sep+'time_series_4.csv'
1364            file1 = open(fileName1,"w")
1365            file1.write("time,stage,speed,momentum,elevation\n\
13661.0, 0, 0, 0, 5 \n\
13672.0, -5, -2, -4, 5 \n\
13683.0, -4, -3, -5, 5 \n")
1369            file1.close()
1370   
1371            fileName2 = temp_dir+sep+'time_series_5.csv'
1372            file2 = open(fileName2,"w")
1373            file2.write("time,stage,speed,momentum,elevation\n\
13741.0, 0, 0, 0, 7 \n\
13752.0, 4, -0.45, 57, 7 \n\
13763.0, 6, -0.5, 56, 7 \n")
1377            file2.close()
1378           
1379            dir, name=os.path.split(fileName)
1380            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1381                                  output_dir=temp_dir,
1382                                  base_name='time_series_',
1383                                  plot_numbers=['3-5'],
1384                                  quantities=['speed','stage','momentum'],
1385                                  assess_all_csv_files=True,
1386                                  extra_plot_name='test')
1387           
1388            #print dir+sep+name[:-4]+'_stage_test.png'
1389            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1390            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1391            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1392   
1393            dir1, name1=os.path.split(fileName1)
1394            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1395            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1396            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1397   
1398   
1399            dir2, name2=os.path.split(fileName2)
1400            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1401            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1402            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1403   
1404            del_dir(temp_dir)
1405       
1406
1407    def test_sww2csv_gauges(self):
1408
1409        def elevation_function(x, y):
1410            return -x
1411       
1412        """Most of this test was copied from test_interpolate
1413        test_interpole_sww2csv
1414       
1415        This is testing the gauge_sww2csv function, by creating a sww file and
1416        then exporting the gauges and checking the results.
1417        """
1418       
1419        # Create mesh
1420        mesh_file = tempfile.mktemp(".tsh")   
1421        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1422        m = Mesh()
1423        m.add_vertices(points)
1424        m.auto_segment()
1425        m.generate_mesh(verbose=False)
1426        m.export_mesh_file(mesh_file)
1427       
1428        # Create shallow water domain
1429        domain = Domain(mesh_file)
1430        os.remove(mesh_file)
1431       
1432        domain.default_order=2
1433        domain.beta_h = 0
1434       
1435        # This test was made before tight_slope_limiters were introduced
1436        # Since were are testing interpolation values this is OK
1437        domain.tight_slope_limiters = 0 
1438       
1439
1440        # Set some field values
1441        domain.set_quantity('elevation', elevation_function)
1442        domain.set_quantity('friction', 0.03)
1443        domain.set_quantity('xmomentum', 3.0)
1444        domain.set_quantity('ymomentum', 4.0)
1445
1446        ######################
1447        # Boundary conditions
1448        B = Transmissive_boundary(domain)
1449        domain.set_boundary( {'exterior': B})
1450
1451        # This call mangles the stage values.
1452        domain.distribute_to_vertices_and_edges()
1453        domain.set_quantity('stage', 1.0)
1454
1455
1456        domain.set_name('datatest' + str(time.time()))
1457        domain.format = 'sww'
1458        domain.smooth = True
1459        domain.reduction = mean
1460
1461
1462        sww = get_dataobject(domain)
1463        sww.store_connectivity()
1464        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1465        domain.set_quantity('stage', 10.0) # This is automatically limited
1466        # so it will not be less than the elevation
1467        domain.time = 2.
1468        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1469
1470        # test the function
1471        points = [[5.0,1.],[0.5,2.]]
1472
1473        points_file = tempfile.mktemp(".csv")
1474#        points_file = 'test_point.csv'
1475        file_id = open(points_file,"w")
1476        file_id.write("name, easting, northing, elevation \n\
1477point1, 5.0, 1.0, 3.0\n\
1478point2, 0.5, 2.0, 9.0\n")
1479        file_id.close()
1480
1481       
1482        sww2csv_gauges(sww.filename, 
1483                       points_file,
1484                       verbose=False,
1485                       use_cache=False)
1486
1487#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1488        point1_answers_array = [[0.0,1.0,6.0,-5.0,3.0,4.0], [2.0,10.0,15.0,-5.0,3.0,4.0]]
1489        point1_filename = 'gauge_point1.csv'
1490        point1_handle = file(point1_filename)
1491        point1_reader = reader(point1_handle)
1492        point1_reader.next()
1493
1494        line=[]
1495        for i,row in enumerate(point1_reader):
1496            #print 'i',i,'row',row
1497            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1498            #print 'assert line',line[i],'point1',point1_answers_array[i]
1499            assert allclose(line[i], point1_answers_array[i])
1500
1501        point2_answers_array = [[0.0,1.0,1.5,-0.5,3.0,4.0], [2.0,10.0,10.5,-0.5,3.0,4.0]]
1502        point2_filename = 'gauge_point2.csv' 
1503        point2_handle = file(point2_filename)
1504        point2_reader = reader(point2_handle)
1505        point2_reader.next()
1506                       
1507        line=[]
1508        for i,row in enumerate(point2_reader):
1509            #print 'i',i,'row',row
1510            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1511            #print 'assert line',line[i],'point1',point1_answers_array[i]
1512            assert allclose(line[i], point2_answers_array[i])
1513                         
1514        # clean up
1515        point1_handle.close()
1516        point2_handle.close()
1517        #print "sww.filename",sww.filename
1518        os.remove(sww.filename)
1519        os.remove(points_file)
1520        os.remove(point1_filename)
1521        os.remove(point2_filename)
1522
1523
1524
1525    def test_sww2csv_gauges1(self):
1526        from anuga.pmesh.mesh import Mesh
1527        from anuga.shallow_water import Domain, Transmissive_boundary
1528        from anuga.shallow_water.data_manager import get_dataobject
1529        from csv import reader,writer
1530        import time
1531        import string
1532
1533        def elevation_function(x, y):
1534            return -x
1535       
1536        """Most of this test was copied from test_interpolate
1537        test_interpole_sww2csv
1538       
1539        This is testing the gauge_sww2csv function, by creating a sww file and
1540        then exporting the gauges and checking the results.
1541       
1542        This tests the ablity not to have elevation in the points file and
1543        not store xmomentum and ymomentum
1544        """
1545       
1546        # Create mesh
1547        mesh_file = tempfile.mktemp(".tsh")   
1548        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1549        m = Mesh()
1550        m.add_vertices(points)
1551        m.auto_segment()
1552        m.generate_mesh(verbose=False)
1553        m.export_mesh_file(mesh_file)
1554       
1555        # Create shallow water domain
1556        domain = Domain(mesh_file)
1557        os.remove(mesh_file)
1558       
1559        domain.default_order=2
1560        domain.beta_h = 0
1561
1562        # Set some field values
1563        domain.set_quantity('elevation', elevation_function)
1564        domain.set_quantity('friction', 0.03)
1565        domain.set_quantity('xmomentum', 3.0)
1566        domain.set_quantity('ymomentum', 4.0)
1567
1568        ######################
1569        # Boundary conditions
1570        B = Transmissive_boundary(domain)
1571        domain.set_boundary( {'exterior': B})
1572
1573        # This call mangles the stage values.
1574        domain.distribute_to_vertices_and_edges()
1575        domain.set_quantity('stage', 1.0)
1576
1577
1578        domain.set_name('datatest' + str(time.time()))
1579        domain.format = 'sww'
1580        domain.smooth = True
1581        domain.reduction = mean
1582
1583        sww = get_dataobject(domain)
1584        sww.store_connectivity()
1585        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1586        domain.set_quantity('stage', 10.0) # This is automatically limited
1587        # so it will not be less than the elevation
1588        domain.time = 2.
1589        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1590
1591        # test the function
1592        points = [[5.0,1.],[0.5,2.]]
1593
1594        points_file = tempfile.mktemp(".csv")
1595#        points_file = 'test_point.csv'
1596        file_id = open(points_file,"w")
1597        file_id.write("name,easting,northing \n\
1598point1, 5.0, 1.0\n\
1599point2, 0.5, 2.0\n")
1600        file_id.close()
1601
1602        sww2csv_gauges(sww.filename, 
1603                            points_file,
1604                            quantities=['stage', 'elevation'],
1605                            use_cache=False,
1606                            verbose=False)
1607
1608        point1_answers_array = [[0.0,1.0,-5.0], [2.0,10.0,-5.0]]
1609        point1_filename = 'gauge_point1.csv'
1610        point1_handle = file(point1_filename)
1611        point1_reader = reader(point1_handle)
1612        point1_reader.next()
1613
1614        line=[]
1615        for i,row in enumerate(point1_reader):
1616#            print 'i',i,'row',row
1617            line.append([float(row[0]),float(row[1]),float(row[2])])
1618            #print 'line',line[i],'point1',point1_answers_array[i]
1619            assert allclose(line[i], point1_answers_array[i])
1620
1621        point2_answers_array = [[0.0,1.0,-0.5], [2.0,10.0,-0.5]]
1622        point2_filename = 'gauge_point2.csv' 
1623        point2_handle = file(point2_filename)
1624        point2_reader = reader(point2_handle)
1625        point2_reader.next()
1626                       
1627        line=[]
1628        for i,row in enumerate(point2_reader):
1629#            print 'i',i,'row',row
1630            line.append([float(row[0]),float(row[1]),float(row[2])])
1631#            print 'line',line[i],'point1',point1_answers_array[i]
1632            assert allclose(line[i], point2_answers_array[i])
1633                         
1634        # clean up
1635        point1_handle.close()
1636        point2_handle.close()
1637        #print "sww.filename",sww.filename
1638        os.remove(sww.filename)
1639        os.remove(points_file)
1640        os.remove(point1_filename)
1641        os.remove(point2_filename)
1642
1643
1644    def test_sww2csv_gauges2(self):
1645
1646        def elevation_function(x, y):
1647            return -x
1648       
1649        """Most of this test was copied from test_interpolate
1650        test_interpole_sww2csv
1651       
1652        This is testing the gauge_sww2csv function, by creating a sww file and
1653        then exporting the gauges and checking the results.
1654       
1655        This is the same as sww2csv_gauges except set domain.set_starttime to 5.
1656        Therefore testing the storing of the absolute time in the csv files
1657        """
1658       
1659        # Create mesh
1660        mesh_file = tempfile.mktemp(".tsh")   
1661        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1662        m = Mesh()
1663        m.add_vertices(points)
1664        m.auto_segment()
1665        m.generate_mesh(verbose=False)
1666        m.export_mesh_file(mesh_file)
1667       
1668        # Create shallow water domain
1669        domain = Domain(mesh_file)
1670        os.remove(mesh_file)
1671       
1672        domain.default_order=2
1673        domain.beta_h = 0
1674
1675        # This test was made before tight_slope_limiters were introduced
1676        # Since were are testing interpolation values this is OK
1677        domain.tight_slope_limiters = 0         
1678
1679        # Set some field values
1680        domain.set_quantity('elevation', elevation_function)
1681        domain.set_quantity('friction', 0.03)
1682        domain.set_quantity('xmomentum', 3.0)
1683        domain.set_quantity('ymomentum', 4.0)
1684        domain.set_starttime(5)
1685
1686        ######################
1687        # Boundary conditions
1688        B = Transmissive_boundary(domain)
1689        domain.set_boundary( {'exterior': B})
1690
1691        # This call mangles the stage values.
1692        domain.distribute_to_vertices_and_edges()
1693        domain.set_quantity('stage', 1.0)
1694       
1695
1696
1697        domain.set_name('datatest' + str(time.time()))
1698        domain.format = 'sww'
1699        domain.smooth = True
1700        domain.reduction = mean
1701
1702        sww = get_dataobject(domain)
1703        sww.store_connectivity()
1704        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1705        domain.set_quantity('stage', 10.0) # This is automatically limited
1706        # so it will not be less than the elevation
1707        domain.time = 2.
1708        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1709
1710        # test the function
1711        points = [[5.0,1.],[0.5,2.]]
1712
1713        points_file = tempfile.mktemp(".csv")
1714#        points_file = 'test_point.csv'
1715        file_id = open(points_file,"w")
1716        file_id.write("name, easting, northing, elevation \n\
1717point1, 5.0, 1.0, 3.0\n\
1718point2, 0.5, 2.0, 9.0\n")
1719        file_id.close()
1720
1721       
1722        sww2csv_gauges(sww.filename, 
1723                            points_file,
1724                            verbose=False,
1725                            use_cache=False)
1726
1727#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1728        point1_answers_array = [[5.0,1.0,6.0,-5.0,3.0,4.0], [7.0,10.0,15.0,-5.0,3.0,4.0]]
1729        point1_filename = 'gauge_point1.csv'
1730        point1_handle = file(point1_filename)
1731        point1_reader = reader(point1_handle)
1732        point1_reader.next()
1733
1734        line=[]
1735        for i,row in enumerate(point1_reader):
1736            #print 'i',i,'row',row
1737            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1738            #print 'assert line',line[i],'point1',point1_answers_array[i]
1739            assert allclose(line[i], point1_answers_array[i])
1740
1741        point2_answers_array = [[5.0,1.0,1.5,-0.5,3.0,4.0], [7.0,10.0,10.5,-0.5,3.0,4.0]]
1742        point2_filename = 'gauge_point2.csv' 
1743        point2_handle = file(point2_filename)
1744        point2_reader = reader(point2_handle)
1745        point2_reader.next()
1746                       
1747        line=[]
1748        for i,row in enumerate(point2_reader):
1749            #print 'i',i,'row',row
1750            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1751            #print 'assert line',line[i],'point1',point1_answers_array[i]
1752            assert allclose(line[i], point2_answers_array[i])
1753                         
1754        # clean up
1755        point1_handle.close()
1756        point2_handle.close()
1757        #print "sww.filename",sww.filename
1758        os.remove(sww.filename)
1759        os.remove(points_file)
1760        os.remove(point1_filename)
1761        os.remove(point2_filename)
1762
1763
1764    def test_greens_law(self):
1765
1766        from math import sqrt
1767       
1768        d1 = 80.0
1769        d2 = 20.0
1770        h1 = 1.0
1771        h2 = greens_law(d1,d2,h1)
1772
1773        assert h2==sqrt(2.0)
1774       
1775    def test_calc_bearings(self):
1776 
1777        from math import atan, degrees
1778        #Test East
1779        uh = 1
1780        vh = 1.e-15
1781        angle = calc_bearing(uh, vh)
1782        if 89 < angle < 91: v=1
1783        assert v==1
1784        #Test West
1785        uh = -1
1786        vh = 1.e-15
1787        angle = calc_bearing(uh, vh)
1788        if 269 < angle < 271: v=1
1789        assert v==1
1790        #Test North
1791        uh = 1.e-15
1792        vh = 1
1793        angle = calc_bearing(uh, vh)
1794        if -1 < angle < 1: v=1
1795        assert v==1
1796        #Test South
1797        uh = 1.e-15
1798        vh = -1
1799        angle = calc_bearing(uh, vh)
1800        if 179 < angle < 181: v=1
1801        assert v==1
1802        #Test South-East
1803        uh = 1
1804        vh = -1
1805        angle = calc_bearing(uh, vh)
1806        if 134 < angle < 136: v=1
1807        assert v==1
1808        #Test North-East
1809        uh = 1
1810        vh = 1
1811        angle = calc_bearing(uh, vh)
1812        if 44 < angle < 46: v=1
1813        assert v==1
1814        #Test South-West
1815        uh = -1
1816        vh = -1
1817        angle = calc_bearing(uh, vh)
1818        if 224 < angle < 226: v=1
1819        assert v==1
1820        #Test North-West
1821        uh = -1
1822        vh = 1
1823        angle = calc_bearing(uh, vh)
1824        if 314 < angle < 316: v=1
1825        assert v==1
1826
1827 
1828
1829
1830       
1831
1832#-------------------------------------------------------------
1833if __name__ == "__main__":
1834    suite = unittest.makeSuite(Test_Util,'test')
1835#    suite = unittest.makeSuite(Test_Util,'test_sww2csv')
1836#    runner = unittest.TextTestRunner(verbosity=2)
1837    runner = unittest.TextTestRunner(verbosity=1)
1838    runner.run(suite)
1839
1840
1841
1842
Note: See TracBrowser for help on using the repository browser.