source: branches/numpy/anuga/abstract_2d_finite_volumes/test_util.py @ 6883

Last change on this file since 6883 was 6553, checked in by rwilson, 16 years ago

Merged trunk into numpy, all tests and validations work.

File size: 66.9 KB
Line 
1#!/usr/bin/env python
2
3
4import unittest
5from math import sqrt, pi
6import tempfile, os
7from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
8
9from anuga.abstract_2d_finite_volumes.util import *
10from anuga.config import epsilon
11from anuga.shallow_water.data_manager import timefile2netcdf,del_dir
12
13from anuga.utilities.numerical_tools import NAN
14
15from sys import platform
16
17from anuga.pmesh.mesh import Mesh
18from anuga.shallow_water import Domain, Transmissive_boundary
19from anuga.shallow_water.data_manager import get_dataobject
20from csv import reader,writer
21import time
22import string
23
24import numpy as num
25
26
27def test_function(x, y):
28    return x+y
29
30class Test_Util(unittest.TestCase):
31    def setUp(self):
32        pass
33
34    def tearDown(self):
35        pass
36
37
38
39
40    #Geometric
41    #def test_distance(self):
42    #    from anuga.abstract_2d_finite_volumes.util import distance#
43    #
44    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
45    #                     'Distance is wrong!')
46    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
47    #                    'distance is wrong!')
48    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
49    #                    'distance is wrong!')
50    #
51    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
52    #                    'distance is wrong!')
53
54
55    def test_file_function_time1(self):
56        """Test that File function interpolates correctly
57        between given times. No x,y dependency here.
58        """
59
60        #Write file
61        import os, time
62        from anuga.config import time_format
63        from math import sin, pi
64
65        #Typical ASCII file
66        finaltime = 1200
67        filename = 'test_file_function'
68        fid = open(filename + '.txt', 'w')
69        start = time.mktime(time.strptime('2000', '%Y'))
70        dt = 60  #One minute intervals
71        t = 0.0
72        while t <= finaltime:
73            t_string = time.strftime(time_format, time.gmtime(t+start))
74            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
75            t += dt
76
77        fid.close()
78
79        #Convert ASCII file to NetCDF (Which is what we really like!)
80        timefile2netcdf(filename)
81
82
83        #Create file function from time series
84        F = file_function(filename + '.tms',
85                          quantities = ['Attribute0',
86                                        'Attribute1',
87                                        'Attribute2'])
88       
89        #Now try interpolation
90        for i in range(20):
91            t = i*10
92            q = F(t)
93
94            #Exact linear intpolation
95            assert num.allclose(q[0], 2*t)
96            if i%6 == 0:
97                assert num.allclose(q[1], t**2)
98                assert num.allclose(q[2], sin(t*pi/600))
99
100        #Check non-exact
101
102        t = 90 #Halfway between 60 and 120
103        q = F(t)
104        assert num.allclose( (120**2 + 60**2)/2, q[1] )
105        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
106
107
108        t = 100 #Two thirds of the way between between 60 and 120
109        q = F(t)
110        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
111        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
112
113        os.remove(filename + '.txt')
114        os.remove(filename + '.tms')       
115
116
117       
118    def test_spatio_temporal_file_function_basic(self):
119        """Test that spatio temporal file function performs the correct
120        interpolations in both time and space
121        NetCDF version (x,y,t dependency)       
122        """
123        import time
124
125        #Create sww file of simple propagation from left to right
126        #through rectangular domain
127        from shallow_water import Domain, Dirichlet_boundary
128        from mesh_factory import rectangular
129
130        #Create basic mesh and shallow water domain
131        points, vertices, boundary = rectangular(3, 3)
132        domain1 = Domain(points, vertices, boundary)
133
134        from anuga.utilities.numerical_tools import mean
135        domain1.reduction = mean
136        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
137                              # only one value.
138
139        domain1.default_order = 2
140        domain1.store = True
141        domain1.set_datadir('.')
142        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
143        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
144
145        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
146        domain1.set_quantity('elevation', 0)
147        domain1.set_quantity('friction', 0)
148        domain1.set_quantity('stage', 0)
149
150        # Boundary conditions
151        B0 = Dirichlet_boundary([0,0,0])
152        B6 = Dirichlet_boundary([0.6,0,0])
153        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
154        domain1.check_integrity()
155
156        finaltime = 8
157        #Evolution
158        t0 = -1
159        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
160            #print 'Timesteps: %.16f, %.16f' %(t0, t)
161            #if t == t0:
162            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
163            #   raise msg
164            t0 = t
165             
166            #domain1.write_time()
167
168
169        #Now read data from sww and check
170        from Scientific.IO.NetCDF import NetCDFFile
171        filename = domain1.get_name() + '.' + domain1.format
172        fid = NetCDFFile(filename)
173
174        x = fid.variables['x'][:]
175        y = fid.variables['y'][:]
176        stage = fid.variables['stage'][:]
177        xmomentum = fid.variables['xmomentum'][:]
178        ymomentum = fid.variables['ymomentum'][:]
179        time = fid.variables['time'][:]
180
181        #Take stage vertex values at last timestep on diagonal
182        #Diagonal is identified by vertices: 0, 5, 10, 15
183
184        last_time_index = len(time)-1 #Last last_time_index
185        d_stage = num.reshape(num.take(stage[last_time_index, :],
186                                       [0,5,10,15],
187                                       axis=0),
188                              (4,1))
189        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
190                                    [0,5,10,15],
191                                   axis=0),
192                           (4,1))
193        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
194                                    [0,5,10,15],
195                                   axis=0),
196                           (4,1))
197        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
198
199        #Reference interpolated values at midpoints on diagonal at
200        #this timestep are
201        r0 = (D[0] + D[1])/2
202        r1 = (D[1] + D[2])/2
203        r2 = (D[2] + D[3])/2
204
205        #And the midpoints are found now
206        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
207        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
208
209        diag = num.concatenate( (Dx, Dy), axis=1)
210        d_midpoints = (diag[1:] + diag[:-1])/2
211
212        #Let us see if the file function can find the correct
213        #values at the midpoints at the last timestep:
214        f = file_function(filename, domain1,
215                          interpolation_points = d_midpoints)
216
217        T = f.get_time()
218        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
219        assert not T[-1] == T[-2], msg
220        t = time[last_time_index]
221        q = f(t, point_id=0); assert num.allclose(r0, q)
222        q = f(t, point_id=1); assert num.allclose(r1, q)
223        q = f(t, point_id=2); assert num.allclose(r2, q)
224
225
226        ##################
227        #Now do the same for the first timestep
228
229        timestep = 0 #First timestep
230        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
231        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
232        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
233        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
234
235        #Reference interpolated values at midpoints on diagonal at
236        #this timestep are
237        r0 = (D[0] + D[1])/2
238        r1 = (D[1] + D[2])/2
239        r2 = (D[2] + D[3])/2
240
241        #Let us see if the file function can find the correct
242        #values
243        q = f(0, point_id=0); assert num.allclose(r0, q)
244        q = f(0, point_id=1); assert num.allclose(r1, q)
245        q = f(0, point_id=2); assert num.allclose(r2, q)
246
247
248        ##################
249        #Now do it again for a timestep in the middle
250
251        timestep = 33
252        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
253        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
254        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
255        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
256
257        #Reference interpolated values at midpoints on diagonal at
258        #this timestep are
259        r0 = (D[0] + D[1])/2
260        r1 = (D[1] + D[2])/2
261        r2 = (D[2] + D[3])/2
262
263        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
264        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
265        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
266
267
268        ##################
269        #Now check temporal interpolation
270        #Halfway between timestep 15 and 16
271
272        timestep = 15
273        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
274        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
275        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
276        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
277
278        #Reference interpolated values at midpoints on diagonal at
279        #this timestep are
280        r0_0 = (D[0] + D[1])/2
281        r1_0 = (D[1] + D[2])/2
282        r2_0 = (D[2] + D[3])/2
283
284        #
285        timestep = 16
286        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
287        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
288        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
289        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
290
291        #Reference interpolated values at midpoints on diagonal at
292        #this timestep are
293        r0_1 = (D[0] + D[1])/2
294        r1_1 = (D[1] + D[2])/2
295        r2_1 = (D[2] + D[3])/2
296
297        # The reference values are
298        r0 = (r0_0 + r0_1)/2
299        r1 = (r1_0 + r1_1)/2
300        r2 = (r2_0 + r2_1)/2
301
302        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
303        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
304        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
305
306        ##################
307        #Finally check interpolation 2 thirds of the way
308        #between timestep 15 and 16
309
310        # The reference values are
311        r0 = (r0_0 + 2*r0_1)/3
312        r1 = (r1_0 + 2*r1_1)/3
313        r2 = (r2_0 + 2*r2_1)/3
314
315        #And the file function gives
316        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
317        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
318        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
319
320        fid.close()
321        import os
322        os.remove(filename)
323
324
325
326    def test_spatio_temporal_file_function_different_origin(self):
327        """Test that spatio temporal file function performs the correct
328        interpolations in both time and space where space is offset by
329        xllcorner and yllcorner
330        NetCDF version (x,y,t dependency)       
331        """
332        import time
333
334        #Create sww file of simple propagation from left to right
335        #through rectangular domain
336        from shallow_water import Domain, Dirichlet_boundary
337        from mesh_factory import rectangular
338
339
340        from anuga.coordinate_transforms.geo_reference import Geo_reference
341        xllcorner = 2048
342        yllcorner = 11000
343        zone = 2
344
345        #Create basic mesh and shallow water domain
346        points, vertices, boundary = rectangular(3, 3)
347        domain1 = Domain(points, vertices, boundary,
348                         geo_reference = Geo_reference(xllcorner = xllcorner,
349                                                       yllcorner = yllcorner))
350       
351
352        from anuga.utilities.numerical_tools import mean       
353        domain1.reduction = mean
354        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
355                              # only one value.
356
357        domain1.default_order = 2
358        domain1.store = True
359        domain1.set_datadir('.')
360        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
361        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
362
363        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
364        domain1.set_quantity('elevation', 0)
365        domain1.set_quantity('friction', 0)
366        domain1.set_quantity('stage', 0)
367
368        # Boundary conditions
369        B0 = Dirichlet_boundary([0,0,0])
370        B6 = Dirichlet_boundary([0.6,0,0])
371        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
372        domain1.check_integrity()
373
374        finaltime = 8
375        #Evolution
376        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
377            pass
378            #domain1.write_time()
379
380
381        #Now read data from sww and check
382        from Scientific.IO.NetCDF import NetCDFFile
383        filename = domain1.get_name() + '.' + domain1.format
384        fid = NetCDFFile(filename)
385
386        x = fid.variables['x'][:]
387        y = fid.variables['y'][:]
388        # we 'cast' to 64 bit floats to pass this test
389        # SWW file quantities are stored as 32 bits
390        x = num.array(x, num.float)
391        y = num.array(y, num.float)
392
393        stage = fid.variables['stage'][:]
394        xmomentum = fid.variables['xmomentum'][:]
395        ymomentum = fid.variables['ymomentum'][:]
396        time = fid.variables['time'][:]
397
398        #Take stage vertex values at last timestep on diagonal
399        #Diagonal is identified by vertices: 0, 5, 10, 15
400
401        last_time_index = len(time)-1 #Last last_time_index     
402        d_stage = num.reshape(num.take(stage[last_time_index, :],
403                                       [0,5,10,15],
404                                       axis=0),
405                              (4,1))
406        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
407                                    [0,5,10,15],
408                                   axis=0),
409                           (4,1))
410        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
411                                    [0,5,10,15],
412                                    axis=0),
413                           (4,1))
414        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
415
416        #Reference interpolated values at midpoints on diagonal at
417        #this timestep are
418        r0 = (D[0] + D[1])/2
419        r1 = (D[1] + D[2])/2
420        r2 = (D[2] + D[3])/2
421
422        #And the midpoints are found now
423        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
424        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
425
426        diag = num.concatenate((Dx, Dy), axis=1)
427        d_midpoints = (diag[1:] + diag[:-1])/2
428
429
430        #Adjust for georef - make interpolation points absolute
431        d_midpoints[:,0] += xllcorner
432        d_midpoints[:,1] += yllcorner               
433
434        #Let us see if the file function can find the correct
435        #values at the midpoints at the last timestep:
436        f = file_function(filename, domain1,
437                          interpolation_points = d_midpoints)
438
439        t = time[last_time_index]                         
440
441        q = f(t, point_id=0)
442        msg = '\nr0=%s\nq=%s' % (str(r0), str(q))
443        assert num.allclose(r0, q), msg
444
445        q = f(t, point_id=1)
446        msg = '\nr1=%s\nq=%s' % (str(r1), str(q))
447        assert num.allclose(r1, q), msg
448
449        q = f(t, point_id=2)
450        msg = '\nr2=%s\nq=%s' % (str(r2), str(q))
451        assert num.allclose(r2, q), msg
452
453
454        ##################
455        #Now do the same for the first timestep
456
457        timestep = 0 #First timestep
458        d_stage = num.reshape(num.take(stage[timestep, :],
459                                       [0,5,10,15],
460                                       axis=0),
461                              (4,1))
462        d_uh = num.reshape(num.take(xmomentum[timestep, :],
463                                    [0,5,10,15],
464                                    axis=0),
465                           (4,1))
466        d_vh = num.reshape(num.take(ymomentum[timestep, :],
467                                    [0,5,10,15],
468                                    axis=0),
469                           (4,1))
470        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
471
472        #Reference interpolated values at midpoints on diagonal at
473        #this timestep are
474        r0 = (D[0] + D[1])/2
475        r1 = (D[1] + D[2])/2
476        r2 = (D[2] + D[3])/2
477
478        #Let us see if the file function can find the correct
479        #values
480        q = f(0, point_id=0); assert num.allclose(r0, q)
481        q = f(0, point_id=1); assert num.allclose(r1, q)
482        q = f(0, point_id=2); assert num.allclose(r2, q)
483
484
485        ##################
486        #Now do it again for a timestep in the middle
487
488        timestep = 33
489        d_stage = num.reshape(num.take(stage[timestep, :],
490                                       [0,5,10,15],
491                                       axis=0),
492                              (4,1))
493        d_uh = num.reshape(num.take(xmomentum[timestep, :],
494                                    [0,5,10,15],
495                                    axis=0),
496                           (4,1))
497        d_vh = num.reshape(num.take(ymomentum[timestep, :],
498                                    [0,5,10,15],
499                                    axis=0),
500                           (4,1))
501        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
502
503        #Reference interpolated values at midpoints on diagonal at
504        #this timestep are
505        r0 = (D[0] + D[1])/2
506        r1 = (D[1] + D[2])/2
507        r2 = (D[2] + D[3])/2
508
509        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
510        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
511        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
512
513
514        ##################
515        #Now check temporal interpolation
516        #Halfway between timestep 15 and 16
517
518        timestep = 15
519        d_stage = num.reshape(num.take(stage[timestep, :],
520                                       [0,5,10,15],
521                                       axis=0),
522                              (4,1))
523        d_uh = num.reshape(num.take(xmomentum[timestep, :],
524                                    [0,5,10,15],
525                                    axis=0),
526                           (4,1))
527        d_vh = num.reshape(num.take(ymomentum[timestep, :],
528                                    [0,5,10,15],
529                                    axis=0),
530                           (4,1))
531        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
532
533        #Reference interpolated values at midpoints on diagonal at
534        #this timestep are
535        r0_0 = (D[0] + D[1])/2
536        r1_0 = (D[1] + D[2])/2
537        r2_0 = (D[2] + D[3])/2
538
539        #
540        timestep = 16
541        d_stage = num.reshape(num.take(stage[timestep, :],
542                                       [0,5,10,15],
543                                       axis=0),
544                              (4,1))
545        d_uh = num.reshape(num.take(xmomentum[timestep, :],
546                                    [0,5,10,15],
547                                    axis=0),
548                           (4,1))
549        d_vh = num.reshape(num.take(ymomentum[timestep, :],
550                                    [0,5,10,15],
551                                    axis=0),
552                           (4,1))
553        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
554
555        #Reference interpolated values at midpoints on diagonal at
556        #this timestep are
557        r0_1 = (D[0] + D[1])/2
558        r1_1 = (D[1] + D[2])/2
559        r2_1 = (D[2] + D[3])/2
560
561        # The reference values are
562        r0 = (r0_0 + r0_1)/2
563        r1 = (r1_0 + r1_1)/2
564        r2 = (r2_0 + r2_1)/2
565
566        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
567        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
568        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
569
570        ##################
571        #Finally check interpolation 2 thirds of the way
572        #between timestep 15 and 16
573
574        # The reference values are
575        r0 = (r0_0 + 2*r0_1)/3
576        r1 = (r1_0 + 2*r1_1)/3
577        r2 = (r2_0 + 2*r2_1)/3
578
579        #And the file function gives
580        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
581        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
582        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
583
584        fid.close()
585        import os
586        os.remove(filename)
587
588       
589
590
591    def test_spatio_temporal_file_function_time(self):
592        """Test that File function interpolates correctly
593        between given times.
594        NetCDF version (x,y,t dependency)
595        """
596
597        #Create NetCDF (sww) file to be read
598        # x: 0, 5, 10, 15
599        # y: -20, -10, 0, 10
600        # t: 0, 60, 120, ...., 1200
601        #
602        # test quantities (arbitrary but non-trivial expressions):
603        #
604        #   stage     = 3*x - y**2 + 2*t
605        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
606        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
607
608        #NOTE: Nice test that may render some of the others redundant.
609
610        import os, time
611        from anuga.config import time_format
612        from mesh_factory import rectangular
613        from shallow_water import Domain
614        import anuga.shallow_water.data_manager
615
616        finaltime = 1200
617        filename = 'test_file_function'
618
619        #Create a domain to hold test grid
620        #(0:15, -20:10)
621        points, vertices, boundary =\
622                rectangular(4, 4, 15, 30, origin = (0, -20))
623        #print "points", points
624
625        #print 'Number of elements', len(vertices)
626        domain = Domain(points, vertices, boundary)
627        domain.smooth = False
628        domain.default_order = 2
629        domain.set_datadir('.')
630        domain.set_name(filename)
631        domain.store = True
632        domain.format = 'sww'   #Native netcdf visualisation format
633
634        #print points
635        start = time.mktime(time.strptime('2000', '%Y'))
636        domain.starttime = start
637
638
639        #Store structure
640        domain.initialise_storage()
641
642        #Compute artificial time steps and store
643        dt = 60  #One minute intervals
644        t = 0.0
645        while t <= finaltime:
646            #Compute quantities
647            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
648            domain.set_quantity('stage', f1)
649
650            f2 = lambda x,y: x+y+t**2
651            domain.set_quantity('xmomentum', f2)
652
653            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
654            domain.set_quantity('ymomentum', f3)
655
656            #Store and advance time
657            domain.time = t
658            domain.store_timestep(domain.conserved_quantities)
659            t += dt
660
661
662        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
663     
664        #Deliberately set domain.starttime to too early
665        domain.starttime = start - 1
666
667        #Create file function
668        F = file_function(filename + '.sww', domain,
669                          quantities = domain.conserved_quantities,
670                          interpolation_points = interpolation_points)
671
672        #Check that FF updates fixes domain starttime
673        assert num.allclose(domain.starttime, start)
674
675        #Check that domain.starttime isn't updated if later
676        domain.starttime = start + 1
677        F = file_function(filename + '.sww', domain,
678                          quantities = domain.conserved_quantities,
679                          interpolation_points = interpolation_points)
680        assert num.allclose(domain.starttime, start+1)
681        domain.starttime = start
682
683
684        #Check linear interpolation in time
685        F = file_function(filename + '.sww', domain,
686                          quantities = domain.conserved_quantities,
687                          interpolation_points = interpolation_points)               
688        for id in range(len(interpolation_points)):
689            x = interpolation_points[id][0]
690            y = interpolation_points[id][1]
691
692            for i in range(20):
693                t = i*10
694                k = i%6
695
696                if k == 0:
697                    q0 = F(t, point_id=id)
698                    q1 = F(t+60, point_id=id)
699
700                if num.alltrue(q0 == NAN):
701                    actual = q0
702                else:
703                    actual = (k*q1 + (6-k)*q0)/6
704                q = F(t, point_id=id)
705                #print i, k, t, q
706                #print ' ', q0
707                #print ' ', q1
708                #print "q",q
709                #print "actual", actual
710                #print
711                if num.alltrue(q0 == NAN):
712                     self.failUnless(num.alltrue(q == actual), 'Fail!')
713                else:
714                    assert num.allclose(q, actual)
715
716
717        #Another check of linear interpolation in time
718        for id in range(len(interpolation_points)):
719            q60 = F(60, point_id=id)
720            q120 = F(120, point_id=id)
721
722            t = 90 #Halfway between 60 and 120
723            q = F(t, point_id=id)
724            assert num.allclose( (q120+q60)/2, q )
725
726            t = 100 #Two thirds of the way between between 60 and 120
727            q = F(t, point_id=id)
728            assert num.allclose(q60/3 + 2*q120/3, q)
729
730
731
732        #Check that domain.starttime isn't updated if later than file starttime but earlier
733        #than file end time
734        delta = 23
735        domain.starttime = start + delta
736        F = file_function(filename + '.sww', domain,
737                          quantities = domain.conserved_quantities,
738                          interpolation_points = interpolation_points)
739        assert num.allclose(domain.starttime, start+delta)
740
741
742
743
744        #Now try interpolation with delta offset
745        for id in range(len(interpolation_points)):           
746            x = interpolation_points[id][0]
747            y = interpolation_points[id][1]
748
749            for i in range(20):
750                t = i*10
751                k = i%6
752
753                if k == 0:
754                    q0 = F(t-delta, point_id=id)
755                    q1 = F(t+60-delta, point_id=id)
756
757                q = F(t-delta, point_id=id)
758                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
759
760
761        os.remove(filename + '.sww')
762
763
764
765    def Xtest_spatio_temporal_file_function_time(self):
766        # FIXME: This passes but needs some TLC
767        # Test that File function interpolates correctly
768        # When some points are outside the mesh
769
770        import os, time
771        from anuga.config import time_format
772        from mesh_factory import rectangular
773        from shallow_water import Domain
774        import anuga.shallow_water.data_manager 
775        from anuga.pmesh.mesh_interface import create_mesh_from_regions
776        finaltime = 1200
777       
778        filename = tempfile.mktemp()
779        #print "filename",filename
780        filename = 'test_file_function'
781
782        meshfilename = tempfile.mktemp(".tsh")
783
784        boundary_tags = {'walls':[0,1],'bom':[2]}
785       
786        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
787       
788        create_mesh_from_regions(polygon_absolute,
789                                 boundary_tags,
790                                 10000000,
791                                 filename=meshfilename)
792        domain = Domain(mesh_filename=meshfilename)
793        domain.smooth = False
794        domain.default_order = 2
795        domain.set_datadir('.')
796        domain.set_name(filename)
797        domain.store = True
798        domain.format = 'sww'   #Native netcdf visualisation format
799
800        #print points
801        start = time.mktime(time.strptime('2000', '%Y'))
802        domain.starttime = start
803       
804
805        #Store structure
806        domain.initialise_storage()
807
808        #Compute artificial time steps and store
809        dt = 60  #One minute intervals
810        t = 0.0
811        while t <= finaltime:
812            #Compute quantities
813            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
814            domain.set_quantity('stage', f1)
815
816            f2 = lambda x,y: x+y+t**2
817            domain.set_quantity('xmomentum', f2)
818
819            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
820            domain.set_quantity('ymomentum', f3)
821
822            #Store and advance time
823            domain.time = t
824            domain.store_timestep(domain.conserved_quantities)
825            t += dt
826
827        interpolation_points = [[1,0]]
828        interpolation_points = [[100,1000]]
829       
830        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
831                                [78787,78787],[7878,3432]]
832           
833        #Deliberately set domain.starttime to too early
834        domain.starttime = start - 1
835
836        #Create file function
837        F = file_function(filename + '.sww', domain,
838                          quantities = domain.conserved_quantities,
839                          interpolation_points = interpolation_points)
840
841        #Check that FF updates fixes domain starttime
842        assert num.allclose(domain.starttime, start)
843
844        #Check that domain.starttime isn't updated if later
845        domain.starttime = start + 1
846        F = file_function(filename + '.sww', domain,
847                          quantities = domain.conserved_quantities,
848                          interpolation_points = interpolation_points)
849        assert num.allclose(domain.starttime, start+1)
850        domain.starttime = start
851
852
853        #Check linear interpolation in time
854        # checking points inside and outside the mesh
855        F = file_function(filename + '.sww', domain,
856                          quantities = domain.conserved_quantities,
857                          interpolation_points = interpolation_points)
858       
859        for id in range(len(interpolation_points)):
860            x = interpolation_points[id][0]
861            y = interpolation_points[id][1]
862
863            for i in range(20):
864                t = i*10
865                k = i%6
866
867                if k == 0:
868                    q0 = F(t, point_id=id)
869                    q1 = F(t+60, point_id=id)
870
871                if q0 == NAN:
872                    actual = q0
873                else:
874                    actual = (k*q1 + (6-k)*q0)/6
875                q = F(t, point_id=id)
876                #print i, k, t, q
877                #print ' ', q0
878                #print ' ', q1
879                #print "q",q
880                #print "actual", actual
881                #print
882                if q0 == NAN:
883                     self.failUnless( q == actual, 'Fail!')
884                else:
885                    assert num.allclose(q, actual)
886
887        # now lets check points inside the mesh
888        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
889        interpolation_points = [[10,-12.5]]
890           
891        print "len(interpolation_points)",len(interpolation_points) 
892        F = file_function(filename + '.sww', domain,
893                          quantities = domain.conserved_quantities,
894                          interpolation_points = interpolation_points)
895
896        domain.starttime = start
897
898
899        #Check linear interpolation in time
900        F = file_function(filename + '.sww', domain,
901                          quantities = domain.conserved_quantities,
902                          interpolation_points = interpolation_points)               
903        for id in range(len(interpolation_points)):
904            x = interpolation_points[id][0]
905            y = interpolation_points[id][1]
906
907            for i in range(20):
908                t = i*10
909                k = i%6
910
911                if k == 0:
912                    q0 = F(t, point_id=id)
913                    q1 = F(t+60, point_id=id)
914
915                if q0 == NAN:
916                    actual = q0
917                else:
918                    actual = (k*q1 + (6-k)*q0)/6
919                q = F(t, point_id=id)
920                print "############"
921                print "id, x, y ", id, x, y #k, t, q
922                print "t", t
923                #print ' ', q0
924                #print ' ', q1
925                print "q",q
926                print "actual", actual
927                #print
928                if q0 == NAN:
929                     self.failUnless( q == actual, 'Fail!')
930                else:
931                    assert num.allclose(q, actual)
932
933
934        #Another check of linear interpolation in time
935        for id in range(len(interpolation_points)):
936            q60 = F(60, point_id=id)
937            q120 = F(120, point_id=id)
938
939            t = 90 #Halfway between 60 and 120
940            q = F(t, point_id=id)
941            assert num.allclose( (q120+q60)/2, q )
942
943            t = 100 #Two thirds of the way between between 60 and 120
944            q = F(t, point_id=id)
945            assert num.allclose(q60/3 + 2*q120/3, q)
946
947
948
949        #Check that domain.starttime isn't updated if later than file starttime but earlier
950        #than file end time
951        delta = 23
952        domain.starttime = start + delta
953        F = file_function(filename + '.sww', domain,
954                          quantities = domain.conserved_quantities,
955                          interpolation_points = interpolation_points)
956        assert num.allclose(domain.starttime, start+delta)
957
958
959
960
961        #Now try interpolation with delta offset
962        for id in range(len(interpolation_points)):           
963            x = interpolation_points[id][0]
964            y = interpolation_points[id][1]
965
966            for i in range(20):
967                t = i*10
968                k = i%6
969
970                if k == 0:
971                    q0 = F(t-delta, point_id=id)
972                    q1 = F(t+60-delta, point_id=id)
973
974                q = F(t-delta, point_id=id)
975                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
976
977
978        os.remove(filename + '.sww')
979
980    def test_file_function_time_with_domain(self):
981        """Test that File function interpolates correctly
982        between given times. No x,y dependency here.
983        Use domain with starttime
984        """
985
986        #Write file
987        import os, time, calendar
988        from anuga.config import time_format
989        from math import sin, pi
990        from domain import Domain
991
992        finaltime = 1200
993        filename = 'test_file_function'
994        fid = open(filename + '.txt', 'w')
995        start = time.mktime(time.strptime('2000', '%Y'))
996        dt = 60  #One minute intervals
997        t = 0.0
998        while t <= finaltime:
999            t_string = time.strftime(time_format, time.gmtime(t+start))
1000            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1001            t += dt
1002
1003        fid.close()
1004
1005
1006        #Convert ASCII file to NetCDF (Which is what we really like!)
1007        timefile2netcdf(filename)
1008
1009
1010
1011        a = [0.0, 0.0]
1012        b = [4.0, 0.0]
1013        c = [0.0, 3.0]
1014
1015        points = [a, b, c]
1016        vertices = [[0,1,2]]
1017        domain = Domain(points, vertices)
1018
1019        # Check that domain.starttime is updated if non-existing
1020        F = file_function(filename + '.tms',
1021                          domain,
1022                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
1023        assert num.allclose(domain.starttime, start)
1024
1025        # Check that domain.starttime is updated if too early
1026        domain.starttime = start - 1
1027        F = file_function(filename + '.tms',
1028                          domain,
1029                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1030        assert num.allclose(domain.starttime, start)
1031
1032        # Check that domain.starttime isn't updated if later
1033        domain.starttime = start + 1
1034        F = file_function(filename + '.tms',
1035                          domain,
1036                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1037        assert num.allclose(domain.starttime, start+1)
1038
1039        domain.starttime = start
1040        F = file_function(filename + '.tms',
1041                          domain,
1042                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
1043                          use_cache=True)
1044       
1045
1046        #print F.precomputed_values
1047        #print 'F(60)', F(60)
1048       
1049        #Now try interpolation
1050        for i in range(20):
1051            t = i*10
1052            q = F(t)
1053
1054            #Exact linear intpolation
1055            assert num.allclose(q[0], 2*t)
1056            if i%6 == 0:
1057                assert num.allclose(q[1], t**2)
1058                assert num.allclose(q[2], sin(t*pi/600))
1059
1060        #Check non-exact
1061
1062        t = 90 #Halfway between 60 and 120
1063        q = F(t)
1064        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1065        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1066
1067
1068        t = 100 #Two thirds of the way between between 60 and 120
1069        q = F(t)
1070        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1071        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1072
1073        os.remove(filename + '.tms')
1074        os.remove(filename + '.txt')       
1075
1076    def test_file_function_time_with_domain_different_start(self):
1077        """Test that File function interpolates correctly
1078        between given times. No x,y dependency here.
1079        Use domain with a starttime later than that of file
1080
1081        ASCII version
1082        """
1083
1084        #Write file
1085        import os, time, calendar
1086        from anuga.config import time_format
1087        from math import sin, pi
1088        from domain import Domain
1089
1090        finaltime = 1200
1091        filename = 'test_file_function'
1092        fid = open(filename + '.txt', 'w')
1093        start = time.mktime(time.strptime('2000', '%Y'))
1094        dt = 60  #One minute intervals
1095        t = 0.0
1096        while t <= finaltime:
1097            t_string = time.strftime(time_format, time.gmtime(t+start))
1098            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1099            t += dt
1100
1101        fid.close()
1102
1103        #Convert ASCII file to NetCDF (Which is what we really like!)
1104        timefile2netcdf(filename)       
1105
1106        a = [0.0, 0.0]
1107        b = [4.0, 0.0]
1108        c = [0.0, 3.0]
1109
1110        points = [a, b, c]
1111        vertices = [[0,1,2]]
1112        domain = Domain(points, vertices)
1113
1114        #Check that domain.starttime isn't updated if later than file starttime but earlier
1115        #than file end time
1116        delta = 23
1117        domain.starttime = start + delta
1118        F = file_function(filename + '.tms', domain,
1119                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1120        assert num.allclose(domain.starttime, start+delta)
1121
1122        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1123                                            277., 337., 397., 457., 517.,
1124                                            577., 637., 697., 757., 817.,
1125                                            877., 937., 997., 1057., 1117.,
1126                                            1177.])
1127
1128
1129        #Now try interpolation with delta offset
1130        for i in range(20):
1131            t = i*10
1132            q = F(t-delta)
1133
1134            #Exact linear intpolation
1135            assert num.allclose(q[0], 2*t)
1136            if i%6 == 0:
1137                assert num.allclose(q[1], t**2)
1138                assert num.allclose(q[2], sin(t*pi/600))
1139
1140        #Check non-exact
1141
1142        t = 90 #Halfway between 60 and 120
1143        q = F(t-delta)
1144        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1145        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1146
1147
1148        t = 100 #Two thirds of the way between between 60 and 120
1149        q = F(t-delta)
1150        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1151        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1152
1153
1154        os.remove(filename + '.tms')
1155        os.remove(filename + '.txt')               
1156
1157       
1158
1159    def test_file_function_time_with_domain_different_start_and_time_limit(self):
1160        """Test that File function interpolates correctly
1161        between given times. No x,y dependency here.
1162        Use domain with a starttime later than that of file
1163
1164        ASCII version
1165       
1166        This test also tests that time can be truncated.
1167        """
1168
1169        # Write file
1170        import os, time, calendar
1171        from anuga.config import time_format
1172        from math import sin, pi
1173        from domain import Domain
1174
1175        finaltime = 1200
1176        filename = 'test_file_function'
1177        fid = open(filename + '.txt', 'w')
1178        start = time.mktime(time.strptime('2000', '%Y'))
1179        dt = 60  #One minute intervals
1180        t = 0.0
1181        while t <= finaltime:
1182            t_string = time.strftime(time_format, time.gmtime(t+start))
1183            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1184            t += dt
1185
1186        fid.close()
1187
1188        # Convert ASCII file to NetCDF (Which is what we really like!)
1189        timefile2netcdf(filename)       
1190
1191        a = [0.0, 0.0]
1192        b = [4.0, 0.0]
1193        c = [0.0, 3.0]
1194
1195        points = [a, b, c]
1196        vertices = [[0,1,2]]
1197        domain = Domain(points, vertices)
1198
1199        # Check that domain.starttime isn't updated if later than file starttime but earlier
1200        # than file end time
1201        delta = 23
1202        domain.starttime = start + delta
1203        time_limit = domain.starttime + 600
1204        F = file_function(filename + '.tms', domain,
1205                          time_limit=time_limit,
1206                          quantities=['Attribute0', 'Attribute1', 'Attribute2'])       
1207        assert num.allclose(domain.starttime, start+delta)
1208
1209        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1210                                            277., 337., 397., 457., 517.,
1211                                            577.])       
1212
1213
1214
1215        # Now try interpolation with delta offset
1216        for i in range(20):
1217            t = i*10
1218            q = F(t-delta)
1219
1220            #Exact linear intpolation
1221            assert num.allclose(q[0], 2*t)
1222            if i%6 == 0:
1223                assert num.allclose(q[1], t**2)
1224                assert num.allclose(q[2], sin(t*pi/600))
1225
1226        # Check non-exact
1227        t = 90 #Halfway between 60 and 120
1228        q = F(t-delta)
1229        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1230        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1231
1232
1233        t = 100 # Two thirds of the way between between 60 and 120
1234        q = F(t-delta)
1235        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1236        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1237
1238
1239        os.remove(filename + '.tms')
1240        os.remove(filename + '.txt')               
1241
1242       
1243       
1244       
1245
1246
1247    def test_apply_expression_to_dictionary(self):
1248
1249        #FIXME: Division is not expected to work for integers.
1250        #This must be caught.
1251        foo = num.array([[1,2,3], [4,5,6]], num.float)
1252
1253        bar = num.array([[-1,0,5], [6,1,1]], num.float)                 
1254
1255        D = {'X': foo, 'Y': bar}
1256
1257        Z = apply_expression_to_dictionary('X+Y', D)       
1258        assert num.allclose(Z, foo+bar)
1259
1260        Z = apply_expression_to_dictionary('X*Y', D)       
1261        assert num.allclose(Z, foo*bar)       
1262
1263        Z = apply_expression_to_dictionary('4*X+Y', D)       
1264        assert num.allclose(Z, 4*foo+bar)       
1265
1266        # test zero division is OK
1267        Z = apply_expression_to_dictionary('X/Y', D)
1268        assert num.allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1269
1270        # make an error for zero on zero
1271        # this is really an error in numeric, SciPy core can handle it
1272        # Z = apply_expression_to_dictionary('0/Y', D)
1273
1274        #Check exceptions
1275        try:
1276            #Wrong name
1277            Z = apply_expression_to_dictionary('4*X+A', D)       
1278        except NameError:
1279            pass
1280        else:
1281            msg = 'Should have raised a NameError Exception'
1282            raise msg
1283
1284
1285        try:
1286            #Wrong order
1287            Z = apply_expression_to_dictionary(D, '4*X+A')       
1288        except AssertionError:
1289            pass
1290        else:
1291            msg = 'Should have raised a AssertionError Exception'
1292            raise msg       
1293       
1294
1295    def test_multiple_replace(self):
1296        """Hard test that checks a true word-by-word simultaneous replace
1297        """
1298       
1299        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1300        exp = '3*x+y + xi'
1301       
1302        new = multiple_replace(exp, D)
1303       
1304        assert new == '3*xi+eta + lam'
1305                         
1306
1307
1308    def test_point_on_line_obsolete(self):
1309        """Test that obsolete call issues appropriate warning"""
1310
1311        #Turn warning into an exception
1312        import warnings
1313        warnings.filterwarnings('error')
1314
1315        try:
1316            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1317        except DeprecationWarning:
1318            pass
1319        else:
1320            msg = 'point_on_line should have issued a DeprecationWarning'
1321            raise Exception(msg)   
1322
1323        warnings.resetwarnings()
1324   
1325    def test_get_revision_number(self):
1326        """test_get_revision_number(self):
1327
1328        Test that revision number can be retrieved.
1329        """
1330        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1331            # I have a known snv incompatability issue,
1332            # so I'm skipping this test.
1333            # FIXME when SVN is upgraded on our clusters
1334            pass
1335        else:   
1336            n = get_revision_number()
1337            assert n>=0
1338
1339
1340       
1341    def test_add_directories(self):
1342       
1343        import tempfile
1344        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1345        directories = ['ja','ne','ke']
1346        kens_dir = add_directories(root_dir, directories)
1347        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1348               sep + 'ke'
1349        assert access(root_dir,F_OK)
1350
1351        add_directories(root_dir, directories)
1352        assert access(root_dir,F_OK)
1353       
1354        #clean up!
1355        os.rmdir(kens_dir)
1356        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1357        os.rmdir(root_dir + sep + 'ja')
1358        os.rmdir(root_dir)
1359
1360    def test_add_directories_bad(self):
1361       
1362        import tempfile
1363        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1364        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1365       
1366        try:
1367            kens_dir = add_directories(root_dir, directories)
1368        except OSError:
1369            pass
1370        else:
1371            msg = 'bad dir name should give OSError'
1372            raise Exception(msg)   
1373           
1374        #clean up!
1375        os.rmdir(root_dir)
1376
1377    def test_check_list(self):
1378
1379        check_list(['stage','xmomentum'])
1380
1381       
1382    def test_add_directories(self):
1383       
1384        import tempfile
1385        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1386        directories = ['ja','ne','ke']
1387        kens_dir = add_directories(root_dir, directories)
1388        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1389               sep + 'ke'
1390        assert access(root_dir,F_OK)
1391
1392        add_directories(root_dir, directories)
1393        assert access(root_dir,F_OK)
1394       
1395        #clean up!
1396        os.rmdir(kens_dir)
1397        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1398        os.rmdir(root_dir + sep + 'ja')
1399        os.rmdir(root_dir)
1400
1401    def test_add_directories_bad(self):
1402       
1403        import tempfile
1404        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1405        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1406       
1407        try:
1408            kens_dir = add_directories(root_dir, directories)
1409        except OSError:
1410            pass
1411        else:
1412            msg = 'bad dir name should give OSError'
1413            raise Exception(msg)   
1414           
1415        #clean up!
1416        os.rmdir(root_dir)
1417
1418    def test_check_list(self):
1419
1420        check_list(['stage','xmomentum'])
1421
1422######
1423# Test the remove_lone_verts() function
1424######
1425       
1426    def test_remove_lone_verts_a(self):
1427        verts = [[0,0],[1,0],[0,1]]
1428        tris = [[0,1,2]]
1429        new_verts, new_tris = remove_lone_verts(verts, tris)
1430        self.failUnless(new_verts.tolist() == verts)
1431        self.failUnless(new_tris.tolist() == tris)
1432
1433    def test_remove_lone_verts_b(self):
1434        verts = [[0,0],[1,0],[0,1],[99,99]]
1435        tris = [[0,1,2]]
1436        new_verts, new_tris = remove_lone_verts(verts, tris)
1437        self.failUnless(new_verts.tolist() == verts[0:3])
1438        self.failUnless(new_tris.tolist() == tris)
1439       
1440    def test_remove_lone_verts_c(self):
1441        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1442        tris = [[1,2,4]]
1443        new_verts, new_tris = remove_lone_verts(verts, tris)
1444        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1445        self.failUnless(new_tris.tolist() == [[0,1,2]])
1446     
1447    def test_remove_lone_verts_d(self):
1448        verts = [[0,0],[1,0],[99,99],[0,1]]
1449        tris = [[0,1,3]]
1450        new_verts, new_tris = remove_lone_verts(verts, tris)
1451        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1452        self.failUnless(new_tris.tolist() == [[0,1,2]])
1453       
1454    def test_remove_lone_verts_e(self):
1455        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1456        tris = [[0,1,2]]
1457        new_verts, new_tris = remove_lone_verts(verts, tris)
1458        self.failUnless(new_verts.tolist() == verts[0:3])
1459        self.failUnless(new_tris.tolist() == tris)
1460     
1461    def test_remove_lone_verts_f(self):
1462        verts = [[0,0],[1,0],[99,99],[0,1],[99,99],[1,1],[99,99]]
1463        tris = [[0,1,3],[0,1,5]]
1464        new_verts, new_tris = remove_lone_verts(verts, tris)
1465        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1],[1,1]])
1466        self.failUnless(new_tris.tolist() == [[0,1,2],[0,1,3]])
1467       
1468######
1469#
1470######
1471       
1472    def test_get_min_max_values(self):
1473       
1474        list=[8,9,6,1,4]
1475        min1, max1 = get_min_max_values(list)
1476       
1477        assert min1==1 
1478        assert max1==9
1479       
1480    def test_get_min_max_values1(self):
1481       
1482        list=[-8,-9,-6,-1,-4]
1483        min1, max1 = get_min_max_values(list)
1484       
1485#        print 'min1,max1',min1,max1
1486        assert min1==-9 
1487        assert max1==-1
1488
1489#    def test_get_min_max_values2(self):
1490#        '''
1491#        The min and max supplied are greater than the ones in the
1492#        list and therefore are the ones returned
1493#        '''
1494#        list=[-8,-9,-6,-1,-4]
1495#        min1, max1 = get_min_max_values(list,-10,10)
1496#       
1497##        print 'min1,max1',min1,max1
1498#        assert min1==-10
1499#        assert max1==10
1500       
1501    def test_make_plots_from_csv_files(self):
1502       
1503        #if sys.platform == 'win32':  #Windows
1504            try: 
1505                import pylab
1506            except ImportError:
1507                #ANUGA don't need pylab to work so the system doesn't
1508                #rely on pylab being installed
1509                return
1510           
1511       
1512            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1513            temp_dir = tempfile.mkdtemp('','figures')
1514    #        print 'temp_dir',temp_dir
1515            fileName = temp_dir+sep+'time_series_3.csv'
1516            file = open(fileName,"w")
1517            file.write("time,stage,speed,momentum,elevation\n\
15181.0, 0, 0, 0, 10 \n\
15192.0, 5, 2, 4, 10 \n\
15203.0, 3, 3, 5, 10 \n")
1521            file.close()
1522   
1523            fileName1 = temp_dir+sep+'time_series_4.csv'
1524            file1 = open(fileName1,"w")
1525            file1.write("time,stage,speed,momentum,elevation\n\
15261.0, 0, 0, 0, 5 \n\
15272.0, -5, -2, -4, 5 \n\
15283.0, -4, -3, -5, 5 \n")
1529            file1.close()
1530   
1531            fileName2 = temp_dir+sep+'time_series_5.csv'
1532            file2 = open(fileName2,"w")
1533            file2.write("time,stage,speed,momentum,elevation\n\
15341.0, 0, 0, 0, 7 \n\
15352.0, 4, -0.45, 57, 7 \n\
15363.0, 6, -0.5, 56, 7 \n")
1537            file2.close()
1538           
1539            dir, name=os.path.split(fileName)
1540            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1541                                  output_dir=temp_dir,
1542                                  base_name='time_series_',
1543                                  plot_numbers=['3-5'],
1544                                  quantities=['speed','stage','momentum'],
1545                                  assess_all_csv_files=True,
1546                                  extra_plot_name='test')
1547           
1548            #print dir+sep+name[:-4]+'_stage_test.png'
1549            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1550            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1551            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1552   
1553            dir1, name1=os.path.split(fileName1)
1554            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1555            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1556            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1557   
1558   
1559            dir2, name2=os.path.split(fileName2)
1560            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1561            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1562            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1563   
1564            del_dir(temp_dir)
1565       
1566
1567    def test_sww2csv_gauges(self):
1568
1569        def elevation_function(x, y):
1570            return -x
1571       
1572        """Most of this test was copied from test_interpolate
1573        test_interpole_sww2csv
1574       
1575        This is testing the gauge_sww2csv function, by creating a sww file and
1576        then exporting the gauges and checking the results.
1577        """
1578       
1579        # Create mesh
1580        mesh_file = tempfile.mktemp(".tsh")   
1581        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1582        m = Mesh()
1583        m.add_vertices(points)
1584        m.auto_segment()
1585        m.generate_mesh(verbose=False)
1586        m.export_mesh_file(mesh_file)
1587       
1588        # Create shallow water domain
1589        domain = Domain(mesh_file)
1590        os.remove(mesh_file)
1591       
1592        domain.default_order=2
1593       
1594        # This test was made before tight_slope_limiters were introduced
1595        # Since were are testing interpolation values this is OK
1596        domain.tight_slope_limiters = 0 
1597       
1598
1599        # Set some field values
1600        domain.set_quantity('elevation', elevation_function)
1601        domain.set_quantity('friction', 0.03)
1602        domain.set_quantity('xmomentum', 3.0)
1603        domain.set_quantity('ymomentum', 4.0)
1604
1605        ######################
1606        # Boundary conditions
1607        B = Transmissive_boundary(domain)
1608        domain.set_boundary( {'exterior': B})
1609
1610        # This call mangles the stage values.
1611        domain.distribute_to_vertices_and_edges()
1612        domain.set_quantity('stage', 1.0)
1613
1614
1615        domain.set_name('datatest' + str(time.time()))
1616        domain.format = 'sww'
1617        domain.smooth = True
1618        domain.reduction = mean
1619
1620
1621        sww = get_dataobject(domain)
1622        sww.store_connectivity()
1623        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1624        domain.set_quantity('stage', 10.0) # This is automatically limited
1625        # so it will not be less than the elevation
1626        domain.time = 2.
1627        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1628
1629        # test the function
1630        points = [[5.0,1.],[0.5,2.]]
1631
1632        points_file = tempfile.mktemp(".csv")
1633#        points_file = 'test_point.csv'
1634        file_id = open(points_file,"w")
1635        file_id.write("name, easting, northing, elevation \n\
1636point1, 5.0, 1.0, 3.0\n\
1637point2, 0.5, 2.0, 9.0\n")
1638        file_id.close()
1639
1640       
1641        sww2csv_gauges(sww.filename, 
1642                       points_file,
1643                       verbose=False,
1644                       use_cache=False)
1645
1646#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1647        point1_answers_array = [[0.0,0.0,1.0,6.0,-5.0,3.0,4.0], [2.0,2.0/3600.,10.0,15.0,-5.0,3.0,4.0]]
1648        point1_filename = 'gauge_point1.csv'
1649        point1_handle = file(point1_filename)
1650        point1_reader = reader(point1_handle)
1651        point1_reader.next()
1652
1653        line=[]
1654        for i,row in enumerate(point1_reader):
1655#            print 'i',i,'row',row
1656            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1657                         float(row[4]),float(row[5]),float(row[6])])
1658#            print 'assert line',line[i],'point1',point1_answers_array[i]
1659            assert num.allclose(line[i], point1_answers_array[i])
1660
1661        point2_answers_array = [[0.0,0.0,1.0,1.5,-0.5,3.0,4.0], [2.0,2.0/3600.,10.0,10.5,-0.5,3.0,4.0]]
1662        point2_filename = 'gauge_point2.csv' 
1663        point2_handle = file(point2_filename)
1664        point2_reader = reader(point2_handle)
1665        point2_reader.next()
1666                       
1667        line=[]
1668        for i,row in enumerate(point2_reader):
1669#            print 'i',i,'row',row
1670            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1671                         float(row[4]),float(row[5]),float(row[6])])
1672#            print 'assert line',line[i],'point1',point1_answers_array[i]
1673            assert num.allclose(line[i], point2_answers_array[i])
1674                         
1675        # clean up
1676        point1_handle.close()
1677        point2_handle.close()
1678        #print "sww.filename",sww.filename
1679        os.remove(sww.filename)
1680        os.remove(points_file)
1681        os.remove(point1_filename)
1682        os.remove(point2_filename)
1683
1684
1685
1686    def test_sww2csv_gauges1(self):
1687        from anuga.pmesh.mesh import Mesh
1688        from anuga.shallow_water import Domain, Transmissive_boundary
1689        from anuga.shallow_water.data_manager import get_dataobject
1690        from csv import reader,writer
1691        import time
1692        import string
1693
1694        def elevation_function(x, y):
1695            return -x
1696       
1697        """Most of this test was copied from test_interpolate
1698        test_interpole_sww2csv
1699       
1700        This is testing the gauge_sww2csv function, by creating a sww file and
1701        then exporting the gauges and checking the results.
1702       
1703        This tests the ablity not to have elevation in the points file and
1704        not store xmomentum and ymomentum
1705        """
1706       
1707        # Create mesh
1708        mesh_file = tempfile.mktemp(".tsh")   
1709        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1710        m = Mesh()
1711        m.add_vertices(points)
1712        m.auto_segment()
1713        m.generate_mesh(verbose=False)
1714        m.export_mesh_file(mesh_file)
1715       
1716        # Create shallow water domain
1717        domain = Domain(mesh_file)
1718        os.remove(mesh_file)
1719       
1720        domain.default_order=2
1721
1722        # Set some field values
1723        domain.set_quantity('elevation', elevation_function)
1724        domain.set_quantity('friction', 0.03)
1725        domain.set_quantity('xmomentum', 3.0)
1726        domain.set_quantity('ymomentum', 4.0)
1727
1728        ######################
1729        # Boundary conditions
1730        B = Transmissive_boundary(domain)
1731        domain.set_boundary( {'exterior': B})
1732
1733        # This call mangles the stage values.
1734        domain.distribute_to_vertices_and_edges()
1735        domain.set_quantity('stage', 1.0)
1736
1737
1738        domain.set_name('datatest' + str(time.time()))
1739        domain.format = 'sww'
1740        domain.smooth = True
1741        domain.reduction = mean
1742
1743        sww = get_dataobject(domain)
1744        sww.store_connectivity()
1745        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1746        domain.set_quantity('stage', 10.0) # This is automatically limited
1747        # so it will not be less than the elevation
1748        domain.time = 2.
1749        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1750
1751        # test the function
1752        points = [[5.0,1.],[0.5,2.]]
1753
1754        points_file = tempfile.mktemp(".csv")
1755#        points_file = 'test_point.csv'
1756        file_id = open(points_file,"w")
1757        file_id.write("name,easting,northing \n\
1758point1, 5.0, 1.0\n\
1759point2, 0.5, 2.0\n")
1760        file_id.close()
1761
1762        sww2csv_gauges(sww.filename, 
1763                            points_file,
1764                            quantities=['stage', 'elevation'],
1765                            use_cache=False,
1766                            verbose=False)
1767
1768        point1_answers_array = [[0.0,1.0,-5.0], [2.0,10.0,-5.0]]
1769        point1_filename = 'gauge_point1.csv'
1770        point1_handle = file(point1_filename)
1771        point1_reader = reader(point1_handle)
1772        point1_reader.next()
1773
1774        line=[]
1775        for i,row in enumerate(point1_reader):
1776#            print 'i',i,'row',row
1777            # note the 'hole' (element 1) below - skip the new 'hours' field
1778            line.append([float(row[0]),float(row[2]),float(row[3])])
1779            #print 'line',line[i],'point1',point1_answers_array[i]
1780            assert num.allclose(line[i], point1_answers_array[i])
1781
1782        point2_answers_array = [[0.0,1.0,-0.5], [2.0,10.0,-0.5]]
1783        point2_filename = 'gauge_point2.csv' 
1784        point2_handle = file(point2_filename)
1785        point2_reader = reader(point2_handle)
1786        point2_reader.next()
1787                       
1788        line=[]
1789        for i,row in enumerate(point2_reader):
1790#            print 'i',i,'row',row
1791            # note the 'hole' (element 1) below - skip the new 'hours' field
1792            line.append([float(row[0]),float(row[2]),float(row[3])])
1793#            print 'line',line[i],'point1',point1_answers_array[i]
1794            assert num.allclose(line[i], point2_answers_array[i])
1795                         
1796        # clean up
1797        point1_handle.close()
1798        point2_handle.close()
1799        #print "sww.filename",sww.filename
1800        os.remove(sww.filename)
1801        os.remove(points_file)
1802        os.remove(point1_filename)
1803        os.remove(point2_filename)
1804
1805
1806    def test_sww2csv_gauges2(self):
1807
1808        def elevation_function(x, y):
1809            return -x
1810       
1811        """Most of this test was copied from test_interpolate
1812        test_interpole_sww2csv
1813       
1814        This is testing the gauge_sww2csv function, by creating a sww file and
1815        then exporting the gauges and checking the results.
1816       
1817        This is the same as sww2csv_gauges except set domain.set_starttime to 5.
1818        Therefore testing the storing of the absolute time in the csv files
1819        """
1820       
1821        # Create mesh
1822        mesh_file = tempfile.mktemp(".tsh")   
1823        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1824        m = Mesh()
1825        m.add_vertices(points)
1826        m.auto_segment()
1827        m.generate_mesh(verbose=False)
1828        m.export_mesh_file(mesh_file)
1829       
1830        # Create shallow water domain
1831        domain = Domain(mesh_file)
1832        os.remove(mesh_file)
1833       
1834        domain.default_order=2
1835
1836        # This test was made before tight_slope_limiters were introduced
1837        # Since were are testing interpolation values this is OK
1838        domain.tight_slope_limiters = 0         
1839
1840        # Set some field values
1841        domain.set_quantity('elevation', elevation_function)
1842        domain.set_quantity('friction', 0.03)
1843        domain.set_quantity('xmomentum', 3.0)
1844        domain.set_quantity('ymomentum', 4.0)
1845        domain.set_starttime(5)
1846
1847        ######################
1848        # Boundary conditions
1849        B = Transmissive_boundary(domain)
1850        domain.set_boundary( {'exterior': B})
1851
1852        # This call mangles the stage values.
1853        domain.distribute_to_vertices_and_edges()
1854        domain.set_quantity('stage', 1.0)
1855       
1856
1857
1858        domain.set_name('datatest' + str(time.time()))
1859        domain.format = 'sww'
1860        domain.smooth = True
1861        domain.reduction = mean
1862
1863        sww = get_dataobject(domain)
1864        sww.store_connectivity()
1865        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1866        domain.set_quantity('stage', 10.0) # This is automatically limited
1867        # so it will not be less than the elevation
1868        domain.time = 2.
1869        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1870
1871        # test the function
1872        points = [[5.0,1.],[0.5,2.]]
1873
1874        points_file = tempfile.mktemp(".csv")
1875#        points_file = 'test_point.csv'
1876        file_id = open(points_file,"w")
1877        file_id.write("name, easting, northing, elevation \n\
1878point1, 5.0, 1.0, 3.0\n\
1879point2, 0.5, 2.0, 9.0\n")
1880        file_id.close()
1881
1882       
1883        sww2csv_gauges(sww.filename, 
1884                            points_file,
1885                            verbose=False,
1886                            use_cache=False)
1887
1888#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1889        point1_answers_array = [[5.0,5.0/3600.,1.0,6.0,-5.0,3.0,4.0], [7.0,7.0/3600.,10.0,15.0,-5.0,3.0,4.0]]
1890        point1_filename = 'gauge_point1.csv'
1891        point1_handle = file(point1_filename)
1892        point1_reader = reader(point1_handle)
1893        point1_reader.next()
1894
1895        line=[]
1896        for i,row in enumerate(point1_reader):
1897            #print 'i',i,'row',row
1898            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1899                         float(row[4]), float(row[5]), float(row[6])])
1900            #print 'assert line',line[i],'point1',point1_answers_array[i]
1901            assert num.allclose(line[i], point1_answers_array[i])
1902
1903        point2_answers_array = [[5.0,5.0/3600.,1.0,1.5,-0.5,3.0,4.0], [7.0,7.0/3600.,10.0,10.5,-0.5,3.0,4.0]]
1904        point2_filename = 'gauge_point2.csv' 
1905        point2_handle = file(point2_filename)
1906        point2_reader = reader(point2_handle)
1907        point2_reader.next()
1908                       
1909        line=[]
1910        for i,row in enumerate(point2_reader):
1911            #print 'i',i,'row',row
1912            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1913                         float(row[4]),float(row[5]), float(row[6])])
1914            #print 'assert line',line[i],'point1',point1_answers_array[i]
1915            assert num.allclose(line[i], point2_answers_array[i])
1916                         
1917        # clean up
1918        point1_handle.close()
1919        point2_handle.close()
1920        #print "sww.filename",sww.filename
1921        os.remove(sww.filename)
1922        os.remove(points_file)
1923        os.remove(point1_filename)
1924        os.remove(point2_filename)
1925
1926
1927    def test_greens_law(self):
1928
1929        from math import sqrt
1930       
1931        d1 = 80.0
1932        d2 = 20.0
1933        h1 = 1.0
1934        h2 = greens_law(d1,d2,h1)
1935
1936        assert h2==sqrt(2.0)
1937       
1938    def test_calc_bearings(self):
1939 
1940        from math import atan, degrees
1941        #Test East
1942        uh = 1
1943        vh = 1.e-15
1944        angle = calc_bearing(uh, vh)
1945        if 89 < angle < 91: v=1
1946        assert v==1
1947        #Test West
1948        uh = -1
1949        vh = 1.e-15
1950        angle = calc_bearing(uh, vh)
1951        if 269 < angle < 271: v=1
1952        assert v==1
1953        #Test North
1954        uh = 1.e-15
1955        vh = 1
1956        angle = calc_bearing(uh, vh)
1957        if -1 < angle < 1: v=1
1958        assert v==1
1959        #Test South
1960        uh = 1.e-15
1961        vh = -1
1962        angle = calc_bearing(uh, vh)
1963        if 179 < angle < 181: v=1
1964        assert v==1
1965        #Test South-East
1966        uh = 1
1967        vh = -1
1968        angle = calc_bearing(uh, vh)
1969        if 134 < angle < 136: v=1
1970        assert v==1
1971        #Test North-East
1972        uh = 1
1973        vh = 1
1974        angle = calc_bearing(uh, vh)
1975        if 44 < angle < 46: v=1
1976        assert v==1
1977        #Test South-West
1978        uh = -1
1979        vh = -1
1980        angle = calc_bearing(uh, vh)
1981        if 224 < angle < 226: v=1
1982        assert v==1
1983        #Test North-West
1984        uh = -1
1985        vh = 1
1986        angle = calc_bearing(uh, vh)
1987        if 314 < angle < 316: v=1
1988        assert v==1
1989       
1990
1991#-------------------------------------------------------------
1992
1993if __name__ == "__main__":
1994    suite = unittest.makeSuite(Test_Util, 'test')
1995#    runner = unittest.TextTestRunner(verbosity=2)
1996    runner = unittest.TextTestRunner(verbosity=1)
1997    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.