source: anuga_core/source/anuga/abstract_2d_finite_volumes/test_util.py @ 7342

Last change on this file since 7342 was 7342, checked in by ole, 15 years ago

Refactored stofrage of quantities to use new concept of static and dynamic quantities.
This is in preparation for flexibly allowing quantities such as elevation or friction
to be time dependent.

All tests and the okushiri validation pass.

File size: 66.1 KB
Line 
1#!/usr/bin/env python
2
3
4import unittest
5from math import sqrt, pi
6import tempfile, os
7from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
8
9from anuga.abstract_2d_finite_volumes.util import *
10from anuga.config import epsilon
11from anuga.shallow_water.data_manager import timefile2netcdf,del_dir
12
13from anuga.utilities.numerical_tools import NAN
14
15from sys import platform
16
17from anuga.pmesh.mesh import Mesh
18from anuga.shallow_water import Domain, Transmissive_boundary
19from anuga.shallow_water.data_manager import SWW_file
20from csv import reader,writer
21import time
22import string
23
24import numpy as num
25
26
27def test_function(x, y):
28    return x+y
29
30class Test_Util(unittest.TestCase):
31    def setUp(self):
32        pass
33
34    def tearDown(self):
35        pass
36
37
38
39
40    #Geometric
41    #def test_distance(self):
42    #    from anuga.abstract_2d_finite_volumes.util import distance#
43    #
44    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
45    #                     'Distance is wrong!')
46    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
47    #                    'distance is wrong!')
48    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
49    #                    'distance is wrong!')
50    #
51    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
52    #                    'distance is wrong!')
53
54
55    def test_file_function_time1(self):
56        """Test that File function interpolates correctly
57        between given times. No x,y dependency here.
58        """
59
60        #Write file
61        import os, time
62        from anuga.config import time_format
63        from math import sin, pi
64
65        #Typical ASCII file
66        finaltime = 1200
67        filename = 'test_file_function'
68        fid = open(filename + '.txt', 'w')
69        start = time.mktime(time.strptime('2000', '%Y'))
70        dt = 60  #One minute intervals
71        t = 0.0
72        while t <= finaltime:
73            t_string = time.strftime(time_format, time.gmtime(t+start))
74            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
75            t += dt
76
77        fid.close()
78
79        #Convert ASCII file to NetCDF (Which is what we really like!)
80        timefile2netcdf(filename)
81
82
83        #Create file function from time series
84        F = file_function(filename + '.tms',
85                          quantities = ['Attribute0',
86                                        'Attribute1',
87                                        'Attribute2'])
88       
89        #Now try interpolation
90        for i in range(20):
91            t = i*10
92            q = F(t)
93
94            #Exact linear intpolation
95            assert num.allclose(q[0], 2*t)
96            if i%6 == 0:
97                assert num.allclose(q[1], t**2)
98                assert num.allclose(q[2], sin(t*pi/600))
99
100        #Check non-exact
101
102        t = 90 #Halfway between 60 and 120
103        q = F(t)
104        assert num.allclose( (120**2 + 60**2)/2, q[1] )
105        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
106
107
108        t = 100 #Two thirds of the way between between 60 and 120
109        q = F(t)
110        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
111        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
112
113        os.remove(filename + '.txt')
114        os.remove(filename + '.tms')       
115
116
117       
118    def test_spatio_temporal_file_function_basic(self):
119        """Test that spatio temporal file function performs the correct
120        interpolations in both time and space
121        NetCDF version (x,y,t dependency)       
122        """
123        import time
124
125        #Create sww file of simple propagation from left to right
126        #through rectangular domain
127        from shallow_water import Domain, Dirichlet_boundary
128        from mesh_factory import rectangular
129
130        #Create basic mesh and shallow water domain
131        points, vertices, boundary = rectangular(3, 3)
132        domain1 = Domain(points, vertices, boundary)
133
134        from anuga.utilities.numerical_tools import mean
135        domain1.reduction = mean
136        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
137                              # only one value.
138
139        domain1.default_order = 2
140        domain1.store = True
141        domain1.set_datadir('.')
142        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
143
144        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
145        domain1.set_quantity('elevation', 0)
146        domain1.set_quantity('friction', 0)
147        domain1.set_quantity('stage', 0)
148
149        # Boundary conditions
150        B0 = Dirichlet_boundary([0,0,0])
151        B6 = Dirichlet_boundary([0.6,0,0])
152        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
153        domain1.check_integrity()
154
155        finaltime = 8
156        #Evolution
157        t0 = -1
158        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
159            #print 'Timesteps: %.16f, %.16f' %(t0, t)
160            #if t == t0:
161            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
162            #   raise msg
163            t0 = t
164             
165            #domain1.write_time()
166
167
168        #Now read data from sww and check
169        from Scientific.IO.NetCDF import NetCDFFile
170        filename = domain1.get_name() + '.sww'
171        fid = NetCDFFile(filename)
172
173        x = fid.variables['x'][:]
174        y = fid.variables['y'][:]
175        stage = fid.variables['stage'][:]
176        xmomentum = fid.variables['xmomentum'][:]
177        ymomentum = fid.variables['ymomentum'][:]
178        time = fid.variables['time'][:]
179
180        #Take stage vertex values at last timestep on diagonal
181        #Diagonal is identified by vertices: 0, 5, 10, 15
182
183        last_time_index = len(time)-1 #Last last_time_index
184        d_stage = num.reshape(num.take(stage[last_time_index, :],
185                                       [0,5,10,15],
186                                       axis=0),
187                              (4,1))
188        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
189                                    [0,5,10,15],
190                                   axis=0),
191                           (4,1))
192        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
193                                    [0,5,10,15],
194                                   axis=0),
195                           (4,1))
196        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
197
198        #Reference interpolated values at midpoints on diagonal at
199        #this timestep are
200        r0 = (D[0] + D[1])/2
201        r1 = (D[1] + D[2])/2
202        r2 = (D[2] + D[3])/2
203
204        #And the midpoints are found now
205        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
206        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
207
208        diag = num.concatenate( (Dx, Dy), axis=1)
209        d_midpoints = (diag[1:] + diag[:-1])/2
210
211        #Let us see if the file function can find the correct
212        #values at the midpoints at the last timestep:
213        f = file_function(filename, domain1,
214                          interpolation_points = d_midpoints)
215
216        T = f.get_time()
217        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
218        assert not T[-1] == T[-2], msg
219        t = time[last_time_index]
220        q = f(t, point_id=0); assert num.allclose(r0, q)
221        q = f(t, point_id=1); assert num.allclose(r1, q)
222        q = f(t, point_id=2); assert num.allclose(r2, q)
223
224
225        ##################
226        #Now do the same for the first timestep
227
228        timestep = 0 #First timestep
229        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
230        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
231        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
232        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
233
234        #Reference interpolated values at midpoints on diagonal at
235        #this timestep are
236        r0 = (D[0] + D[1])/2
237        r1 = (D[1] + D[2])/2
238        r2 = (D[2] + D[3])/2
239
240        #Let us see if the file function can find the correct
241        #values
242        q = f(0, point_id=0); assert num.allclose(r0, q)
243        q = f(0, point_id=1); assert num.allclose(r1, q)
244        q = f(0, point_id=2); assert num.allclose(r2, q)
245
246
247        ##################
248        #Now do it again for a timestep in the middle
249
250        timestep = 33
251        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
252        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
253        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
254        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
255
256        #Reference interpolated values at midpoints on diagonal at
257        #this timestep are
258        r0 = (D[0] + D[1])/2
259        r1 = (D[1] + D[2])/2
260        r2 = (D[2] + D[3])/2
261
262        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
263        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
264        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
265
266
267        ##################
268        #Now check temporal interpolation
269        #Halfway between timestep 15 and 16
270
271        timestep = 15
272        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
273        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
274        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
275        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
276
277        #Reference interpolated values at midpoints on diagonal at
278        #this timestep are
279        r0_0 = (D[0] + D[1])/2
280        r1_0 = (D[1] + D[2])/2
281        r2_0 = (D[2] + D[3])/2
282
283        #
284        timestep = 16
285        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
286        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
287        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
288        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
289
290        #Reference interpolated values at midpoints on diagonal at
291        #this timestep are
292        r0_1 = (D[0] + D[1])/2
293        r1_1 = (D[1] + D[2])/2
294        r2_1 = (D[2] + D[3])/2
295
296        # The reference values are
297        r0 = (r0_0 + r0_1)/2
298        r1 = (r1_0 + r1_1)/2
299        r2 = (r2_0 + r2_1)/2
300
301        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
302        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
303        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
304
305        ##################
306        #Finally check interpolation 2 thirds of the way
307        #between timestep 15 and 16
308
309        # The reference values are
310        r0 = (r0_0 + 2*r0_1)/3
311        r1 = (r1_0 + 2*r1_1)/3
312        r2 = (r2_0 + 2*r2_1)/3
313
314        #And the file function gives
315        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
316        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
317        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
318
319        fid.close()
320        import os
321        os.remove(filename)
322
323
324
325    def test_spatio_temporal_file_function_different_origin(self):
326        """Test that spatio temporal file function performs the correct
327        interpolations in both time and space where space is offset by
328        xllcorner and yllcorner
329        NetCDF version (x,y,t dependency)       
330        """
331        import time
332
333        #Create sww file of simple propagation from left to right
334        #through rectangular domain
335        from shallow_water import Domain, Dirichlet_boundary
336        from mesh_factory import rectangular
337
338
339        from anuga.coordinate_transforms.geo_reference import Geo_reference
340        xllcorner = 2048
341        yllcorner = 11000
342        zone = 2
343
344        #Create basic mesh and shallow water domain
345        points, vertices, boundary = rectangular(3, 3)
346        domain1 = Domain(points, vertices, boundary,
347                         geo_reference = Geo_reference(xllcorner = xllcorner,
348                                                       yllcorner = yllcorner))
349       
350
351        from anuga.utilities.numerical_tools import mean       
352        domain1.reduction = mean
353        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
354                              # only one value.
355
356        domain1.default_order = 2
357        domain1.store = True
358        domain1.set_datadir('.')
359        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
360
361        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
362        domain1.set_quantity('elevation', 0)
363        domain1.set_quantity('friction', 0)
364        domain1.set_quantity('stage', 0)
365
366        # Boundary conditions
367        B0 = Dirichlet_boundary([0,0,0])
368        B6 = Dirichlet_boundary([0.6,0,0])
369        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
370        domain1.check_integrity()
371
372        finaltime = 8
373        #Evolution
374        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
375            pass
376            #domain1.write_time()
377
378
379        #Now read data from sww and check
380        from Scientific.IO.NetCDF import NetCDFFile
381        filename = domain1.get_name() + '.sww'
382        fid = NetCDFFile(filename)
383
384        x = fid.variables['x'][:]
385        y = fid.variables['y'][:]
386        # we 'cast' to 64 bit floats to pass this test
387        # SWW file quantities are stored as 32 bits
388        x = num.array(x, num.float)
389        y = num.array(y, num.float)
390
391        stage = fid.variables['stage'][:]
392        xmomentum = fid.variables['xmomentum'][:]
393        ymomentum = fid.variables['ymomentum'][:]
394        time = fid.variables['time'][:]
395
396        #Take stage vertex values at last timestep on diagonal
397        #Diagonal is identified by vertices: 0, 5, 10, 15
398
399        last_time_index = len(time)-1 #Last last_time_index     
400        d_stage = num.reshape(num.take(stage[last_time_index, :],
401                                       [0,5,10,15],
402                                       axis=0),
403                              (4,1))
404        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
405                                    [0,5,10,15],
406                                   axis=0),
407                           (4,1))
408        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
409                                    [0,5,10,15],
410                                    axis=0),
411                           (4,1))
412        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
413
414        #Reference interpolated values at midpoints on diagonal at
415        #this timestep are
416        r0 = (D[0] + D[1])/2
417        r1 = (D[1] + D[2])/2
418        r2 = (D[2] + D[3])/2
419
420        #And the midpoints are found now
421        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
422        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
423
424        diag = num.concatenate((Dx, Dy), axis=1)
425        d_midpoints = (diag[1:] + diag[:-1])/2
426
427
428        #Adjust for georef - make interpolation points absolute
429        d_midpoints[:,0] += xllcorner
430        d_midpoints[:,1] += yllcorner               
431
432        #Let us see if the file function can find the correct
433        #values at the midpoints at the last timestep:
434        f = file_function(filename, domain1,
435                          interpolation_points = d_midpoints)
436
437        t = time[last_time_index]                         
438
439        q = f(t, point_id=0)
440        msg = '\nr0=%s\nq=%s' % (str(r0), str(q))
441        assert num.allclose(r0, q), msg
442
443        q = f(t, point_id=1)
444        msg = '\nr1=%s\nq=%s' % (str(r1), str(q))
445        assert num.allclose(r1, q), msg
446
447        q = f(t, point_id=2)
448        msg = '\nr2=%s\nq=%s' % (str(r2), str(q))
449        assert num.allclose(r2, q), msg
450
451
452        ##################
453        #Now do the same for the first timestep
454
455        timestep = 0 #First timestep
456        d_stage = num.reshape(num.take(stage[timestep, :],
457                                       [0,5,10,15],
458                                       axis=0),
459                              (4,1))
460        d_uh = num.reshape(num.take(xmomentum[timestep, :],
461                                    [0,5,10,15],
462                                    axis=0),
463                           (4,1))
464        d_vh = num.reshape(num.take(ymomentum[timestep, :],
465                                    [0,5,10,15],
466                                    axis=0),
467                           (4,1))
468        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
469
470        #Reference interpolated values at midpoints on diagonal at
471        #this timestep are
472        r0 = (D[0] + D[1])/2
473        r1 = (D[1] + D[2])/2
474        r2 = (D[2] + D[3])/2
475
476        #Let us see if the file function can find the correct
477        #values
478        q = f(0, point_id=0); assert num.allclose(r0, q)
479        q = f(0, point_id=1); assert num.allclose(r1, q)
480        q = f(0, point_id=2); assert num.allclose(r2, q)
481
482
483        ##################
484        #Now do it again for a timestep in the middle
485
486        timestep = 33
487        d_stage = num.reshape(num.take(stage[timestep, :],
488                                       [0,5,10,15],
489                                       axis=0),
490                              (4,1))
491        d_uh = num.reshape(num.take(xmomentum[timestep, :],
492                                    [0,5,10,15],
493                                    axis=0),
494                           (4,1))
495        d_vh = num.reshape(num.take(ymomentum[timestep, :],
496                                    [0,5,10,15],
497                                    axis=0),
498                           (4,1))
499        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
500
501        #Reference interpolated values at midpoints on diagonal at
502        #this timestep are
503        r0 = (D[0] + D[1])/2
504        r1 = (D[1] + D[2])/2
505        r2 = (D[2] + D[3])/2
506
507        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
508        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
509        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
510
511
512        ##################
513        #Now check temporal interpolation
514        #Halfway between timestep 15 and 16
515
516        timestep = 15
517        d_stage = num.reshape(num.take(stage[timestep, :],
518                                       [0,5,10,15],
519                                       axis=0),
520                              (4,1))
521        d_uh = num.reshape(num.take(xmomentum[timestep, :],
522                                    [0,5,10,15],
523                                    axis=0),
524                           (4,1))
525        d_vh = num.reshape(num.take(ymomentum[timestep, :],
526                                    [0,5,10,15],
527                                    axis=0),
528                           (4,1))
529        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
530
531        #Reference interpolated values at midpoints on diagonal at
532        #this timestep are
533        r0_0 = (D[0] + D[1])/2
534        r1_0 = (D[1] + D[2])/2
535        r2_0 = (D[2] + D[3])/2
536
537        #
538        timestep = 16
539        d_stage = num.reshape(num.take(stage[timestep, :],
540                                       [0,5,10,15],
541                                       axis=0),
542                              (4,1))
543        d_uh = num.reshape(num.take(xmomentum[timestep, :],
544                                    [0,5,10,15],
545                                    axis=0),
546                           (4,1))
547        d_vh = num.reshape(num.take(ymomentum[timestep, :],
548                                    [0,5,10,15],
549                                    axis=0),
550                           (4,1))
551        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
552
553        #Reference interpolated values at midpoints on diagonal at
554        #this timestep are
555        r0_1 = (D[0] + D[1])/2
556        r1_1 = (D[1] + D[2])/2
557        r2_1 = (D[2] + D[3])/2
558
559        # The reference values are
560        r0 = (r0_0 + r0_1)/2
561        r1 = (r1_0 + r1_1)/2
562        r2 = (r2_0 + r2_1)/2
563
564        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
565        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
566        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
567
568        ##################
569        #Finally check interpolation 2 thirds of the way
570        #between timestep 15 and 16
571
572        # The reference values are
573        r0 = (r0_0 + 2*r0_1)/3
574        r1 = (r1_0 + 2*r1_1)/3
575        r2 = (r2_0 + 2*r2_1)/3
576
577        #And the file function gives
578        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
579        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
580        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
581
582        fid.close()
583        import os
584        os.remove(filename)
585
586       
587
588
589    def test_spatio_temporal_file_function_time(self):
590        """Test that File function interpolates correctly
591        between given times.
592        NetCDF version (x,y,t dependency)
593        """
594
595        #Create NetCDF (sww) file to be read
596        # x: 0, 5, 10, 15
597        # y: -20, -10, 0, 10
598        # t: 0, 60, 120, ...., 1200
599        #
600        # test quantities (arbitrary but non-trivial expressions):
601        #
602        #   stage     = 3*x - y**2 + 2*t
603        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
604        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
605
606        #NOTE: Nice test that may render some of the others redundant.
607
608        import os, time
609        from anuga.config import time_format
610        from mesh_factory import rectangular
611        from shallow_water import Domain
612        import anuga.shallow_water.data_manager
613
614        finaltime = 1200
615        filename = 'test_file_function'
616
617        #Create a domain to hold test grid
618        #(0:15, -20:10)
619        points, vertices, boundary =\
620                rectangular(4, 4, 15, 30, origin = (0, -20))
621        #print "points", points
622
623        #print 'Number of elements', len(vertices)
624        domain = Domain(points, vertices, boundary)
625        domain.smooth = False
626        domain.default_order = 2
627        domain.set_datadir('.')
628        domain.set_name(filename)
629        domain.store = True
630
631        #print points
632        start = time.mktime(time.strptime('2000', '%Y'))
633        domain.starttime = start
634
635
636        #Store structure
637        domain.initialise_storage()
638
639        #Compute artificial time steps and store
640        dt = 60  #One minute intervals
641        t = 0.0
642        while t <= finaltime:
643            #Compute quantities
644            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
645            domain.set_quantity('stage', f1)
646
647            f2 = lambda x,y: x+y+t**2
648            domain.set_quantity('xmomentum', f2)
649
650            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
651            domain.set_quantity('ymomentum', f3)
652
653            #Store and advance time
654            domain.time = t
655            domain.store_timestep()
656            t += dt
657
658
659        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
660     
661        #Deliberately set domain.starttime to too early
662        domain.starttime = start - 1
663
664        #Create file function
665        F = file_function(filename + '.sww', domain,
666                          quantities = domain.conserved_quantities,
667                          interpolation_points = interpolation_points)
668
669        #Check that FF updates fixes domain starttime
670        assert num.allclose(domain.starttime, start)
671
672        #Check that domain.starttime isn't updated if later
673        domain.starttime = start + 1
674        F = file_function(filename + '.sww', domain,
675                          quantities = domain.conserved_quantities,
676                          interpolation_points = interpolation_points)
677        assert num.allclose(domain.starttime, start+1)
678        domain.starttime = start
679
680
681        #Check linear interpolation in time
682        F = file_function(filename + '.sww', domain,
683                          quantities = domain.conserved_quantities,
684                          interpolation_points = interpolation_points)               
685        for id in range(len(interpolation_points)):
686            x = interpolation_points[id][0]
687            y = interpolation_points[id][1]
688
689            for i in range(20):
690                t = i*10
691                k = i%6
692
693                if k == 0:
694                    q0 = F(t, point_id=id)
695                    q1 = F(t+60, point_id=id)
696
697                if num.alltrue(q0 == NAN):
698                    actual = q0
699                else:
700                    actual = (k*q1 + (6-k)*q0)/6
701                q = F(t, point_id=id)
702                #print i, k, t, q
703                #print ' ', q0
704                #print ' ', q1
705                #print "q",q
706                #print "actual", actual
707                #print
708                if num.alltrue(q0 == NAN):
709                     self.failUnless(num.alltrue(q == actual), 'Fail!')
710                else:
711                    assert num.allclose(q, actual)
712
713
714        #Another check of linear interpolation in time
715        for id in range(len(interpolation_points)):
716            q60 = F(60, point_id=id)
717            q120 = F(120, point_id=id)
718
719            t = 90 #Halfway between 60 and 120
720            q = F(t, point_id=id)
721            assert num.allclose( (q120+q60)/2, q )
722
723            t = 100 #Two thirds of the way between between 60 and 120
724            q = F(t, point_id=id)
725            assert num.allclose(q60/3 + 2*q120/3, q)
726
727
728
729        #Check that domain.starttime isn't updated if later than file starttime but earlier
730        #than file end time
731        delta = 23
732        domain.starttime = start + delta
733        F = file_function(filename + '.sww', domain,
734                          quantities = domain.conserved_quantities,
735                          interpolation_points = interpolation_points)
736        assert num.allclose(domain.starttime, start+delta)
737
738
739
740
741        #Now try interpolation with delta offset
742        for id in range(len(interpolation_points)):           
743            x = interpolation_points[id][0]
744            y = interpolation_points[id][1]
745
746            for i in range(20):
747                t = i*10
748                k = i%6
749
750                if k == 0:
751                    q0 = F(t-delta, point_id=id)
752                    q1 = F(t+60-delta, point_id=id)
753
754                q = F(t-delta, point_id=id)
755                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
756
757
758        os.remove(filename + '.sww')
759
760
761
762    def Xtest_spatio_temporal_file_function_time(self):
763        # FIXME: This passes but needs some TLC
764        # Test that File function interpolates correctly
765        # When some points are outside the mesh
766
767        import os, time
768        from anuga.config import time_format
769        from mesh_factory import rectangular
770        from shallow_water import Domain
771        import anuga.shallow_water.data_manager 
772        from anuga.pmesh.mesh_interface import create_mesh_from_regions
773        finaltime = 1200
774       
775        filename = tempfile.mktemp()
776        #print "filename",filename
777        filename = 'test_file_function'
778
779        meshfilename = tempfile.mktemp(".tsh")
780
781        boundary_tags = {'walls':[0,1],'bom':[2]}
782       
783        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
784       
785        create_mesh_from_regions(polygon_absolute,
786                                 boundary_tags,
787                                 10000000,
788                                 filename=meshfilename)
789        domain = Domain(mesh_filename=meshfilename)
790        domain.smooth = False
791        domain.default_order = 2
792        domain.set_datadir('.')
793        domain.set_name(filename)
794        domain.store = True
795
796        #print points
797        start = time.mktime(time.strptime('2000', '%Y'))
798        domain.starttime = start
799       
800
801        #Store structure
802        domain.initialise_storage()
803
804        #Compute artificial time steps and store
805        dt = 60  #One minute intervals
806        t = 0.0
807        while t <= finaltime:
808            #Compute quantities
809            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
810            domain.set_quantity('stage', f1)
811
812            f2 = lambda x,y: x+y+t**2
813            domain.set_quantity('xmomentum', f2)
814
815            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
816            domain.set_quantity('ymomentum', f3)
817
818            #Store and advance time
819            domain.time = t
820            domain.store_timestep()
821            t += dt
822
823        interpolation_points = [[1,0]]
824        interpolation_points = [[100,1000]]
825       
826        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
827                                [78787,78787],[7878,3432]]
828           
829        #Deliberately set domain.starttime to too early
830        domain.starttime = start - 1
831
832        #Create file function
833        F = file_function(filename + '.sww', domain,
834                          quantities = domain.conserved_quantities,
835                          interpolation_points = interpolation_points)
836
837        #Check that FF updates fixes domain starttime
838        assert num.allclose(domain.starttime, start)
839
840        #Check that domain.starttime isn't updated if later
841        domain.starttime = start + 1
842        F = file_function(filename + '.sww', domain,
843                          quantities = domain.conserved_quantities,
844                          interpolation_points = interpolation_points)
845        assert num.allclose(domain.starttime, start+1)
846        domain.starttime = start
847
848
849        #Check linear interpolation in time
850        # checking points inside and outside the mesh
851        F = file_function(filename + '.sww', domain,
852                          quantities = domain.conserved_quantities,
853                          interpolation_points = interpolation_points)
854       
855        for id in range(len(interpolation_points)):
856            x = interpolation_points[id][0]
857            y = interpolation_points[id][1]
858
859            for i in range(20):
860                t = i*10
861                k = i%6
862
863                if k == 0:
864                    q0 = F(t, point_id=id)
865                    q1 = F(t+60, point_id=id)
866
867                if q0 == NAN:
868                    actual = q0
869                else:
870                    actual = (k*q1 + (6-k)*q0)/6
871                q = F(t, point_id=id)
872                #print i, k, t, q
873                #print ' ', q0
874                #print ' ', q1
875                #print "q",q
876                #print "actual", actual
877                #print
878                if q0 == NAN:
879                     self.failUnless( q == actual, 'Fail!')
880                else:
881                    assert num.allclose(q, actual)
882
883        # now lets check points inside the mesh
884        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
885        interpolation_points = [[10,-12.5]]
886           
887        print "len(interpolation_points)",len(interpolation_points) 
888        F = file_function(filename + '.sww', domain,
889                          quantities = domain.conserved_quantities,
890                          interpolation_points = interpolation_points)
891
892        domain.starttime = start
893
894
895        #Check linear interpolation in time
896        F = file_function(filename + '.sww', domain,
897                          quantities = domain.conserved_quantities,
898                          interpolation_points = interpolation_points)               
899        for id in range(len(interpolation_points)):
900            x = interpolation_points[id][0]
901            y = interpolation_points[id][1]
902
903            for i in range(20):
904                t = i*10
905                k = i%6
906
907                if k == 0:
908                    q0 = F(t, point_id=id)
909                    q1 = F(t+60, point_id=id)
910
911                if q0 == NAN:
912                    actual = q0
913                else:
914                    actual = (k*q1 + (6-k)*q0)/6
915                q = F(t, point_id=id)
916                print "############"
917                print "id, x, y ", id, x, y #k, t, q
918                print "t", t
919                #print ' ', q0
920                #print ' ', q1
921                print "q",q
922                print "actual", actual
923                #print
924                if q0 == NAN:
925                     self.failUnless( q == actual, 'Fail!')
926                else:
927                    assert num.allclose(q, actual)
928
929
930        #Another check of linear interpolation in time
931        for id in range(len(interpolation_points)):
932            q60 = F(60, point_id=id)
933            q120 = F(120, point_id=id)
934
935            t = 90 #Halfway between 60 and 120
936            q = F(t, point_id=id)
937            assert num.allclose( (q120+q60)/2, q )
938
939            t = 100 #Two thirds of the way between between 60 and 120
940            q = F(t, point_id=id)
941            assert num.allclose(q60/3 + 2*q120/3, q)
942
943
944
945        #Check that domain.starttime isn't updated if later than file starttime but earlier
946        #than file end time
947        delta = 23
948        domain.starttime = start + delta
949        F = file_function(filename + '.sww', domain,
950                          quantities = domain.conserved_quantities,
951                          interpolation_points = interpolation_points)
952        assert num.allclose(domain.starttime, start+delta)
953
954
955
956
957        #Now try interpolation with delta offset
958        for id in range(len(interpolation_points)):           
959            x = interpolation_points[id][0]
960            y = interpolation_points[id][1]
961
962            for i in range(20):
963                t = i*10
964                k = i%6
965
966                if k == 0:
967                    q0 = F(t-delta, point_id=id)
968                    q1 = F(t+60-delta, point_id=id)
969
970                q = F(t-delta, point_id=id)
971                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
972
973
974        os.remove(filename + '.sww')
975
976    def test_file_function_time_with_domain(self):
977        """Test that File function interpolates correctly
978        between given times. No x,y dependency here.
979        Use domain with starttime
980        """
981
982        #Write file
983        import os, time, calendar
984        from anuga.config import time_format
985        from math import sin, pi
986        from domain import Domain
987
988        finaltime = 1200
989        filename = 'test_file_function'
990        fid = open(filename + '.txt', 'w')
991        start = time.mktime(time.strptime('2000', '%Y'))
992        dt = 60  #One minute intervals
993        t = 0.0
994        while t <= finaltime:
995            t_string = time.strftime(time_format, time.gmtime(t+start))
996            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
997            t += dt
998
999        fid.close()
1000
1001
1002        #Convert ASCII file to NetCDF (Which is what we really like!)
1003        timefile2netcdf(filename)
1004
1005
1006
1007        a = [0.0, 0.0]
1008        b = [4.0, 0.0]
1009        c = [0.0, 3.0]
1010
1011        points = [a, b, c]
1012        vertices = [[0,1,2]]
1013        domain = Domain(points, vertices)
1014
1015        # Check that domain.starttime is updated if non-existing
1016        F = file_function(filename + '.tms',
1017                          domain,
1018                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
1019        assert num.allclose(domain.starttime, start)
1020
1021        # Check that domain.starttime is updated if too early
1022        domain.starttime = start - 1
1023        F = file_function(filename + '.tms',
1024                          domain,
1025                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1026        assert num.allclose(domain.starttime, start)
1027
1028        # Check that domain.starttime isn't updated if later
1029        domain.starttime = start + 1
1030        F = file_function(filename + '.tms',
1031                          domain,
1032                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1033        assert num.allclose(domain.starttime, start+1)
1034
1035        domain.starttime = start
1036        F = file_function(filename + '.tms',
1037                          domain,
1038                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
1039                          use_cache=True)
1040       
1041
1042        #print F.precomputed_values
1043        #print 'F(60)', F(60)
1044       
1045        #Now try interpolation
1046        for i in range(20):
1047            t = i*10
1048            q = F(t)
1049
1050            #Exact linear intpolation
1051            assert num.allclose(q[0], 2*t)
1052            if i%6 == 0:
1053                assert num.allclose(q[1], t**2)
1054                assert num.allclose(q[2], sin(t*pi/600))
1055
1056        #Check non-exact
1057
1058        t = 90 #Halfway between 60 and 120
1059        q = F(t)
1060        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1061        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1062
1063
1064        t = 100 #Two thirds of the way between between 60 and 120
1065        q = F(t)
1066        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1067        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1068
1069        os.remove(filename + '.tms')
1070        os.remove(filename + '.txt')       
1071
1072    def test_file_function_time_with_domain_different_start(self):
1073        """Test that File function interpolates correctly
1074        between given times. No x,y dependency here.
1075        Use domain with a starttime later than that of file
1076
1077        ASCII version
1078        """
1079
1080        #Write file
1081        import os, time, calendar
1082        from anuga.config import time_format
1083        from math import sin, pi
1084        from domain import Domain
1085
1086        finaltime = 1200
1087        filename = 'test_file_function'
1088        fid = open(filename + '.txt', 'w')
1089        start = time.mktime(time.strptime('2000', '%Y'))
1090        dt = 60  #One minute intervals
1091        t = 0.0
1092        while t <= finaltime:
1093            t_string = time.strftime(time_format, time.gmtime(t+start))
1094            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1095            t += dt
1096
1097        fid.close()
1098
1099        #Convert ASCII file to NetCDF (Which is what we really like!)
1100        timefile2netcdf(filename)       
1101
1102        a = [0.0, 0.0]
1103        b = [4.0, 0.0]
1104        c = [0.0, 3.0]
1105
1106        points = [a, b, c]
1107        vertices = [[0,1,2]]
1108        domain = Domain(points, vertices)
1109
1110        #Check that domain.starttime isn't updated if later than file starttime but earlier
1111        #than file end time
1112        delta = 23
1113        domain.starttime = start + delta
1114        F = file_function(filename + '.tms', domain,
1115                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1116        assert num.allclose(domain.starttime, start+delta)
1117
1118        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1119                                            277., 337., 397., 457., 517.,
1120                                            577., 637., 697., 757., 817.,
1121                                            877., 937., 997., 1057., 1117.,
1122                                            1177.])
1123
1124
1125        #Now try interpolation with delta offset
1126        for i in range(20):
1127            t = i*10
1128            q = F(t-delta)
1129
1130            #Exact linear intpolation
1131            assert num.allclose(q[0], 2*t)
1132            if i%6 == 0:
1133                assert num.allclose(q[1], t**2)
1134                assert num.allclose(q[2], sin(t*pi/600))
1135
1136        #Check non-exact
1137
1138        t = 90 #Halfway between 60 and 120
1139        q = F(t-delta)
1140        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1141        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1142
1143
1144        t = 100 #Two thirds of the way between between 60 and 120
1145        q = F(t-delta)
1146        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1147        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1148
1149
1150        os.remove(filename + '.tms')
1151        os.remove(filename + '.txt')               
1152
1153       
1154
1155    def test_file_function_time_with_domain_different_start_and_time_limit(self):
1156        """Test that File function interpolates correctly
1157        between given times. No x,y dependency here.
1158        Use domain with a starttime later than that of file
1159
1160        ASCII version
1161       
1162        This test also tests that time can be truncated.
1163        """
1164
1165        # Write file
1166        import os, time, calendar
1167        from anuga.config import time_format
1168        from math import sin, pi
1169        from domain import Domain
1170
1171        finaltime = 1200
1172        filename = 'test_file_function'
1173        fid = open(filename + '.txt', 'w')
1174        start = time.mktime(time.strptime('2000', '%Y'))
1175        dt = 60  #One minute intervals
1176        t = 0.0
1177        while t <= finaltime:
1178            t_string = time.strftime(time_format, time.gmtime(t+start))
1179            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1180            t += dt
1181
1182        fid.close()
1183
1184        # Convert ASCII file to NetCDF (Which is what we really like!)
1185        timefile2netcdf(filename)       
1186
1187        a = [0.0, 0.0]
1188        b = [4.0, 0.0]
1189        c = [0.0, 3.0]
1190
1191        points = [a, b, c]
1192        vertices = [[0,1,2]]
1193        domain = Domain(points, vertices)
1194
1195        # Check that domain.starttime isn't updated if later than file starttime but earlier
1196        # than file end time
1197        delta = 23
1198        domain.starttime = start + delta
1199        time_limit = domain.starttime + 600
1200        F = file_function(filename + '.tms', domain,
1201                          time_limit=time_limit,
1202                          quantities=['Attribute0', 'Attribute1', 'Attribute2'])       
1203        assert num.allclose(domain.starttime, start+delta)
1204
1205        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1206                                            277., 337., 397., 457., 517.,
1207                                            577.])       
1208
1209
1210
1211        # Now try interpolation with delta offset
1212        for i in range(20):
1213            t = i*10
1214            q = F(t-delta)
1215
1216            #Exact linear intpolation
1217            assert num.allclose(q[0], 2*t)
1218            if i%6 == 0:
1219                assert num.allclose(q[1], t**2)
1220                assert num.allclose(q[2], sin(t*pi/600))
1221
1222        # Check non-exact
1223        t = 90 #Halfway between 60 and 120
1224        q = F(t-delta)
1225        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1226        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1227
1228
1229        t = 100 # Two thirds of the way between between 60 and 120
1230        q = F(t-delta)
1231        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1232        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1233
1234
1235        os.remove(filename + '.tms')
1236        os.remove(filename + '.txt')               
1237
1238       
1239       
1240       
1241
1242
1243    def test_apply_expression_to_dictionary(self):
1244
1245        #FIXME: Division is not expected to work for integers.
1246        #This must be caught.
1247        foo = num.array([[1,2,3], [4,5,6]], num.float)
1248
1249        bar = num.array([[-1,0,5], [6,1,1]], num.float)                 
1250
1251        D = {'X': foo, 'Y': bar}
1252
1253        Z = apply_expression_to_dictionary('X+Y', D)       
1254        assert num.allclose(Z, foo+bar)
1255
1256        Z = apply_expression_to_dictionary('X*Y', D)       
1257        assert num.allclose(Z, foo*bar)       
1258
1259        Z = apply_expression_to_dictionary('4*X+Y', D)       
1260        assert num.allclose(Z, 4*foo+bar)       
1261
1262        # test zero division is OK
1263        Z = apply_expression_to_dictionary('X/Y', D)
1264        assert num.allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1265
1266        # make an error for zero on zero
1267        # this is really an error in numeric, SciPy core can handle it
1268        # Z = apply_expression_to_dictionary('0/Y', D)
1269
1270        #Check exceptions
1271        try:
1272            #Wrong name
1273            Z = apply_expression_to_dictionary('4*X+A', D)       
1274        except NameError:
1275            pass
1276        else:
1277            msg = 'Should have raised a NameError Exception'
1278            raise msg
1279
1280
1281        try:
1282            #Wrong order
1283            Z = apply_expression_to_dictionary(D, '4*X+A')       
1284        except AssertionError:
1285            pass
1286        else:
1287            msg = 'Should have raised a AssertionError Exception'
1288            raise msg       
1289       
1290
1291    def test_multiple_replace(self):
1292        """Hard test that checks a true word-by-word simultaneous replace
1293        """
1294       
1295        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1296        exp = '3*x+y + xi'
1297       
1298        new = multiple_replace(exp, D)
1299       
1300        assert new == '3*xi+eta + lam'
1301                         
1302
1303
1304    def test_point_on_line_obsolete(self):
1305        """Test that obsolete call issues appropriate warning"""
1306
1307        #Turn warning into an exception
1308        import warnings
1309        warnings.filterwarnings('error')
1310
1311        try:
1312            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1313        except DeprecationWarning:
1314            pass
1315        else:
1316            msg = 'point_on_line should have issued a DeprecationWarning'
1317            raise Exception(msg)   
1318
1319        warnings.resetwarnings()
1320   
1321    def test_get_revision_number(self):
1322        """test_get_revision_number(self):
1323
1324        Test that revision number can be retrieved.
1325        """
1326        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1327            # I have a known snv incompatability issue,
1328            # so I'm skipping this test.
1329            # FIXME when SVN is upgraded on our clusters
1330            pass
1331        else:   
1332            n = get_revision_number()
1333            assert n>=0
1334
1335
1336       
1337    def test_add_directories(self):
1338       
1339        import tempfile
1340        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1341        directories = ['ja','ne','ke']
1342        kens_dir = add_directories(root_dir, directories)
1343        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1344               sep + 'ke'
1345        assert access(root_dir,F_OK)
1346
1347        add_directories(root_dir, directories)
1348        assert access(root_dir,F_OK)
1349       
1350        #clean up!
1351        os.rmdir(kens_dir)
1352        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1353        os.rmdir(root_dir + sep + 'ja')
1354        os.rmdir(root_dir)
1355
1356    def test_add_directories_bad(self):
1357       
1358        import tempfile
1359        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1360        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1361       
1362        try:
1363            kens_dir = add_directories(root_dir, directories)
1364        except OSError:
1365            pass
1366        else:
1367            msg = 'bad dir name should give OSError'
1368            raise Exception(msg)   
1369           
1370        #clean up!
1371        os.rmdir(root_dir)
1372
1373    def test_check_list(self):
1374
1375        check_list(['stage','xmomentum'])
1376
1377       
1378    def test_add_directories(self):
1379       
1380        import tempfile
1381        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1382        directories = ['ja','ne','ke']
1383        kens_dir = add_directories(root_dir, directories)
1384        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1385               sep + 'ke'
1386        assert access(root_dir,F_OK)
1387
1388        add_directories(root_dir, directories)
1389        assert access(root_dir,F_OK)
1390       
1391        #clean up!
1392        os.rmdir(kens_dir)
1393        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1394        os.rmdir(root_dir + sep + 'ja')
1395        os.rmdir(root_dir)
1396
1397    def test_add_directories_bad(self):
1398       
1399        import tempfile
1400        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1401        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1402       
1403        try:
1404            kens_dir = add_directories(root_dir, directories)
1405        except OSError:
1406            pass
1407        else:
1408            msg = 'bad dir name should give OSError'
1409            raise Exception(msg)   
1410           
1411        #clean up!
1412        os.rmdir(root_dir)
1413
1414    def test_check_list(self):
1415
1416        check_list(['stage','xmomentum'])
1417
1418######
1419# Test the remove_lone_verts() function
1420######
1421       
1422    def test_remove_lone_verts_a(self):
1423        verts = [[0,0],[1,0],[0,1]]
1424        tris = [[0,1,2]]
1425        new_verts, new_tris = remove_lone_verts(verts, tris)
1426        self.failUnless(new_verts.tolist() == verts)
1427        self.failUnless(new_tris.tolist() == tris)
1428
1429    def test_remove_lone_verts_b(self):
1430        verts = [[0,0],[1,0],[0,1],[99,99]]
1431        tris = [[0,1,2]]
1432        new_verts, new_tris = remove_lone_verts(verts, tris)
1433        self.failUnless(new_verts.tolist() == verts[0:3])
1434        self.failUnless(new_tris.tolist() == tris)
1435       
1436    def test_remove_lone_verts_c(self):
1437        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1438        tris = [[1,2,4]]
1439        new_verts, new_tris = remove_lone_verts(verts, tris)
1440        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1441        self.failUnless(new_tris.tolist() == [[0,1,2]])
1442     
1443    def test_remove_lone_verts_d(self):
1444        verts = [[0,0],[1,0],[99,99],[0,1]]
1445        tris = [[0,1,3]]
1446        new_verts, new_tris = remove_lone_verts(verts, tris)
1447        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1448        self.failUnless(new_tris.tolist() == [[0,1,2]])
1449       
1450    def test_remove_lone_verts_e(self):
1451        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1452        tris = [[0,1,2]]
1453        new_verts, new_tris = remove_lone_verts(verts, tris)
1454        self.failUnless(new_verts.tolist() == verts[0:3])
1455        self.failUnless(new_tris.tolist() == tris)
1456     
1457    def test_remove_lone_verts_f(self):
1458        verts = [[0,0],[1,0],[99,99],[0,1],[99,99],[1,1],[99,99]]
1459        tris = [[0,1,3],[0,1,5]]
1460        new_verts, new_tris = remove_lone_verts(verts, tris)
1461        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1],[1,1]])
1462        self.failUnless(new_tris.tolist() == [[0,1,2],[0,1,3]])
1463       
1464######
1465#
1466######
1467       
1468    def test_get_min_max_values(self):
1469       
1470        list=[8,9,6,1,4]
1471        min1, max1 = get_min_max_values(list)
1472       
1473        assert min1==1 
1474        assert max1==9
1475       
1476    def test_get_min_max_values1(self):
1477       
1478        list=[-8,-9,-6,-1,-4]
1479        min1, max1 = get_min_max_values(list)
1480       
1481#        print 'min1,max1',min1,max1
1482        assert min1==-9 
1483        assert max1==-1
1484
1485#    def test_get_min_max_values2(self):
1486#        '''
1487#        The min and max supplied are greater than the ones in the
1488#        list and therefore are the ones returned
1489#        '''
1490#        list=[-8,-9,-6,-1,-4]
1491#        min1, max1 = get_min_max_values(list,-10,10)
1492#       
1493##        print 'min1,max1',min1,max1
1494#        assert min1==-10
1495#        assert max1==10
1496       
1497    def test_make_plots_from_csv_files(self):
1498       
1499        #if sys.platform == 'win32':  #Windows
1500            try: 
1501                import pylab
1502            except ImportError:
1503                #ANUGA don't need pylab to work so the system doesn't
1504                #rely on pylab being installed
1505                return
1506           
1507       
1508            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1509            temp_dir = tempfile.mkdtemp('','figures')
1510    #        print 'temp_dir',temp_dir
1511            fileName = temp_dir+sep+'time_series_3.csv'
1512            file = open(fileName,"w")
1513            file.write("time,stage,speed,momentum,elevation\n\
15141.0, 0, 0, 0, 10 \n\
15152.0, 5, 2, 4, 10 \n\
15163.0, 3, 3, 5, 10 \n")
1517            file.close()
1518   
1519            fileName1 = temp_dir+sep+'time_series_4.csv'
1520            file1 = open(fileName1,"w")
1521            file1.write("time,stage,speed,momentum,elevation\n\
15221.0, 0, 0, 0, 5 \n\
15232.0, -5, -2, -4, 5 \n\
15243.0, -4, -3, -5, 5 \n")
1525            file1.close()
1526   
1527            fileName2 = temp_dir+sep+'time_series_5.csv'
1528            file2 = open(fileName2,"w")
1529            file2.write("time,stage,speed,momentum,elevation\n\
15301.0, 0, 0, 0, 7 \n\
15312.0, 4, -0.45, 57, 7 \n\
15323.0, 6, -0.5, 56, 7 \n")
1533            file2.close()
1534           
1535            dir, name=os.path.split(fileName)
1536            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1537                                  output_dir=temp_dir,
1538                                  base_name='time_series_',
1539                                  plot_numbers=['3-5'],
1540                                  quantities=['speed','stage','momentum'],
1541                                  assess_all_csv_files=True,
1542                                  extra_plot_name='test')
1543           
1544            #print dir+sep+name[:-4]+'_stage_test.png'
1545            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1546            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1547            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1548   
1549            dir1, name1=os.path.split(fileName1)
1550            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1551            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1552            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1553   
1554   
1555            dir2, name2=os.path.split(fileName2)
1556            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1557            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1558            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1559   
1560            del_dir(temp_dir)
1561       
1562
1563    def test_sww2csv_gauges(self):
1564
1565        def elevation_function(x, y):
1566            return -x
1567       
1568        """Most of this test was copied from test_interpolate
1569        test_interpole_sww2csv
1570       
1571        This is testing the gauge_sww2csv function, by creating a sww file and
1572        then exporting the gauges and checking the results.
1573        """
1574       
1575        # Create mesh
1576        mesh_file = tempfile.mktemp(".tsh")   
1577        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1578        m = Mesh()
1579        m.add_vertices(points)
1580        m.auto_segment()
1581        m.generate_mesh(verbose=False)
1582        m.export_mesh_file(mesh_file)
1583       
1584        # Create shallow water domain
1585        domain = Domain(mesh_file)
1586        os.remove(mesh_file)
1587       
1588        domain.default_order=2
1589       
1590        # This test was made before tight_slope_limiters were introduced
1591        # Since were are testing interpolation values this is OK
1592        domain.tight_slope_limiters = 0 
1593       
1594
1595        # Set some field values
1596        domain.set_quantity('elevation', elevation_function)
1597        domain.set_quantity('friction', 0.03)
1598        domain.set_quantity('xmomentum', 3.0)
1599        domain.set_quantity('ymomentum', 4.0)
1600
1601        ######################
1602        # Boundary conditions
1603        B = Transmissive_boundary(domain)
1604        domain.set_boundary( {'exterior': B})
1605
1606        # This call mangles the stage values.
1607        domain.distribute_to_vertices_and_edges()
1608        domain.set_quantity('stage', 1.0)
1609
1610
1611        domain.set_name('datatest' + str(time.time()))
1612        domain.smooth = True
1613        domain.reduction = mean
1614
1615
1616        sww = SWW_file(domain)
1617        sww.store_connectivity()
1618        sww.store_timestep()
1619        domain.set_quantity('stage', 10.0) # This is automatically limited
1620        # so it will not be less than the elevation
1621        domain.time = 2.
1622        sww.store_timestep()
1623
1624        # test the function
1625        points = [[5.0,1.],[0.5,2.]]
1626
1627        points_file = tempfile.mktemp(".csv")
1628#        points_file = 'test_point.csv'
1629        file_id = open(points_file,"w")
1630        file_id.write("name, easting, northing, elevation \n\
1631point1, 5.0, 1.0, 3.0\n\
1632point2, 0.5, 2.0, 9.0\n")
1633        file_id.close()
1634
1635       
1636        sww2csv_gauges(sww.filename, 
1637                       points_file,
1638                       verbose=False,
1639                       use_cache=False)
1640
1641#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1642        point1_answers_array = [[0.0,0.0,1.0,6.0,-5.0,3.0,4.0], [2.0,2.0/3600.,10.0,15.0,-5.0,3.0,4.0]]
1643        point1_filename = 'gauge_point1.csv'
1644        point1_handle = file(point1_filename)
1645        point1_reader = reader(point1_handle)
1646        point1_reader.next()
1647
1648        line=[]
1649        for i,row in enumerate(point1_reader):
1650#            print 'i',i,'row',row
1651            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1652                         float(row[4]),float(row[5]),float(row[6])])
1653#            print 'assert line',line[i],'point1',point1_answers_array[i]
1654            assert num.allclose(line[i], point1_answers_array[i])
1655
1656        point2_answers_array = [[0.0,0.0,1.0,1.5,-0.5,3.0,4.0], [2.0,2.0/3600.,10.0,10.5,-0.5,3.0,4.0]]
1657        point2_filename = 'gauge_point2.csv' 
1658        point2_handle = file(point2_filename)
1659        point2_reader = reader(point2_handle)
1660        point2_reader.next()
1661                       
1662        line=[]
1663        for i,row in enumerate(point2_reader):
1664#            print 'i',i,'row',row
1665            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1666                         float(row[4]),float(row[5]),float(row[6])])
1667#            print 'assert line',line[i],'point1',point1_answers_array[i]
1668            assert num.allclose(line[i], point2_answers_array[i])
1669                         
1670        # clean up
1671        point1_handle.close()
1672        point2_handle.close()
1673        #print "sww.filename",sww.filename
1674        os.remove(sww.filename)
1675        os.remove(points_file)
1676        os.remove(point1_filename)
1677        os.remove(point2_filename)
1678
1679
1680
1681    def test_sww2csv_gauges1(self):
1682        from anuga.pmesh.mesh import Mesh
1683        from anuga.shallow_water import Domain, Transmissive_boundary
1684        from csv import reader,writer
1685        import time
1686        import string
1687
1688        def elevation_function(x, y):
1689            return -x
1690       
1691        """Most of this test was copied from test_interpolate
1692        test_interpole_sww2csv
1693       
1694        This is testing the gauge_sww2csv function, by creating a sww file and
1695        then exporting the gauges and checking the results.
1696       
1697        This tests the ablity not to have elevation in the points file and
1698        not store xmomentum and ymomentum
1699        """
1700       
1701        # Create mesh
1702        mesh_file = tempfile.mktemp(".tsh")   
1703        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1704        m = Mesh()
1705        m.add_vertices(points)
1706        m.auto_segment()
1707        m.generate_mesh(verbose=False)
1708        m.export_mesh_file(mesh_file)
1709       
1710        # Create shallow water domain
1711        domain = Domain(mesh_file)
1712        os.remove(mesh_file)
1713       
1714        domain.default_order=2
1715
1716        # Set some field values
1717        domain.set_quantity('elevation', elevation_function)
1718        domain.set_quantity('friction', 0.03)
1719        domain.set_quantity('xmomentum', 3.0)
1720        domain.set_quantity('ymomentum', 4.0)
1721
1722        ######################
1723        # Boundary conditions
1724        B = Transmissive_boundary(domain)
1725        domain.set_boundary( {'exterior': B})
1726
1727        # This call mangles the stage values.
1728        domain.distribute_to_vertices_and_edges()
1729        domain.set_quantity('stage', 1.0)
1730
1731
1732        domain.set_name('datatest' + str(time.time()))
1733        domain.smooth = True
1734        domain.reduction = mean
1735
1736        sww = SWW_file(domain)
1737        sww.store_connectivity()
1738        sww.store_timestep()
1739        domain.set_quantity('stage', 10.0) # This is automatically limited
1740        # so it will not be less than the elevation
1741        domain.time = 2.
1742        sww.store_timestep()
1743
1744        # test the function
1745        points = [[5.0,1.],[0.5,2.]]
1746
1747        points_file = tempfile.mktemp(".csv")
1748#        points_file = 'test_point.csv'
1749        file_id = open(points_file,"w")
1750        file_id.write("name,easting,northing \n\
1751point1, 5.0, 1.0\n\
1752point2, 0.5, 2.0\n")
1753        file_id.close()
1754
1755        sww2csv_gauges(sww.filename, 
1756                            points_file,
1757                            quantities=['stage', 'elevation'],
1758                            use_cache=False,
1759                            verbose=False)
1760
1761        point1_answers_array = [[0.0,1.0,-5.0], [2.0,10.0,-5.0]]
1762        point1_filename = 'gauge_point1.csv'
1763        point1_handle = file(point1_filename)
1764        point1_reader = reader(point1_handle)
1765        point1_reader.next()
1766
1767        line=[]
1768        for i,row in enumerate(point1_reader):
1769#            print 'i',i,'row',row
1770            # note the 'hole' (element 1) below - skip the new 'hours' field
1771            line.append([float(row[0]),float(row[2]),float(row[3])])
1772            #print 'line',line[i],'point1',point1_answers_array[i]
1773            assert num.allclose(line[i], point1_answers_array[i])
1774
1775        point2_answers_array = [[0.0,1.0,-0.5], [2.0,10.0,-0.5]]
1776        point2_filename = 'gauge_point2.csv' 
1777        point2_handle = file(point2_filename)
1778        point2_reader = reader(point2_handle)
1779        point2_reader.next()
1780                       
1781        line=[]
1782        for i,row in enumerate(point2_reader):
1783#            print 'i',i,'row',row
1784            # note the 'hole' (element 1) below - skip the new 'hours' field
1785            line.append([float(row[0]),float(row[2]),float(row[3])])
1786#            print 'line',line[i],'point1',point1_answers_array[i]
1787            assert num.allclose(line[i], point2_answers_array[i])
1788                         
1789        # clean up
1790        point1_handle.close()
1791        point2_handle.close()
1792        #print "sww.filename",sww.filename
1793        os.remove(sww.filename)
1794        os.remove(points_file)
1795        os.remove(point1_filename)
1796        os.remove(point2_filename)
1797
1798
1799    def test_sww2csv_gauges2(self):
1800
1801        def elevation_function(x, y):
1802            return -x
1803       
1804        """Most of this test was copied from test_interpolate
1805        test_interpole_sww2csv
1806       
1807        This is testing the gauge_sww2csv function, by creating a sww file and
1808        then exporting the gauges and checking the results.
1809       
1810        This is the same as sww2csv_gauges except set domain.set_starttime to 5.
1811        Therefore testing the storing of the absolute time in the csv files
1812        """
1813       
1814        # Create mesh
1815        mesh_file = tempfile.mktemp(".tsh")   
1816        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1817        m = Mesh()
1818        m.add_vertices(points)
1819        m.auto_segment()
1820        m.generate_mesh(verbose=False)
1821        m.export_mesh_file(mesh_file)
1822       
1823        # Create shallow water domain
1824        domain = Domain(mesh_file)
1825        os.remove(mesh_file)
1826       
1827        domain.default_order=2
1828
1829        # This test was made before tight_slope_limiters were introduced
1830        # Since were are testing interpolation values this is OK
1831        domain.tight_slope_limiters = 0         
1832
1833        # Set some field values
1834        domain.set_quantity('elevation', elevation_function)
1835        domain.set_quantity('friction', 0.03)
1836        domain.set_quantity('xmomentum', 3.0)
1837        domain.set_quantity('ymomentum', 4.0)
1838        domain.set_starttime(5)
1839
1840        ######################
1841        # Boundary conditions
1842        B = Transmissive_boundary(domain)
1843        domain.set_boundary( {'exterior': B})
1844
1845        # This call mangles the stage values.
1846        domain.distribute_to_vertices_and_edges()
1847        domain.set_quantity('stage', 1.0)
1848       
1849
1850
1851        domain.set_name('datatest' + str(time.time()))
1852        domain.smooth = True
1853        domain.reduction = mean
1854
1855        sww = SWW_file(domain)
1856        sww.store_connectivity()
1857        sww.store_timestep()
1858        domain.set_quantity('stage', 10.0) # This is automatically limited
1859        # so it will not be less than the elevation
1860        domain.time = 2.
1861        sww.store_timestep()
1862
1863        # test the function
1864        points = [[5.0,1.],[0.5,2.]]
1865
1866        points_file = tempfile.mktemp(".csv")
1867#        points_file = 'test_point.csv'
1868        file_id = open(points_file,"w")
1869        file_id.write("name, easting, northing, elevation \n\
1870point1, 5.0, 1.0, 3.0\n\
1871point2, 0.5, 2.0, 9.0\n")
1872        file_id.close()
1873
1874       
1875        sww2csv_gauges(sww.filename, 
1876                            points_file,
1877                            verbose=False,
1878                            use_cache=False)
1879
1880#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1881        point1_answers_array = [[5.0,5.0/3600.,1.0,6.0,-5.0,3.0,4.0], [7.0,7.0/3600.,10.0,15.0,-5.0,3.0,4.0]]
1882        point1_filename = 'gauge_point1.csv'
1883        point1_handle = file(point1_filename)
1884        point1_reader = reader(point1_handle)
1885        point1_reader.next()
1886
1887        line=[]
1888        for i,row in enumerate(point1_reader):
1889            #print 'i',i,'row',row
1890            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1891                         float(row[4]), float(row[5]), float(row[6])])
1892            #print 'assert line',line[i],'point1',point1_answers_array[i]
1893            assert num.allclose(line[i], point1_answers_array[i])
1894
1895        point2_answers_array = [[5.0,5.0/3600.,1.0,1.5,-0.5,3.0,4.0], [7.0,7.0/3600.,10.0,10.5,-0.5,3.0,4.0]]
1896        point2_filename = 'gauge_point2.csv' 
1897        point2_handle = file(point2_filename)
1898        point2_reader = reader(point2_handle)
1899        point2_reader.next()
1900                       
1901        line=[]
1902        for i,row in enumerate(point2_reader):
1903            #print 'i',i,'row',row
1904            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1905                         float(row[4]),float(row[5]), float(row[6])])
1906            #print 'assert line',line[i],'point1',point1_answers_array[i]
1907            assert num.allclose(line[i], point2_answers_array[i])
1908                         
1909        # clean up
1910        point1_handle.close()
1911        point2_handle.close()
1912        #print "sww.filename",sww.filename
1913        os.remove(sww.filename)
1914        os.remove(points_file)
1915        os.remove(point1_filename)
1916        os.remove(point2_filename)
1917
1918
1919    def test_greens_law(self):
1920
1921        from math import sqrt
1922       
1923        d1 = 80.0
1924        d2 = 20.0
1925        h1 = 1.0
1926        h2 = greens_law(d1,d2,h1)
1927
1928        assert h2==sqrt(2.0)
1929       
1930    def test_calc_bearings(self):
1931 
1932        from math import atan, degrees
1933        #Test East
1934        uh = 1
1935        vh = 1.e-15
1936        angle = calc_bearing(uh, vh)
1937        if 89 < angle < 91: v=1
1938        assert v==1
1939        #Test West
1940        uh = -1
1941        vh = 1.e-15
1942        angle = calc_bearing(uh, vh)
1943        if 269 < angle < 271: v=1
1944        assert v==1
1945        #Test North
1946        uh = 1.e-15
1947        vh = 1
1948        angle = calc_bearing(uh, vh)
1949        if -1 < angle < 1: v=1
1950        assert v==1
1951        #Test South
1952        uh = 1.e-15
1953        vh = -1
1954        angle = calc_bearing(uh, vh)
1955        if 179 < angle < 181: v=1
1956        assert v==1
1957        #Test South-East
1958        uh = 1
1959        vh = -1
1960        angle = calc_bearing(uh, vh)
1961        if 134 < angle < 136: v=1
1962        assert v==1
1963        #Test North-East
1964        uh = 1
1965        vh = 1
1966        angle = calc_bearing(uh, vh)
1967        if 44 < angle < 46: v=1
1968        assert v==1
1969        #Test South-West
1970        uh = -1
1971        vh = -1
1972        angle = calc_bearing(uh, vh)
1973        if 224 < angle < 226: v=1
1974        assert v==1
1975        #Test North-West
1976        uh = -1
1977        vh = 1
1978        angle = calc_bearing(uh, vh)
1979        if 314 < angle < 316: v=1
1980        assert v==1
1981       
1982
1983#-------------------------------------------------------------
1984
1985if __name__ == "__main__":
1986    suite = unittest.makeSuite(Test_Util, 'test')
1987#    runner = unittest.TextTestRunner(verbosity=2)
1988    runner = unittest.TextTestRunner(verbosity=1)
1989    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.