source: anuga_core/source_numpy_conversion/anuga/abstract_2d_finite_volumes/test_util.py @ 5899

Last change on this file since 5899 was 5899, checked in by rwilson, 15 years ago

Initial NumPy? changes (again!).

File size: 62.4 KB
Line 
1## Automatically adapted for numpy.oldnumeric Oct 28, 2008 by alter_code1.py
2
3#!/usr/bin/env python
4
5
6import unittest
7##from numpy.oldnumeric import zeros, array, allclose, Float, alltrue
8from numpy import zeros, array, allclose, float, alltrue
9from math import sqrt, pi
10import tempfile, os
11from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
12
13from anuga.abstract_2d_finite_volumes.util import *
14from anuga.config import epsilon
15from anuga.shallow_water.data_manager import timefile2netcdf,del_dir
16
17from anuga.utilities.numerical_tools import NAN
18
19from sys import platform
20
21from anuga.pmesh.mesh import Mesh
22from anuga.shallow_water import Domain, Transmissive_boundary
23from anuga.shallow_water.data_manager import get_dataobject
24from csv import reader,writer
25import time
26import string
27
28def test_function(x, y):
29    return x+y
30
31class Test_Util(unittest.TestCase):
32    def setUp(self):
33        pass
34
35    def tearDown(self):
36        pass
37
38
39
40
41    #Geometric
42    #def test_distance(self):
43    #    from anuga.abstract_2d_finite_volumes.util import distance#
44    #
45    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
46    #                     'Distance is wrong!')
47    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
48    #                    'distance is wrong!')
49    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
50    #                    'distance is wrong!')
51    #
52    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
53    #                    'distance is wrong!')
54
55
56    def test_file_function_time1(self):
57        """Test that File function interpolates correctly
58        between given times. No x,y dependency here.
59        """
60
61        #Write file
62        import os, time
63        from anuga.config import time_format
64        from math import sin, pi
65
66        #Typical ASCII file
67        finaltime = 1200
68        filename = 'test_file_function'
69        fid = open(filename + '.txt', 'w')
70        start = time.mktime(time.strptime('2000', '%Y'))
71        dt = 60  #One minute intervals
72        t = 0.0
73        while t <= finaltime:
74            t_string = time.strftime(time_format, time.gmtime(t+start))
75            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
76            t += dt
77
78        fid.close()
79
80        #Convert ASCII file to NetCDF (Which is what we really like!)
81        timefile2netcdf(filename)
82
83
84        #Create file function from time series
85        F = file_function(filename + '.tms',
86                          quantities = ['Attribute0',
87                                        'Attribute1',
88                                        'Attribute2'])
89       
90        #Now try interpolation
91        for i in range(20):
92            t = i*10
93            q = F(t)
94
95            #Exact linear intpolation
96            assert allclose(q[0], 2*t)
97            if i%6 == 0:
98                assert allclose(q[1], t**2)
99                assert allclose(q[2], sin(t*pi/600))
100
101        #Check non-exact
102
103        t = 90 #Halfway between 60 and 120
104        q = F(t)
105        assert allclose( (120**2 + 60**2)/2, q[1] )
106        assert allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
107
108
109        t = 100 #Two thirds of the way between between 60 and 120
110        q = F(t)
111        assert allclose( 2*120**2/3 + 60**2/3, q[1] )
112        assert allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
113
114        os.remove(filename + '.txt')
115        os.remove(filename + '.tms')       
116
117
118       
119    def test_spatio_temporal_file_function_basic(self):
120        """Test that spatio temporal file function performs the correct
121        interpolations in both time and space
122        NetCDF version (x,y,t dependency)       
123        """
124        import time
125
126        #Create sww file of simple propagation from left to right
127        #through rectangular domain
128        from shallow_water import Domain, Dirichlet_boundary
129        from mesh_factory import rectangular
130        from numpy.oldnumeric import take, concatenate, reshape
131
132        #Create basic mesh and shallow water domain
133        points, vertices, boundary = rectangular(3, 3)
134        domain1 = Domain(points, vertices, boundary)
135
136        from anuga.utilities.numerical_tools import mean
137        domain1.reduction = mean
138        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
139                              # only one value.
140
141        domain1.default_order = 2
142        domain1.store = True
143        domain1.set_datadir('.')
144        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
145        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
146
147        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
148        domain1.set_quantity('elevation', 0)
149        domain1.set_quantity('friction', 0)
150        domain1.set_quantity('stage', 0)
151
152        # Boundary conditions
153        B0 = Dirichlet_boundary([0,0,0])
154        B6 = Dirichlet_boundary([0.6,0,0])
155        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
156        domain1.check_integrity()
157
158        finaltime = 8
159        #Evolution
160        t0 = -1
161        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
162            #print 'Timesteps: %.16f, %.16f' %(t0, t)
163            #if t == t0:
164            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
165            #   raise msg
166            t0 = t
167             
168            #domain1.write_time()
169
170
171        #Now read data from sww and check
172        from Scientific.IO.NetCDF import NetCDFFile
173        filename = domain1.get_name() + '.' + domain1.format
174        fid = NetCDFFile(filename)
175
176        x = fid.variables['x'][:]
177        y = fid.variables['y'][:]
178        stage = fid.variables['stage'][:]
179        xmomentum = fid.variables['xmomentum'][:]
180        ymomentum = fid.variables['ymomentum'][:]
181        time = fid.variables['time'][:]
182
183        #Take stage vertex values at last timestep on diagonal
184        #Diagonal is identified by vertices: 0, 5, 10, 15
185
186        last_time_index = len(time)-1 #Last last_time_index
187        d_stage = reshape(take(stage[last_time_index, :], [0,5,10,15]), (4,1))
188        d_uh = reshape(take(xmomentum[last_time_index, :], [0,5,10,15]), (4,1))
189        d_vh = reshape(take(ymomentum[last_time_index, :], [0,5,10,15]), (4,1))
190        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
191
192        #Reference interpolated values at midpoints on diagonal at
193        #this timestep are
194        r0 = (D[0] + D[1])/2
195        r1 = (D[1] + D[2])/2
196        r2 = (D[2] + D[3])/2
197
198        #And the midpoints are found now
199        Dx = take(reshape(x, (16,1)), [0,5,10,15])
200        Dy = take(reshape(y, (16,1)), [0,5,10,15])
201
202        diag = concatenate( (Dx, Dy), axis=1)
203        d_midpoints = (diag[1:] + diag[:-1])/2
204
205        #Let us see if the file function can find the correct
206        #values at the midpoints at the last timestep:
207        f = file_function(filename, domain1,
208                          interpolation_points = d_midpoints)
209
210        T = f.get_time()
211        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
212        assert not T[-1] == T[-2], msg
213        t = time[last_time_index]
214        q = f(t, point_id=0); assert allclose(r0, q)
215        q = f(t, point_id=1); assert allclose(r1, q)
216        q = f(t, point_id=2); assert allclose(r2, q)
217
218
219        ##################
220        #Now do the same for the first timestep
221
222        timestep = 0 #First timestep
223        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
224        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
225        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
226        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
227
228        #Reference interpolated values at midpoints on diagonal at
229        #this timestep are
230        r0 = (D[0] + D[1])/2
231        r1 = (D[1] + D[2])/2
232        r2 = (D[2] + D[3])/2
233
234        #Let us see if the file function can find the correct
235        #values
236        q = f(0, point_id=0); assert allclose(r0, q)
237        q = f(0, point_id=1); assert allclose(r1, q)
238        q = f(0, point_id=2); assert allclose(r2, q)
239
240
241        ##################
242        #Now do it again for a timestep in the middle
243
244        timestep = 33
245        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
246        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
247        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
248        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
249
250        #Reference interpolated values at midpoints on diagonal at
251        #this timestep are
252        r0 = (D[0] + D[1])/2
253        r1 = (D[1] + D[2])/2
254        r2 = (D[2] + D[3])/2
255
256        q = f(timestep/10., point_id=0); assert allclose(r0, q)
257        q = f(timestep/10., point_id=1); assert allclose(r1, q)
258        q = f(timestep/10., point_id=2); assert allclose(r2, q)
259
260
261        ##################
262        #Now check temporal interpolation
263        #Halfway between timestep 15 and 16
264
265        timestep = 15
266        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
267        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
268        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
269        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
270
271        #Reference interpolated values at midpoints on diagonal at
272        #this timestep are
273        r0_0 = (D[0] + D[1])/2
274        r1_0 = (D[1] + D[2])/2
275        r2_0 = (D[2] + D[3])/2
276
277        #
278        timestep = 16
279        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
280        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
281        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
282        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
283
284        #Reference interpolated values at midpoints on diagonal at
285        #this timestep are
286        r0_1 = (D[0] + D[1])/2
287        r1_1 = (D[1] + D[2])/2
288        r2_1 = (D[2] + D[3])/2
289
290        # The reference values are
291        r0 = (r0_0 + r0_1)/2
292        r1 = (r1_0 + r1_1)/2
293        r2 = (r2_0 + r2_1)/2
294
295        q = f((timestep - 0.5)/10., point_id=0); assert allclose(r0, q)
296        q = f((timestep - 0.5)/10., point_id=1); assert allclose(r1, q)
297        q = f((timestep - 0.5)/10., point_id=2); assert allclose(r2, q)
298
299        ##################
300        #Finally check interpolation 2 thirds of the way
301        #between timestep 15 and 16
302
303        # The reference values are
304        r0 = (r0_0 + 2*r0_1)/3
305        r1 = (r1_0 + 2*r1_1)/3
306        r2 = (r2_0 + 2*r2_1)/3
307
308        #And the file function gives
309        q = f((timestep - 1.0/3)/10., point_id=0); assert allclose(r0, q)
310        q = f((timestep - 1.0/3)/10., point_id=1); assert allclose(r1, q)
311        q = f((timestep - 1.0/3)/10., point_id=2); assert allclose(r2, q)
312
313        fid.close()
314        import os
315        os.remove(filename)
316
317
318
319    def test_spatio_temporal_file_function_different_origin(self):
320        """Test that spatio temporal file function performs the correct
321        interpolations in both time and space where space is offset by
322        xllcorner and yllcorner
323        NetCDF version (x,y,t dependency)       
324        """
325        import time
326
327        #Create sww file of simple propagation from left to right
328        #through rectangular domain
329        from shallow_water import Domain, Dirichlet_boundary
330        from mesh_factory import rectangular
331        from numpy.oldnumeric import take, concatenate, reshape
332
333
334        from anuga.coordinate_transforms.geo_reference import Geo_reference
335        xllcorner = 2048
336        yllcorner = 11000
337        zone = 2
338
339        #Create basic mesh and shallow water domain
340        points, vertices, boundary = rectangular(3, 3)
341        domain1 = Domain(points, vertices, boundary,
342                         geo_reference = Geo_reference(xllcorner = xllcorner,
343                                                       yllcorner = yllcorner))
344       
345
346        from anuga.utilities.numerical_tools import mean       
347        domain1.reduction = mean
348        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
349                              # only one value.
350
351        domain1.default_order = 2
352        domain1.store = True
353        domain1.set_datadir('.')
354        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
355        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
356
357        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
358        domain1.set_quantity('elevation', 0)
359        domain1.set_quantity('friction', 0)
360        domain1.set_quantity('stage', 0)
361
362        # Boundary conditions
363        B0 = Dirichlet_boundary([0,0,0])
364        B6 = Dirichlet_boundary([0.6,0,0])
365        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
366        domain1.check_integrity()
367
368        finaltime = 8
369        #Evolution
370        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
371            pass
372            #domain1.write_time()
373
374
375        #Now read data from sww and check
376        from Scientific.IO.NetCDF import NetCDFFile
377        filename = domain1.get_name() + '.' + domain1.format
378        fid = NetCDFFile(filename)
379
380        x = fid.variables['x'][:]
381        y = fid.variables['y'][:]
382        stage = fid.variables['stage'][:]
383        xmomentum = fid.variables['xmomentum'][:]
384        ymomentum = fid.variables['ymomentum'][:]
385        time = fid.variables['time'][:]
386
387        #Take stage vertex values at last timestep on diagonal
388        #Diagonal is identified by vertices: 0, 5, 10, 15
389
390        last_time_index = len(time)-1 #Last last_time_index     
391        d_stage = reshape(take(stage[last_time_index, :], [0,5,10,15]), (4,1))
392        d_uh = reshape(take(xmomentum[last_time_index, :], [0,5,10,15]), (4,1))
393        d_vh = reshape(take(ymomentum[last_time_index, :], [0,5,10,15]), (4,1))
394        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
395
396        #Reference interpolated values at midpoints on diagonal at
397        #this timestep are
398        r0 = (D[0] + D[1])/2
399        r1 = (D[1] + D[2])/2
400        r2 = (D[2] + D[3])/2
401
402        #And the midpoints are found now
403        Dx = take(reshape(x, (16,1)), [0,5,10,15])
404        Dy = take(reshape(y, (16,1)), [0,5,10,15])
405
406        diag = concatenate( (Dx, Dy), axis=1)
407        d_midpoints = (diag[1:] + diag[:-1])/2
408
409
410        #Adjust for georef - make interpolation points absolute
411        d_midpoints[:,0] += xllcorner
412        d_midpoints[:,1] += yllcorner               
413
414        #Let us see if the file function can find the correct
415        #values at the midpoints at the last timestep:
416        f = file_function(filename, domain1,
417                          interpolation_points = d_midpoints)
418
419        t = time[last_time_index]                         
420        q = f(t, point_id=0); assert allclose(r0, q)
421        q = f(t, point_id=1); assert allclose(r1, q)
422        q = f(t, point_id=2); assert allclose(r2, q)
423
424
425        ##################
426        #Now do the same for the first timestep
427
428        timestep = 0 #First timestep
429        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
430        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
431        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
432        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
433
434        #Reference interpolated values at midpoints on diagonal at
435        #this timestep are
436        r0 = (D[0] + D[1])/2
437        r1 = (D[1] + D[2])/2
438        r2 = (D[2] + D[3])/2
439
440        #Let us see if the file function can find the correct
441        #values
442        q = f(0, point_id=0); assert allclose(r0, q)
443        q = f(0, point_id=1); assert allclose(r1, q)
444        q = f(0, point_id=2); assert allclose(r2, q)
445
446
447        ##################
448        #Now do it again for a timestep in the middle
449
450        timestep = 33
451        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
452        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
453        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
454        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
455
456        #Reference interpolated values at midpoints on diagonal at
457        #this timestep are
458        r0 = (D[0] + D[1])/2
459        r1 = (D[1] + D[2])/2
460        r2 = (D[2] + D[3])/2
461
462        q = f(timestep/10., point_id=0); assert allclose(r0, q)
463        q = f(timestep/10., point_id=1); assert allclose(r1, q)
464        q = f(timestep/10., point_id=2); assert allclose(r2, q)
465
466
467        ##################
468        #Now check temporal interpolation
469        #Halfway between timestep 15 and 16
470
471        timestep = 15
472        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
473        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
474        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
475        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
476
477        #Reference interpolated values at midpoints on diagonal at
478        #this timestep are
479        r0_0 = (D[0] + D[1])/2
480        r1_0 = (D[1] + D[2])/2
481        r2_0 = (D[2] + D[3])/2
482
483        #
484        timestep = 16
485        d_stage = reshape(take(stage[timestep, :], [0,5,10,15]), (4,1))
486        d_uh = reshape(take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
487        d_vh = reshape(take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
488        D = concatenate( (d_stage, d_uh, d_vh), axis=1)
489
490        #Reference interpolated values at midpoints on diagonal at
491        #this timestep are
492        r0_1 = (D[0] + D[1])/2
493        r1_1 = (D[1] + D[2])/2
494        r2_1 = (D[2] + D[3])/2
495
496        # The reference values are
497        r0 = (r0_0 + r0_1)/2
498        r1 = (r1_0 + r1_1)/2
499        r2 = (r2_0 + r2_1)/2
500
501        q = f((timestep - 0.5)/10., point_id=0); assert allclose(r0, q)
502        q = f((timestep - 0.5)/10., point_id=1); assert allclose(r1, q)
503        q = f((timestep - 0.5)/10., point_id=2); assert allclose(r2, q)
504
505        ##################
506        #Finally check interpolation 2 thirds of the way
507        #between timestep 15 and 16
508
509        # The reference values are
510        r0 = (r0_0 + 2*r0_1)/3
511        r1 = (r1_0 + 2*r1_1)/3
512        r2 = (r2_0 + 2*r2_1)/3
513
514        #And the file function gives
515        q = f((timestep - 1.0/3)/10., point_id=0); assert allclose(r0, q)
516        q = f((timestep - 1.0/3)/10., point_id=1); assert allclose(r1, q)
517        q = f((timestep - 1.0/3)/10., point_id=2); assert allclose(r2, q)
518
519        fid.close()
520        import os
521        os.remove(filename)
522
523       
524
525
526    def test_spatio_temporal_file_function_time(self):
527        """Test that File function interpolates correctly
528        between given times.
529        NetCDF version (x,y,t dependency)
530        """
531
532        #Create NetCDF (sww) file to be read
533        # x: 0, 5, 10, 15
534        # y: -20, -10, 0, 10
535        # t: 0, 60, 120, ...., 1200
536        #
537        # test quantities (arbitrary but non-trivial expressions):
538        #
539        #   stage     = 3*x - y**2 + 2*t
540        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
541        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
542
543        #NOTE: Nice test that may render some of the others redundant.
544
545        import os, time
546        from anuga.config import time_format
547        from numpy.oldnumeric import sin, pi, exp
548        from mesh_factory import rectangular
549        from shallow_water import Domain
550        import anuga.shallow_water.data_manager
551
552        finaltime = 1200
553        filename = 'test_file_function'
554
555        #Create a domain to hold test grid
556        #(0:15, -20:10)
557        points, vertices, boundary =\
558                rectangular(4, 4, 15, 30, origin = (0, -20))
559        #print "points", points
560
561        #print 'Number of elements', len(vertices)
562        domain = Domain(points, vertices, boundary)
563        domain.smooth = False
564        domain.default_order = 2
565        domain.set_datadir('.')
566        domain.set_name(filename)
567        domain.store = True
568        domain.format = 'sww'   #Native netcdf visualisation format
569
570        #print points
571        start = time.mktime(time.strptime('2000', '%Y'))
572        domain.starttime = start
573
574
575        #Store structure
576        domain.initialise_storage()
577
578        #Compute artificial time steps and store
579        dt = 60  #One minute intervals
580        t = 0.0
581        while t <= finaltime:
582            #Compute quantities
583            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
584            domain.set_quantity('stage', f1)
585
586            f2 = lambda x,y: x+y+t**2
587            domain.set_quantity('xmomentum', f2)
588
589            f3 = lambda x,y: x**2 + y**2 * sin(t*pi/600)
590            domain.set_quantity('ymomentum', f3)
591
592            #Store and advance time
593            domain.time = t
594            domain.store_timestep(domain.conserved_quantities)
595            t += dt
596
597
598        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
599     
600        #Deliberately set domain.starttime to too early
601        domain.starttime = start - 1
602
603        #Create file function
604        F = file_function(filename + '.sww', domain,
605                          quantities = domain.conserved_quantities,
606                          interpolation_points = interpolation_points)
607
608        #Check that FF updates fixes domain starttime
609        assert allclose(domain.starttime, start)
610
611        #Check that domain.starttime isn't updated if later
612        domain.starttime = start + 1
613        F = file_function(filename + '.sww', domain,
614                          quantities = domain.conserved_quantities,
615                          interpolation_points = interpolation_points)
616        assert allclose(domain.starttime, start+1)
617        domain.starttime = start
618
619
620        #Check linear interpolation in time
621        F = file_function(filename + '.sww', domain,
622                          quantities = domain.conserved_quantities,
623                          interpolation_points = interpolation_points)               
624        for id in range(len(interpolation_points)):
625            x = interpolation_points[id][0]
626            y = interpolation_points[id][1]
627
628            for i in range(20):
629                t = i*10
630                k = i%6
631
632                if k == 0:
633                    q0 = F(t, point_id=id)
634                    q1 = F(t+60, point_id=id)
635
636                if q0 == NAN:
637                    actual = q0
638                else:
639                    actual = (k*q1 + (6-k)*q0)/6
640                q = F(t, point_id=id)
641                #print i, k, t, q
642                #print ' ', q0
643                #print ' ', q1
644                #print "q",q
645                #print "actual", actual
646                #print
647                if q0 == NAN:
648                     self.failUnless( q == actual, 'Fail!')
649                else:
650                    assert allclose(q, actual)
651
652
653        #Another check of linear interpolation in time
654        for id in range(len(interpolation_points)):
655            q60 = F(60, point_id=id)
656            q120 = F(120, point_id=id)
657
658            t = 90 #Halfway between 60 and 120
659            q = F(t, point_id=id)
660            assert allclose( (q120+q60)/2, q )
661
662            t = 100 #Two thirds of the way between between 60 and 120
663            q = F(t, point_id=id)
664            assert allclose(q60/3 + 2*q120/3, q)
665
666
667
668        #Check that domain.starttime isn't updated if later than file starttime but earlier
669        #than file end time
670        delta = 23
671        domain.starttime = start + delta
672        F = file_function(filename + '.sww', domain,
673                          quantities = domain.conserved_quantities,
674                          interpolation_points = interpolation_points)
675        assert allclose(domain.starttime, start+delta)
676
677
678
679
680        #Now try interpolation with delta offset
681        for id in range(len(interpolation_points)):           
682            x = interpolation_points[id][0]
683            y = interpolation_points[id][1]
684
685            for i in range(20):
686                t = i*10
687                k = i%6
688
689                if k == 0:
690                    q0 = F(t-delta, point_id=id)
691                    q1 = F(t+60-delta, point_id=id)
692
693                q = F(t-delta, point_id=id)
694                assert allclose(q, (k*q1 + (6-k)*q0)/6)
695
696
697        os.remove(filename + '.sww')
698
699
700
701    def Xtest_spatio_temporal_file_function_time(self):
702        # FIXME: This passes but needs some TLC
703        # Test that File function interpolates correctly
704        # When some points are outside the mesh
705
706        import os, time
707        from anuga.config import time_format
708        from numpy.oldnumeric import sin, pi, exp
709        from mesh_factory import rectangular
710        from shallow_water import Domain
711        import anuga.shallow_water.data_manager 
712        from anuga.pmesh.mesh_interface import create_mesh_from_regions
713        finaltime = 1200
714       
715        filename = tempfile.mktemp()
716        #print "filename",filename
717        filename = 'test_file_function'
718
719        meshfilename = tempfile.mktemp(".tsh")
720
721        boundary_tags = {'walls':[0,1],'bom':[2]}
722       
723        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
724       
725        create_mesh_from_regions(polygon_absolute,
726                                 boundary_tags,
727                                 10000000,
728                                 filename=meshfilename)
729        domain = Domain(mesh_filename=meshfilename)
730        domain.smooth = False
731        domain.default_order = 2
732        domain.set_datadir('.')
733        domain.set_name(filename)
734        domain.store = True
735        domain.format = 'sww'   #Native netcdf visualisation format
736
737        #print points
738        start = time.mktime(time.strptime('2000', '%Y'))
739        domain.starttime = start
740       
741
742        #Store structure
743        domain.initialise_storage()
744
745        #Compute artificial time steps and store
746        dt = 60  #One minute intervals
747        t = 0.0
748        while t <= finaltime:
749            #Compute quantities
750            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
751            domain.set_quantity('stage', f1)
752
753            f2 = lambda x,y: x+y+t**2
754            domain.set_quantity('xmomentum', f2)
755
756            f3 = lambda x,y: x**2 + y**2 * sin(t*pi/600)
757            domain.set_quantity('ymomentum', f3)
758
759            #Store and advance time
760            domain.time = t
761            domain.store_timestep(domain.conserved_quantities)
762            t += dt
763
764        interpolation_points = [[1,0]]
765        interpolation_points = [[100,1000]]
766       
767        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
768                                [78787,78787],[7878,3432]]
769           
770        #Deliberately set domain.starttime to too early
771        domain.starttime = start - 1
772
773        #Create file function
774        F = file_function(filename + '.sww', domain,
775                          quantities = domain.conserved_quantities,
776                          interpolation_points = interpolation_points)
777
778        #Check that FF updates fixes domain starttime
779        assert allclose(domain.starttime, start)
780
781        #Check that domain.starttime isn't updated if later
782        domain.starttime = start + 1
783        F = file_function(filename + '.sww', domain,
784                          quantities = domain.conserved_quantities,
785                          interpolation_points = interpolation_points)
786        assert allclose(domain.starttime, start+1)
787        domain.starttime = start
788
789
790        #Check linear interpolation in time
791        # checking points inside and outside the mesh
792        F = file_function(filename + '.sww', domain,
793                          quantities = domain.conserved_quantities,
794                          interpolation_points = interpolation_points)
795       
796        for id in range(len(interpolation_points)):
797            x = interpolation_points[id][0]
798            y = interpolation_points[id][1]
799
800            for i in range(20):
801                t = i*10
802                k = i%6
803
804                if k == 0:
805                    q0 = F(t, point_id=id)
806                    q1 = F(t+60, point_id=id)
807
808                if q0 == NAN:
809                    actual = q0
810                else:
811                    actual = (k*q1 + (6-k)*q0)/6
812                q = F(t, point_id=id)
813                #print i, k, t, q
814                #print ' ', q0
815                #print ' ', q1
816                #print "q",q
817                #print "actual", actual
818                #print
819                if q0 == NAN:
820                     self.failUnless( q == actual, 'Fail!')
821                else:
822                    assert allclose(q, actual)
823
824        # now lets check points inside the mesh
825        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
826        interpolation_points = [[10,-12.5]]
827           
828        print "len(interpolation_points)",len(interpolation_points) 
829        F = file_function(filename + '.sww', domain,
830                          quantities = domain.conserved_quantities,
831                          interpolation_points = interpolation_points)
832
833        domain.starttime = start
834
835
836        #Check linear interpolation in time
837        F = file_function(filename + '.sww', domain,
838                          quantities = domain.conserved_quantities,
839                          interpolation_points = interpolation_points)               
840        for id in range(len(interpolation_points)):
841            x = interpolation_points[id][0]
842            y = interpolation_points[id][1]
843
844            for i in range(20):
845                t = i*10
846                k = i%6
847
848                if k == 0:
849                    q0 = F(t, point_id=id)
850                    q1 = F(t+60, point_id=id)
851
852                if q0 == NAN:
853                    actual = q0
854                else:
855                    actual = (k*q1 + (6-k)*q0)/6
856                q = F(t, point_id=id)
857                print "############"
858                print "id, x, y ", id, x, y #k, t, q
859                print "t", t
860                #print ' ', q0
861                #print ' ', q1
862                print "q",q
863                print "actual", actual
864                #print
865                if q0 == NAN:
866                     self.failUnless( q == actual, 'Fail!')
867                else:
868                    assert allclose(q, actual)
869
870
871        #Another check of linear interpolation in time
872        for id in range(len(interpolation_points)):
873            q60 = F(60, point_id=id)
874            q120 = F(120, point_id=id)
875
876            t = 90 #Halfway between 60 and 120
877            q = F(t, point_id=id)
878            assert allclose( (q120+q60)/2, q )
879
880            t = 100 #Two thirds of the way between between 60 and 120
881            q = F(t, point_id=id)
882            assert allclose(q60/3 + 2*q120/3, q)
883
884
885
886        #Check that domain.starttime isn't updated if later than file starttime but earlier
887        #than file end time
888        delta = 23
889        domain.starttime = start + delta
890        F = file_function(filename + '.sww', domain,
891                          quantities = domain.conserved_quantities,
892                          interpolation_points = interpolation_points)
893        assert allclose(domain.starttime, start+delta)
894
895
896
897
898        #Now try interpolation with delta offset
899        for id in range(len(interpolation_points)):           
900            x = interpolation_points[id][0]
901            y = interpolation_points[id][1]
902
903            for i in range(20):
904                t = i*10
905                k = i%6
906
907                if k == 0:
908                    q0 = F(t-delta, point_id=id)
909                    q1 = F(t+60-delta, point_id=id)
910
911                q = F(t-delta, point_id=id)
912                assert allclose(q, (k*q1 + (6-k)*q0)/6)
913
914
915        os.remove(filename + '.sww')
916
917    def test_file_function_time_with_domain(self):
918        """Test that File function interpolates correctly
919        between given times. No x,y dependency here.
920        Use domain with starttime
921        """
922
923        #Write file
924        import os, time, calendar
925        from anuga.config import time_format
926        from math import sin, pi
927        from domain import Domain
928
929        finaltime = 1200
930        filename = 'test_file_function'
931        fid = open(filename + '.txt', 'w')
932        start = time.mktime(time.strptime('2000', '%Y'))
933        dt = 60  #One minute intervals
934        t = 0.0
935        while t <= finaltime:
936            t_string = time.strftime(time_format, time.gmtime(t+start))
937            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
938            t += dt
939
940        fid.close()
941
942
943        #Convert ASCII file to NetCDF (Which is what we really like!)
944        timefile2netcdf(filename)
945
946
947
948        a = [0.0, 0.0]
949        b = [4.0, 0.0]
950        c = [0.0, 3.0]
951
952        points = [a, b, c]
953        vertices = [[0,1,2]]
954        domain = Domain(points, vertices)
955
956        # Check that domain.starttime is updated if non-existing
957        F = file_function(filename + '.tms',
958                          domain,
959                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
960        assert allclose(domain.starttime, start)
961
962        # Check that domain.starttime is updated if too early
963        domain.starttime = start - 1
964        F = file_function(filename + '.tms',
965                          domain,
966                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
967        assert allclose(domain.starttime, start)
968
969        # Check that domain.starttime isn't updated if later
970        domain.starttime = start + 1
971        F = file_function(filename + '.tms',
972                          domain,
973                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
974        assert allclose(domain.starttime, start+1)
975
976        domain.starttime = start
977        F = file_function(filename + '.tms',
978                          domain,
979                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
980                          use_cache=True)
981       
982
983        #print F.precomputed_values
984        #print 'F(60)', F(60)
985       
986        #Now try interpolation
987        for i in range(20):
988            t = i*10
989            q = F(t)
990
991            #Exact linear intpolation
992            assert allclose(q[0], 2*t)
993            if i%6 == 0:
994                assert allclose(q[1], t**2)
995                assert allclose(q[2], sin(t*pi/600))
996
997        #Check non-exact
998
999        t = 90 #Halfway between 60 and 120
1000        q = F(t)
1001        assert allclose( (120**2 + 60**2)/2, q[1] )
1002        assert allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1003
1004
1005        t = 100 #Two thirds of the way between between 60 and 120
1006        q = F(t)
1007        assert allclose( 2*120**2/3 + 60**2/3, q[1] )
1008        assert allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1009
1010        os.remove(filename + '.tms')
1011        os.remove(filename + '.txt')       
1012
1013    def test_file_function_time_with_domain_different_start(self):
1014        """Test that File function interpolates correctly
1015        between given times. No x,y dependency here.
1016        Use domain with a starttime later than that of file
1017
1018        ASCII version
1019        """
1020
1021        #Write file
1022        import os, time, calendar
1023        from anuga.config import time_format
1024        from math import sin, pi
1025        from domain import Domain
1026
1027        finaltime = 1200
1028        filename = 'test_file_function'
1029        fid = open(filename + '.txt', 'w')
1030        start = time.mktime(time.strptime('2000', '%Y'))
1031        dt = 60  #One minute intervals
1032        t = 0.0
1033        while t <= finaltime:
1034            t_string = time.strftime(time_format, time.gmtime(t+start))
1035            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1036            t += dt
1037
1038        fid.close()
1039
1040        #Convert ASCII file to NetCDF (Which is what we really like!)
1041        timefile2netcdf(filename)       
1042
1043        a = [0.0, 0.0]
1044        b = [4.0, 0.0]
1045        c = [0.0, 3.0]
1046
1047        points = [a, b, c]
1048        vertices = [[0,1,2]]
1049        domain = Domain(points, vertices)
1050
1051        #Check that domain.starttime isn't updated if later than file starttime but earlier
1052        #than file end time
1053        delta = 23
1054        domain.starttime = start + delta
1055        F = file_function(filename + '.tms', domain,
1056                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1057        assert allclose(domain.starttime, start+delta)
1058
1059
1060
1061
1062        #Now try interpolation with delta offset
1063        for i in range(20):
1064            t = i*10
1065            q = F(t-delta)
1066
1067            #Exact linear intpolation
1068            assert allclose(q[0], 2*t)
1069            if i%6 == 0:
1070                assert allclose(q[1], t**2)
1071                assert allclose(q[2], sin(t*pi/600))
1072
1073        #Check non-exact
1074
1075        t = 90 #Halfway between 60 and 120
1076        q = F(t-delta)
1077        assert allclose( (120**2 + 60**2)/2, q[1] )
1078        assert allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1079
1080
1081        t = 100 #Two thirds of the way between between 60 and 120
1082        q = F(t-delta)
1083        assert allclose( 2*120**2/3 + 60**2/3, q[1] )
1084        assert allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1085
1086
1087        os.remove(filename + '.tms')
1088        os.remove(filename + '.txt')               
1089
1090
1091
1092    def test_apply_expression_to_dictionary(self):
1093
1094        #FIXME: Division is not expected to work for integers.
1095        #This must be caught.
1096        foo = array([[1,2,3],
1097                     [4,5,6]], float)
1098
1099        bar = array([[-1,0,5],
1100                     [6,1,1]], float)                 
1101
1102        D = {'X': foo, 'Y': bar}
1103
1104        Z = apply_expression_to_dictionary('X+Y', D)       
1105        assert allclose(Z, foo+bar)
1106
1107        Z = apply_expression_to_dictionary('X*Y', D)       
1108        assert allclose(Z, foo*bar)       
1109
1110        Z = apply_expression_to_dictionary('4*X+Y', D)       
1111        assert allclose(Z, 4*foo+bar)       
1112
1113        # test zero division is OK
1114        Z = apply_expression_to_dictionary('X/Y', D)
1115        assert allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1116
1117        # make an error for zero on zero
1118        # this is really an error in Numeric, SciPy core can handle it
1119        # Z = apply_expression_to_dictionary('0/Y', D)
1120
1121        #Check exceptions
1122        try:
1123            #Wrong name
1124            Z = apply_expression_to_dictionary('4*X+A', D)       
1125        except NameError:
1126            pass
1127        else:
1128            msg = 'Should have raised a NameError Exception'
1129            raise msg
1130
1131
1132        try:
1133            #Wrong order
1134            Z = apply_expression_to_dictionary(D, '4*X+A')       
1135        except AssertionError:
1136            pass
1137        else:
1138            msg = 'Should have raised a AssertionError Exception'
1139            raise msg       
1140       
1141
1142    def test_multiple_replace(self):
1143        """Hard test that checks a true word-by-word simultaneous replace
1144        """
1145       
1146        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1147        exp = '3*x+y + xi'
1148       
1149        new = multiple_replace(exp, D)
1150       
1151        assert new == '3*xi+eta + lam'
1152                         
1153
1154
1155    def test_point_on_line_obsolete(self):
1156        """Test that obsolete call issues appropriate warning"""
1157
1158        #Turn warning into an exception
1159        import warnings
1160        warnings.filterwarnings('error')
1161
1162        try:
1163            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1164        except DeprecationWarning:
1165            pass
1166        else:
1167            msg = 'point_on_line should have issued a DeprecationWarning'
1168            raise Exception(msg)   
1169
1170        warnings.resetwarnings()
1171   
1172    def test_get_revision_number(self):
1173        """test_get_revision_number(self):
1174
1175        Test that revision number can be retrieved.
1176        """
1177        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1178            # I have a known snv incompatability issue,
1179            # so I'm skipping this test.
1180            # FIXME when SVN is upgraded on our clusters
1181            pass
1182        else:   
1183            n = get_revision_number()
1184            assert n>=0
1185
1186
1187       
1188    def test_add_directories(self):
1189       
1190        import tempfile
1191        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1192        directories = ['ja','ne','ke']
1193        kens_dir = add_directories(root_dir, directories)
1194        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1195               sep + 'ke'
1196        assert access(root_dir,F_OK)
1197
1198        add_directories(root_dir, directories)
1199        assert access(root_dir,F_OK)
1200       
1201        #clean up!
1202        os.rmdir(kens_dir)
1203        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1204        os.rmdir(root_dir + sep + 'ja')
1205        os.rmdir(root_dir)
1206
1207    def test_add_directories_bad(self):
1208       
1209        import tempfile
1210        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1211        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1212       
1213        try:
1214            kens_dir = add_directories(root_dir, directories)
1215        except OSError:
1216            pass
1217        else:
1218            msg = 'bad dir name should give OSError'
1219            raise Exception(msg)   
1220           
1221        #clean up!
1222        os.rmdir(root_dir)
1223
1224    def test_check_list(self):
1225
1226        check_list(['stage','xmomentum'])
1227
1228       
1229    def test_add_directories(self):
1230       
1231        import tempfile
1232        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1233        directories = ['ja','ne','ke']
1234        kens_dir = add_directories(root_dir, directories)
1235        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1236               sep + 'ke'
1237        assert access(root_dir,F_OK)
1238
1239        add_directories(root_dir, directories)
1240        assert access(root_dir,F_OK)
1241       
1242        #clean up!
1243        os.rmdir(kens_dir)
1244        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1245        os.rmdir(root_dir + sep + 'ja')
1246        os.rmdir(root_dir)
1247
1248    def test_add_directories_bad(self):
1249       
1250        import tempfile
1251        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1252        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1253       
1254        try:
1255            kens_dir = add_directories(root_dir, directories)
1256        except OSError:
1257            pass
1258        else:
1259            msg = 'bad dir name should give OSError'
1260            raise Exception(msg)   
1261           
1262        #clean up!
1263        os.rmdir(root_dir)
1264
1265    def test_check_list(self):
1266
1267        check_list(['stage','xmomentum'])
1268       
1269    def test_remove_lone_verts_d(self):
1270        verts = [[0,0],[1,0],[0,1]]
1271        tris = [[0,1,2]]
1272        new_verts, new_tris = remove_lone_verts(verts, tris)
1273        assert alltrue(new_verts == verts)
1274        assert alltrue(new_tris == tris)
1275     
1276
1277    def test_remove_lone_verts_e(self):
1278        verts = [[0,0],[1,0],[0,1],[99,99]]
1279        tris = [[0,1,2]]
1280        new_verts, new_tris = remove_lone_verts(verts, tris)
1281        assert new_verts == verts[0:3]
1282        assert new_tris == tris
1283       
1284    def test_remove_lone_verts_a(self):
1285        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1286        tris = [[1,2,4]]
1287        new_verts, new_tris = remove_lone_verts(verts, tris)
1288        #print "new_verts", new_verts
1289        assert alltrue(new_verts == [[0,0],[1,0],[0,1]])
1290        assert alltrue(new_tris == [[0,1,2]])
1291     
1292    def test_remove_lone_verts_c(self):
1293        verts = [[0,0],[1,0],[99,99],[0,1]]
1294        tris = [[0,1,3]]
1295        new_verts, new_tris = remove_lone_verts(verts, tris)
1296        print "new_verts", new_verts
1297        assert alltrue(new_verts == [[0,0],[1,0],[0,1]])
1298        assert alltrue(new_tris == [[0,1,2]])
1299       
1300    def test_remove_lone_verts_b(self):
1301        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1302        tris = [[0,1,2]]
1303        new_verts, new_tris = remove_lone_verts(verts, tris)
1304        assert alltrue(new_verts == verts[0:3])
1305        assert alltrue(new_tris == tris)
1306     
1307
1308    def test_remove_lone_verts_e(self):
1309        verts = [[0,0],[1,0],[0,1],[99,99]]
1310        tris = [[0,1,2]]
1311        new_verts, new_tris = remove_lone_verts(verts, tris)
1312        assert alltrue(new_verts == verts[0:3])
1313        assert alltrue(new_tris == tris)
1314       
1315    def test_get_min_max_values(self):
1316       
1317        list=[8,9,6,1,4]
1318        min1, max1 = get_min_max_values(list)
1319       
1320        assert min1==1 
1321        assert max1==9
1322       
1323    def test_get_min_max_values1(self):
1324       
1325        list=[-8,-9,-6,-1,-4]
1326        min1, max1 = get_min_max_values(list)
1327       
1328#        print 'min1,max1',min1,max1
1329        assert min1==-9 
1330        assert max1==-1
1331
1332#    def test_get_min_max_values2(self):
1333#        '''
1334#        The min and max supplied are greater than the ones in the
1335#        list and therefore are the ones returned
1336#        '''
1337#        list=[-8,-9,-6,-1,-4]
1338#        min1, max1 = get_min_max_values(list,-10,10)
1339#       
1340##        print 'min1,max1',min1,max1
1341#        assert min1==-10
1342#        assert max1==10
1343       
1344    def test_make_plots_from_csv_files(self):
1345       
1346        #if sys.platform == 'win32':  #Windows
1347            try: 
1348                import pylab
1349            except ImportError:
1350                #ANUGA don't need pylab to work so the system doesn't
1351                #rely on pylab being installed
1352                return
1353           
1354       
1355            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1356            temp_dir = tempfile.mkdtemp('','figures')
1357    #        print 'temp_dir',temp_dir
1358            fileName = temp_dir+sep+'time_series_3.csv'
1359            file = open(fileName,"w")
1360            file.write("time,stage,speed,momentum,elevation\n\
13611.0, 0, 0, 0, 10 \n\
13622.0, 5, 2, 4, 10 \n\
13633.0, 3, 3, 5, 10 \n")
1364            file.close()
1365   
1366            fileName1 = temp_dir+sep+'time_series_4.csv'
1367            file1 = open(fileName1,"w")
1368            file1.write("time,stage,speed,momentum,elevation\n\
13691.0, 0, 0, 0, 5 \n\
13702.0, -5, -2, -4, 5 \n\
13713.0, -4, -3, -5, 5 \n")
1372            file1.close()
1373   
1374            fileName2 = temp_dir+sep+'time_series_5.csv'
1375            file2 = open(fileName2,"w")
1376            file2.write("time,stage,speed,momentum,elevation\n\
13771.0, 0, 0, 0, 7 \n\
13782.0, 4, -0.45, 57, 7 \n\
13793.0, 6, -0.5, 56, 7 \n")
1380            file2.close()
1381           
1382            dir, name=os.path.split(fileName)
1383            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1384                                  output_dir=temp_dir,
1385                                  base_name='time_series_',
1386                                  plot_numbers=['3-5'],
1387                                  quantities=['speed','stage','momentum'],
1388                                  assess_all_csv_files=True,
1389                                  extra_plot_name='test')
1390           
1391            #print dir+sep+name[:-4]+'_stage_test.png'
1392            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1393            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1394            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1395   
1396            dir1, name1=os.path.split(fileName1)
1397            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1398            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1399            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1400   
1401   
1402            dir2, name2=os.path.split(fileName2)
1403            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1404            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1405            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1406   
1407            del_dir(temp_dir)
1408       
1409
1410    def test_sww2csv_gauges(self):
1411
1412        def elevation_function(x, y):
1413            return -x
1414       
1415        """Most of this test was copied from test_interpolate
1416        test_interpole_sww2csv
1417       
1418        This is testing the gauge_sww2csv function, by creating a sww file and
1419        then exporting the gauges and checking the results.
1420        """
1421       
1422        # Create mesh
1423        mesh_file = tempfile.mktemp(".tsh")   
1424        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1425        m = Mesh()
1426        m.add_vertices(points)
1427        m.auto_segment()
1428        m.generate_mesh(verbose=False)
1429        m.export_mesh_file(mesh_file)
1430       
1431        # Create shallow water domain
1432        domain = Domain(mesh_file)
1433        os.remove(mesh_file)
1434       
1435        domain.default_order=2
1436       
1437        # This test was made before tight_slope_limiters were introduced
1438        # Since were are testing interpolation values this is OK
1439        domain.tight_slope_limiters = 0 
1440       
1441
1442        # Set some field values
1443        domain.set_quantity('elevation', elevation_function)
1444        domain.set_quantity('friction', 0.03)
1445        domain.set_quantity('xmomentum', 3.0)
1446        domain.set_quantity('ymomentum', 4.0)
1447
1448        ######################
1449        # Boundary conditions
1450        B = Transmissive_boundary(domain)
1451        domain.set_boundary( {'exterior': B})
1452
1453        # This call mangles the stage values.
1454        domain.distribute_to_vertices_and_edges()
1455        domain.set_quantity('stage', 1.0)
1456
1457
1458        domain.set_name('datatest' + str(time.time()))
1459        domain.format = 'sww'
1460        domain.smooth = True
1461        domain.reduction = mean
1462
1463
1464        sww = get_dataobject(domain)
1465        sww.store_connectivity()
1466        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1467        domain.set_quantity('stage', 10.0) # This is automatically limited
1468        # so it will not be less than the elevation
1469        domain.time = 2.
1470        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1471
1472        # test the function
1473        points = [[5.0,1.],[0.5,2.]]
1474
1475        points_file = tempfile.mktemp(".csv")
1476#        points_file = 'test_point.csv'
1477        file_id = open(points_file,"w")
1478        file_id.write("name, easting, northing, elevation \n\
1479point1, 5.0, 1.0, 3.0\n\
1480point2, 0.5, 2.0, 9.0\n")
1481        file_id.close()
1482
1483       
1484        sww2csv_gauges(sww.filename, 
1485                       points_file,
1486                       verbose=False,
1487                       use_cache=False)
1488
1489#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1490        point1_answers_array = [[0.0,1.0,6.0,-5.0,3.0,4.0], [2.0,10.0,15.0,-5.0,3.0,4.0]]
1491        point1_filename = 'gauge_point1.csv'
1492        point1_handle = file(point1_filename)
1493        point1_reader = reader(point1_handle)
1494        point1_reader.next()
1495
1496        line=[]
1497        for i,row in enumerate(point1_reader):
1498            #print 'i',i,'row',row
1499            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1500            #print 'assert line',line[i],'point1',point1_answers_array[i]
1501            assert allclose(line[i], point1_answers_array[i])
1502
1503        point2_answers_array = [[0.0,1.0,1.5,-0.5,3.0,4.0], [2.0,10.0,10.5,-0.5,3.0,4.0]]
1504        point2_filename = 'gauge_point2.csv' 
1505        point2_handle = file(point2_filename)
1506        point2_reader = reader(point2_handle)
1507        point2_reader.next()
1508                       
1509        line=[]
1510        for i,row in enumerate(point2_reader):
1511            #print 'i',i,'row',row
1512            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1513            #print 'assert line',line[i],'point1',point1_answers_array[i]
1514            assert allclose(line[i], point2_answers_array[i])
1515                         
1516        # clean up
1517        point1_handle.close()
1518        point2_handle.close()
1519        #print "sww.filename",sww.filename
1520        os.remove(sww.filename)
1521        os.remove(points_file)
1522        os.remove(point1_filename)
1523        os.remove(point2_filename)
1524
1525
1526
1527    def test_sww2csv_gauges1(self):
1528        from anuga.pmesh.mesh import Mesh
1529        from anuga.shallow_water import Domain, Transmissive_boundary
1530        from anuga.shallow_water.data_manager import get_dataobject
1531        from csv import reader,writer
1532        import time
1533        import string
1534
1535        def elevation_function(x, y):
1536            return -x
1537       
1538        """Most of this test was copied from test_interpolate
1539        test_interpole_sww2csv
1540       
1541        This is testing the gauge_sww2csv function, by creating a sww file and
1542        then exporting the gauges and checking the results.
1543       
1544        This tests the ablity not to have elevation in the points file and
1545        not store xmomentum and ymomentum
1546        """
1547       
1548        # Create mesh
1549        mesh_file = tempfile.mktemp(".tsh")   
1550        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1551        m = Mesh()
1552        m.add_vertices(points)
1553        m.auto_segment()
1554        m.generate_mesh(verbose=False)
1555        m.export_mesh_file(mesh_file)
1556       
1557        # Create shallow water domain
1558        domain = Domain(mesh_file)
1559        os.remove(mesh_file)
1560       
1561        domain.default_order=2
1562
1563        # Set some field values
1564        domain.set_quantity('elevation', elevation_function)
1565        domain.set_quantity('friction', 0.03)
1566        domain.set_quantity('xmomentum', 3.0)
1567        domain.set_quantity('ymomentum', 4.0)
1568
1569        ######################
1570        # Boundary conditions
1571        B = Transmissive_boundary(domain)
1572        domain.set_boundary( {'exterior': B})
1573
1574        # This call mangles the stage values.
1575        domain.distribute_to_vertices_and_edges()
1576        domain.set_quantity('stage', 1.0)
1577
1578
1579        domain.set_name('datatest' + str(time.time()))
1580        domain.format = 'sww'
1581        domain.smooth = True
1582        domain.reduction = mean
1583
1584        sww = get_dataobject(domain)
1585        sww.store_connectivity()
1586        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1587        domain.set_quantity('stage', 10.0) # This is automatically limited
1588        # so it will not be less than the elevation
1589        domain.time = 2.
1590        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1591
1592        # test the function
1593        points = [[5.0,1.],[0.5,2.]]
1594
1595        points_file = tempfile.mktemp(".csv")
1596#        points_file = 'test_point.csv'
1597        file_id = open(points_file,"w")
1598        file_id.write("name,easting,northing \n\
1599point1, 5.0, 1.0\n\
1600point2, 0.5, 2.0\n")
1601        file_id.close()
1602
1603        sww2csv_gauges(sww.filename, 
1604                            points_file,
1605                            quantities=['stage', 'elevation'],
1606                            use_cache=False,
1607                            verbose=False)
1608
1609        point1_answers_array = [[0.0,1.0,-5.0], [2.0,10.0,-5.0]]
1610        point1_filename = 'gauge_point1.csv'
1611        point1_handle = file(point1_filename)
1612        point1_reader = reader(point1_handle)
1613        point1_reader.next()
1614
1615        line=[]
1616        for i,row in enumerate(point1_reader):
1617#            print 'i',i,'row',row
1618            line.append([float(row[0]),float(row[1]),float(row[2])])
1619            #print 'line',line[i],'point1',point1_answers_array[i]
1620            assert allclose(line[i], point1_answers_array[i])
1621
1622        point2_answers_array = [[0.0,1.0,-0.5], [2.0,10.0,-0.5]]
1623        point2_filename = 'gauge_point2.csv' 
1624        point2_handle = file(point2_filename)
1625        point2_reader = reader(point2_handle)
1626        point2_reader.next()
1627                       
1628        line=[]
1629        for i,row in enumerate(point2_reader):
1630#            print 'i',i,'row',row
1631            line.append([float(row[0]),float(row[1]),float(row[2])])
1632#            print 'line',line[i],'point1',point1_answers_array[i]
1633            assert allclose(line[i], point2_answers_array[i])
1634                         
1635        # clean up
1636        point1_handle.close()
1637        point2_handle.close()
1638        #print "sww.filename",sww.filename
1639        os.remove(sww.filename)
1640        os.remove(points_file)
1641        os.remove(point1_filename)
1642        os.remove(point2_filename)
1643
1644
1645    def test_sww2csv_gauges2(self):
1646
1647        def elevation_function(x, y):
1648            return -x
1649       
1650        """Most of this test was copied from test_interpolate
1651        test_interpole_sww2csv
1652       
1653        This is testing the gauge_sww2csv function, by creating a sww file and
1654        then exporting the gauges and checking the results.
1655       
1656        This is the same as sww2csv_gauges except set domain.set_starttime to 5.
1657        Therefore testing the storing of the absolute time in the csv files
1658        """
1659       
1660        # Create mesh
1661        mesh_file = tempfile.mktemp(".tsh")   
1662        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1663        m = Mesh()
1664        m.add_vertices(points)
1665        m.auto_segment()
1666        m.generate_mesh(verbose=False)
1667        m.export_mesh_file(mesh_file)
1668       
1669        # Create shallow water domain
1670        domain = Domain(mesh_file)
1671        os.remove(mesh_file)
1672       
1673        domain.default_order=2
1674
1675        # This test was made before tight_slope_limiters were introduced
1676        # Since were are testing interpolation values this is OK
1677        domain.tight_slope_limiters = 0         
1678
1679        # Set some field values
1680        domain.set_quantity('elevation', elevation_function)
1681        domain.set_quantity('friction', 0.03)
1682        domain.set_quantity('xmomentum', 3.0)
1683        domain.set_quantity('ymomentum', 4.0)
1684        domain.set_starttime(5)
1685
1686        ######################
1687        # Boundary conditions
1688        B = Transmissive_boundary(domain)
1689        domain.set_boundary( {'exterior': B})
1690
1691        # This call mangles the stage values.
1692        domain.distribute_to_vertices_and_edges()
1693        domain.set_quantity('stage', 1.0)
1694       
1695
1696
1697        domain.set_name('datatest' + str(time.time()))
1698        domain.format = 'sww'
1699        domain.smooth = True
1700        domain.reduction = mean
1701
1702        sww = get_dataobject(domain)
1703        sww.store_connectivity()
1704        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1705        domain.set_quantity('stage', 10.0) # This is automatically limited
1706        # so it will not be less than the elevation
1707        domain.time = 2.
1708        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1709
1710        # test the function
1711        points = [[5.0,1.],[0.5,2.]]
1712
1713        points_file = tempfile.mktemp(".csv")
1714#        points_file = 'test_point.csv'
1715        file_id = open(points_file,"w")
1716        file_id.write("name, easting, northing, elevation \n\
1717point1, 5.0, 1.0, 3.0\n\
1718point2, 0.5, 2.0, 9.0\n")
1719        file_id.close()
1720
1721       
1722        sww2csv_gauges(sww.filename, 
1723                            points_file,
1724                            verbose=False,
1725                            use_cache=False)
1726
1727#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1728        point1_answers_array = [[5.0,1.0,6.0,-5.0,3.0,4.0], [7.0,10.0,15.0,-5.0,3.0,4.0]]
1729        point1_filename = 'gauge_point1.csv'
1730        point1_handle = file(point1_filename)
1731        point1_reader = reader(point1_handle)
1732        point1_reader.next()
1733
1734        line=[]
1735        for i,row in enumerate(point1_reader):
1736            #print 'i',i,'row',row
1737            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1738            #print 'assert line',line[i],'point1',point1_answers_array[i]
1739            assert allclose(line[i], point1_answers_array[i])
1740
1741        point2_answers_array = [[5.0,1.0,1.5,-0.5,3.0,4.0], [7.0,10.0,10.5,-0.5,3.0,4.0]]
1742        point2_filename = 'gauge_point2.csv' 
1743        point2_handle = file(point2_filename)
1744        point2_reader = reader(point2_handle)
1745        point2_reader.next()
1746                       
1747        line=[]
1748        for i,row in enumerate(point2_reader):
1749            #print 'i',i,'row',row
1750            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1751            #print 'assert line',line[i],'point1',point1_answers_array[i]
1752            assert allclose(line[i], point2_answers_array[i])
1753                         
1754        # clean up
1755        point1_handle.close()
1756        point2_handle.close()
1757        #print "sww.filename",sww.filename
1758        os.remove(sww.filename)
1759        os.remove(points_file)
1760        os.remove(point1_filename)
1761        os.remove(point2_filename)
1762
1763
1764    def test_greens_law(self):
1765
1766        from math import sqrt
1767       
1768        d1 = 80.0
1769        d2 = 20.0
1770        h1 = 1.0
1771        h2 = greens_law(d1,d2,h1)
1772
1773        assert h2==sqrt(2.0)
1774       
1775    def test_calc_bearings(self):
1776 
1777        from math import atan, degrees
1778        #Test East
1779        uh = 1
1780        vh = 1.e-15
1781        angle = calc_bearing(uh, vh)
1782        if 89 < angle < 91: v=1
1783        assert v==1
1784        #Test West
1785        uh = -1
1786        vh = 1.e-15
1787        angle = calc_bearing(uh, vh)
1788        if 269 < angle < 271: v=1
1789        assert v==1
1790        #Test North
1791        uh = 1.e-15
1792        vh = 1
1793        angle = calc_bearing(uh, vh)
1794        if -1 < angle < 1: v=1
1795        assert v==1
1796        #Test South
1797        uh = 1.e-15
1798        vh = -1
1799        angle = calc_bearing(uh, vh)
1800        if 179 < angle < 181: v=1
1801        assert v==1
1802        #Test South-East
1803        uh = 1
1804        vh = -1
1805        angle = calc_bearing(uh, vh)
1806        if 134 < angle < 136: v=1
1807        assert v==1
1808        #Test North-East
1809        uh = 1
1810        vh = 1
1811        angle = calc_bearing(uh, vh)
1812        if 44 < angle < 46: v=1
1813        assert v==1
1814        #Test South-West
1815        uh = -1
1816        vh = -1
1817        angle = calc_bearing(uh, vh)
1818        if 224 < angle < 226: v=1
1819        assert v==1
1820        #Test North-West
1821        uh = -1
1822        vh = 1
1823        angle = calc_bearing(uh, vh)
1824        if 314 < angle < 316: v=1
1825        assert v==1
1826
1827 
1828
1829
1830       
1831
1832#-------------------------------------------------------------
1833if __name__ == "__main__":
1834    suite = unittest.makeSuite(Test_Util,'test')
1835#    suite = unittest.makeSuite(Test_Util,'test_sww2csv')
1836#    runner = unittest.TextTestRunner(verbosity=2)
1837    runner = unittest.TextTestRunner(verbosity=1)
1838    runner.run(suite)
1839
1840
1841
1842
Note: See TracBrowser for help on using the repository browser.