source: anuga_core/source/anuga/abstract_2d_finite_volumes/test_util.py @ 6318

Last change on this file since 6318 was 6318, checked in by kristy, 15 years ago

Updated test_util.py so that it understands that hours is a new column in the created csv file.
In util.py I took out the print statements.

File size: 64.5 KB
Line 
1#!/usr/bin/env python
2
3
4import unittest
5from math import sqrt, pi
6import tempfile, os
7from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
8
9from anuga.abstract_2d_finite_volumes.util import *
10from anuga.config import epsilon
11from anuga.shallow_water.data_manager import timefile2netcdf,del_dir
12
13from anuga.utilities.numerical_tools import NAN
14
15from sys import platform
16
17from anuga.pmesh.mesh import Mesh
18from anuga.shallow_water import Domain, Transmissive_boundary
19from anuga.shallow_water.data_manager import get_dataobject
20from csv import reader,writer
21import time
22import string
23
24import Numeric as num
25
26
27def test_function(x, y):
28    return x+y
29
30class Test_Util(unittest.TestCase):
31    def setUp(self):
32        pass
33
34    def tearDown(self):
35        pass
36
37
38
39
40    #Geometric
41    #def test_distance(self):
42    #    from anuga.abstract_2d_finite_volumes.util import distance#
43    #
44    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
45    #                     'Distance is wrong!')
46    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
47    #                    'distance is wrong!')
48    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
49    #                    'distance is wrong!')
50    #
51    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
52    #                    'distance is wrong!')
53
54
55    def test_file_function_time1(self):
56        """Test that File function interpolates correctly
57        between given times. No x,y dependency here.
58        """
59
60        #Write file
61        import os, time
62        from anuga.config import time_format
63        from math import sin, pi
64
65        #Typical ASCII file
66        finaltime = 1200
67        filename = 'test_file_function'
68        fid = open(filename + '.txt', 'w')
69        start = time.mktime(time.strptime('2000', '%Y'))
70        dt = 60  #One minute intervals
71        t = 0.0
72        while t <= finaltime:
73            t_string = time.strftime(time_format, time.gmtime(t+start))
74            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
75            t += dt
76
77        fid.close()
78
79        #Convert ASCII file to NetCDF (Which is what we really like!)
80        timefile2netcdf(filename)
81
82
83        #Create file function from time series
84        F = file_function(filename + '.tms',
85                          quantities = ['Attribute0',
86                                        'Attribute1',
87                                        'Attribute2'])
88       
89        #Now try interpolation
90        for i in range(20):
91            t = i*10
92            q = F(t)
93
94            #Exact linear intpolation
95            assert num.allclose(q[0], 2*t)
96            if i%6 == 0:
97                assert num.allclose(q[1], t**2)
98                assert num.allclose(q[2], sin(t*pi/600))
99
100        #Check non-exact
101
102        t = 90 #Halfway between 60 and 120
103        q = F(t)
104        assert num.allclose( (120**2 + 60**2)/2, q[1] )
105        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
106
107
108        t = 100 #Two thirds of the way between between 60 and 120
109        q = F(t)
110        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
111        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
112
113        os.remove(filename + '.txt')
114        os.remove(filename + '.tms')       
115
116
117       
118    def test_spatio_temporal_file_function_basic(self):
119        """Test that spatio temporal file function performs the correct
120        interpolations in both time and space
121        NetCDF version (x,y,t dependency)       
122        """
123        import time
124
125        #Create sww file of simple propagation from left to right
126        #through rectangular domain
127        from shallow_water import Domain, Dirichlet_boundary
128        from mesh_factory import rectangular
129
130        #Create basic mesh and shallow water domain
131        points, vertices, boundary = rectangular(3, 3)
132        domain1 = Domain(points, vertices, boundary)
133
134        from anuga.utilities.numerical_tools import mean
135        domain1.reduction = mean
136        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
137                              # only one value.
138
139        domain1.default_order = 2
140        domain1.store = True
141        domain1.set_datadir('.')
142        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
143        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
144
145        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
146        domain1.set_quantity('elevation', 0)
147        domain1.set_quantity('friction', 0)
148        domain1.set_quantity('stage', 0)
149
150        # Boundary conditions
151        B0 = Dirichlet_boundary([0,0,0])
152        B6 = Dirichlet_boundary([0.6,0,0])
153        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
154        domain1.check_integrity()
155
156        finaltime = 8
157        #Evolution
158        t0 = -1
159        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
160            #print 'Timesteps: %.16f, %.16f' %(t0, t)
161            #if t == t0:
162            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
163            #   raise msg
164            t0 = t
165             
166            #domain1.write_time()
167
168
169        #Now read data from sww and check
170        from Scientific.IO.NetCDF import NetCDFFile
171        filename = domain1.get_name() + '.' + domain1.format
172        fid = NetCDFFile(filename)
173
174        x = fid.variables['x'][:]
175        y = fid.variables['y'][:]
176        stage = fid.variables['stage'][:]
177        xmomentum = fid.variables['xmomentum'][:]
178        ymomentum = fid.variables['ymomentum'][:]
179        time = fid.variables['time'][:]
180
181        #Take stage vertex values at last timestep on diagonal
182        #Diagonal is identified by vertices: 0, 5, 10, 15
183
184        last_time_index = len(time)-1 #Last last_time_index
185        d_stage = num.reshape(num.take(stage[last_time_index, :], [0,5,10,15]), (4,1))
186        d_uh = num.reshape(num.take(xmomentum[last_time_index, :], [0,5,10,15]), (4,1))
187        d_vh = num.reshape(num.take(ymomentum[last_time_index, :], [0,5,10,15]), (4,1))
188        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
189
190        #Reference interpolated values at midpoints on diagonal at
191        #this timestep are
192        r0 = (D[0] + D[1])/2
193        r1 = (D[1] + D[2])/2
194        r2 = (D[2] + D[3])/2
195
196        #And the midpoints are found now
197        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15])
198        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15])
199
200        diag = num.concatenate( (Dx, Dy), axis=1)
201        d_midpoints = (diag[1:] + diag[:-1])/2
202
203        #Let us see if the file function can find the correct
204        #values at the midpoints at the last timestep:
205        f = file_function(filename, domain1,
206                          interpolation_points = d_midpoints)
207
208        T = f.get_time()
209        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
210        assert not T[-1] == T[-2], msg
211        t = time[last_time_index]
212        q = f(t, point_id=0); assert num.allclose(r0, q)
213        q = f(t, point_id=1); assert num.allclose(r1, q)
214        q = f(t, point_id=2); assert num.allclose(r2, q)
215
216
217        ##################
218        #Now do the same for the first timestep
219
220        timestep = 0 #First timestep
221        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
222        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
223        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
224        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
225
226        #Reference interpolated values at midpoints on diagonal at
227        #this timestep are
228        r0 = (D[0] + D[1])/2
229        r1 = (D[1] + D[2])/2
230        r2 = (D[2] + D[3])/2
231
232        #Let us see if the file function can find the correct
233        #values
234        q = f(0, point_id=0); assert num.allclose(r0, q)
235        q = f(0, point_id=1); assert num.allclose(r1, q)
236        q = f(0, point_id=2); assert num.allclose(r2, q)
237
238
239        ##################
240        #Now do it again for a timestep in the middle
241
242        timestep = 33
243        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
244        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
245        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
246        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
247
248        #Reference interpolated values at midpoints on diagonal at
249        #this timestep are
250        r0 = (D[0] + D[1])/2
251        r1 = (D[1] + D[2])/2
252        r2 = (D[2] + D[3])/2
253
254        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
255        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
256        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
257
258
259        ##################
260        #Now check temporal interpolation
261        #Halfway between timestep 15 and 16
262
263        timestep = 15
264        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
265        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
266        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
267        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
268
269        #Reference interpolated values at midpoints on diagonal at
270        #this timestep are
271        r0_0 = (D[0] + D[1])/2
272        r1_0 = (D[1] + D[2])/2
273        r2_0 = (D[2] + D[3])/2
274
275        #
276        timestep = 16
277        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
278        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
279        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
280        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
281
282        #Reference interpolated values at midpoints on diagonal at
283        #this timestep are
284        r0_1 = (D[0] + D[1])/2
285        r1_1 = (D[1] + D[2])/2
286        r2_1 = (D[2] + D[3])/2
287
288        # The reference values are
289        r0 = (r0_0 + r0_1)/2
290        r1 = (r1_0 + r1_1)/2
291        r2 = (r2_0 + r2_1)/2
292
293        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
294        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
295        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
296
297        ##################
298        #Finally check interpolation 2 thirds of the way
299        #between timestep 15 and 16
300
301        # The reference values are
302        r0 = (r0_0 + 2*r0_1)/3
303        r1 = (r1_0 + 2*r1_1)/3
304        r2 = (r2_0 + 2*r2_1)/3
305
306        #And the file function gives
307        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
308        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
309        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
310
311        fid.close()
312        import os
313        os.remove(filename)
314
315
316
317    def test_spatio_temporal_file_function_different_origin(self):
318        """Test that spatio temporal file function performs the correct
319        interpolations in both time and space where space is offset by
320        xllcorner and yllcorner
321        NetCDF version (x,y,t dependency)       
322        """
323        import time
324
325        #Create sww file of simple propagation from left to right
326        #through rectangular domain
327        from shallow_water import Domain, Dirichlet_boundary
328        from mesh_factory import rectangular
329
330
331        from anuga.coordinate_transforms.geo_reference import Geo_reference
332        xllcorner = 2048
333        yllcorner = 11000
334        zone = 2
335
336        #Create basic mesh and shallow water domain
337        points, vertices, boundary = rectangular(3, 3)
338        domain1 = Domain(points, vertices, boundary,
339                         geo_reference = Geo_reference(xllcorner = xllcorner,
340                                                       yllcorner = yllcorner))
341       
342
343        from anuga.utilities.numerical_tools import mean       
344        domain1.reduction = mean
345        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
346                              # only one value.
347
348        domain1.default_order = 2
349        domain1.store = True
350        domain1.set_datadir('.')
351        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
352        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
353
354        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
355        domain1.set_quantity('elevation', 0)
356        domain1.set_quantity('friction', 0)
357        domain1.set_quantity('stage', 0)
358
359        # Boundary conditions
360        B0 = Dirichlet_boundary([0,0,0])
361        B6 = Dirichlet_boundary([0.6,0,0])
362        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
363        domain1.check_integrity()
364
365        finaltime = 8
366        #Evolution
367        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
368            pass
369            #domain1.write_time()
370
371
372        #Now read data from sww and check
373        from Scientific.IO.NetCDF import NetCDFFile
374        filename = domain1.get_name() + '.' + domain1.format
375        fid = NetCDFFile(filename)
376
377        x = fid.variables['x'][:]
378        y = fid.variables['y'][:]
379        stage = fid.variables['stage'][:]
380        xmomentum = fid.variables['xmomentum'][:]
381        ymomentum = fid.variables['ymomentum'][:]
382        time = fid.variables['time'][:]
383
384        #Take stage vertex values at last timestep on diagonal
385        #Diagonal is identified by vertices: 0, 5, 10, 15
386
387        last_time_index = len(time)-1 #Last last_time_index     
388        d_stage = num.reshape(num.take(stage[last_time_index, :], [0,5,10,15]), (4,1))
389        d_uh = num.reshape(num.take(xmomentum[last_time_index, :], [0,5,10,15]), (4,1))
390        d_vh = num.reshape(num.take(ymomentum[last_time_index, :], [0,5,10,15]), (4,1))
391        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
392
393        #Reference interpolated values at midpoints on diagonal at
394        #this timestep are
395        r0 = (D[0] + D[1])/2
396        r1 = (D[1] + D[2])/2
397        r2 = (D[2] + D[3])/2
398
399        #And the midpoints are found now
400        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15])
401        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15])
402
403        diag = num.concatenate( (Dx, Dy), axis=1)
404        d_midpoints = (diag[1:] + diag[:-1])/2
405
406
407        #Adjust for georef - make interpolation points absolute
408        d_midpoints[:,0] += xllcorner
409        d_midpoints[:,1] += yllcorner               
410
411        #Let us see if the file function can find the correct
412        #values at the midpoints at the last timestep:
413        f = file_function(filename, domain1,
414                          interpolation_points = d_midpoints)
415
416        t = time[last_time_index]                         
417        q = f(t, point_id=0); assert num.allclose(r0, q)
418        q = f(t, point_id=1); assert num.allclose(r1, q)
419        q = f(t, point_id=2); assert num.allclose(r2, q)
420
421
422        ##################
423        #Now do the same for the first timestep
424
425        timestep = 0 #First timestep
426        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
427        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
428        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
429        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
430
431        #Reference interpolated values at midpoints on diagonal at
432        #this timestep are
433        r0 = (D[0] + D[1])/2
434        r1 = (D[1] + D[2])/2
435        r2 = (D[2] + D[3])/2
436
437        #Let us see if the file function can find the correct
438        #values
439        q = f(0, point_id=0); assert num.allclose(r0, q)
440        q = f(0, point_id=1); assert num.allclose(r1, q)
441        q = f(0, point_id=2); assert num.allclose(r2, q)
442
443
444        ##################
445        #Now do it again for a timestep in the middle
446
447        timestep = 33
448        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
449        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
450        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
451        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
452
453        #Reference interpolated values at midpoints on diagonal at
454        #this timestep are
455        r0 = (D[0] + D[1])/2
456        r1 = (D[1] + D[2])/2
457        r2 = (D[2] + D[3])/2
458
459        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
460        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
461        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
462
463
464        ##################
465        #Now check temporal interpolation
466        #Halfway between timestep 15 and 16
467
468        timestep = 15
469        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
470        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
471        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
472        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
473
474        #Reference interpolated values at midpoints on diagonal at
475        #this timestep are
476        r0_0 = (D[0] + D[1])/2
477        r1_0 = (D[1] + D[2])/2
478        r2_0 = (D[2] + D[3])/2
479
480        #
481        timestep = 16
482        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
483        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
484        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
485        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
486
487        #Reference interpolated values at midpoints on diagonal at
488        #this timestep are
489        r0_1 = (D[0] + D[1])/2
490        r1_1 = (D[1] + D[2])/2
491        r2_1 = (D[2] + D[3])/2
492
493        # The reference values are
494        r0 = (r0_0 + r0_1)/2
495        r1 = (r1_0 + r1_1)/2
496        r2 = (r2_0 + r2_1)/2
497
498        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
499        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
500        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
501
502        ##################
503        #Finally check interpolation 2 thirds of the way
504        #between timestep 15 and 16
505
506        # The reference values are
507        r0 = (r0_0 + 2*r0_1)/3
508        r1 = (r1_0 + 2*r1_1)/3
509        r2 = (r2_0 + 2*r2_1)/3
510
511        #And the file function gives
512        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
513        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
514        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
515
516        fid.close()
517        import os
518        os.remove(filename)
519
520       
521
522
523    def test_spatio_temporal_file_function_time(self):
524        """Test that File function interpolates correctly
525        between given times.
526        NetCDF version (x,y,t dependency)
527        """
528
529        #Create NetCDF (sww) file to be read
530        # x: 0, 5, 10, 15
531        # y: -20, -10, 0, 10
532        # t: 0, 60, 120, ...., 1200
533        #
534        # test quantities (arbitrary but non-trivial expressions):
535        #
536        #   stage     = 3*x - y**2 + 2*t
537        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
538        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
539
540        #NOTE: Nice test that may render some of the others redundant.
541
542        import os, time
543        from anuga.config import time_format
544        from mesh_factory import rectangular
545        from shallow_water import Domain
546        import anuga.shallow_water.data_manager
547
548        finaltime = 1200
549        filename = 'test_file_function'
550
551        #Create a domain to hold test grid
552        #(0:15, -20:10)
553        points, vertices, boundary =\
554                rectangular(4, 4, 15, 30, origin = (0, -20))
555        #print "points", points
556
557        #print 'Number of elements', len(vertices)
558        domain = Domain(points, vertices, boundary)
559        domain.smooth = False
560        domain.default_order = 2
561        domain.set_datadir('.')
562        domain.set_name(filename)
563        domain.store = True
564        domain.format = 'sww'   #Native netcdf visualisation format
565
566        #print points
567        start = time.mktime(time.strptime('2000', '%Y'))
568        domain.starttime = start
569
570
571        #Store structure
572        domain.initialise_storage()
573
574        #Compute artificial time steps and store
575        dt = 60  #One minute intervals
576        t = 0.0
577        while t <= finaltime:
578            #Compute quantities
579            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
580            domain.set_quantity('stage', f1)
581
582            f2 = lambda x,y: x+y+t**2
583            domain.set_quantity('xmomentum', f2)
584
585            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
586            domain.set_quantity('ymomentum', f3)
587
588            #Store and advance time
589            domain.time = t
590            domain.store_timestep(domain.conserved_quantities)
591            t += dt
592
593
594        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
595     
596        #Deliberately set domain.starttime to too early
597        domain.starttime = start - 1
598
599        #Create file function
600        F = file_function(filename + '.sww', domain,
601                          quantities = domain.conserved_quantities,
602                          interpolation_points = interpolation_points)
603
604        #Check that FF updates fixes domain starttime
605        assert num.allclose(domain.starttime, start)
606
607        #Check that domain.starttime isn't updated if later
608        domain.starttime = start + 1
609        F = file_function(filename + '.sww', domain,
610                          quantities = domain.conserved_quantities,
611                          interpolation_points = interpolation_points)
612        assert num.allclose(domain.starttime, start+1)
613        domain.starttime = start
614
615
616        #Check linear interpolation in time
617        F = file_function(filename + '.sww', domain,
618                          quantities = domain.conserved_quantities,
619                          interpolation_points = interpolation_points)               
620        for id in range(len(interpolation_points)):
621            x = interpolation_points[id][0]
622            y = interpolation_points[id][1]
623
624            for i in range(20):
625                t = i*10
626                k = i%6
627
628                if k == 0:
629                    q0 = F(t, point_id=id)
630                    q1 = F(t+60, point_id=id)
631
632                if q0 == NAN:
633                    actual = q0
634                else:
635                    actual = (k*q1 + (6-k)*q0)/6
636                q = F(t, point_id=id)
637                #print i, k, t, q
638                #print ' ', q0
639                #print ' ', q1
640                #print "q",q
641                #print "actual", actual
642                #print
643                if q0 == NAN:
644                     self.failUnless( q == actual, 'Fail!')
645                else:
646                    assert num.allclose(q, actual)
647
648
649        #Another check of linear interpolation in time
650        for id in range(len(interpolation_points)):
651            q60 = F(60, point_id=id)
652            q120 = F(120, point_id=id)
653
654            t = 90 #Halfway between 60 and 120
655            q = F(t, point_id=id)
656            assert num.allclose( (q120+q60)/2, q )
657
658            t = 100 #Two thirds of the way between between 60 and 120
659            q = F(t, point_id=id)
660            assert num.allclose(q60/3 + 2*q120/3, q)
661
662
663
664        #Check that domain.starttime isn't updated if later than file starttime but earlier
665        #than file end time
666        delta = 23
667        domain.starttime = start + delta
668        F = file_function(filename + '.sww', domain,
669                          quantities = domain.conserved_quantities,
670                          interpolation_points = interpolation_points)
671        assert num.allclose(domain.starttime, start+delta)
672
673
674
675
676        #Now try interpolation with delta offset
677        for id in range(len(interpolation_points)):           
678            x = interpolation_points[id][0]
679            y = interpolation_points[id][1]
680
681            for i in range(20):
682                t = i*10
683                k = i%6
684
685                if k == 0:
686                    q0 = F(t-delta, point_id=id)
687                    q1 = F(t+60-delta, point_id=id)
688
689                q = F(t-delta, point_id=id)
690                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
691
692
693        os.remove(filename + '.sww')
694
695
696
697    def Xtest_spatio_temporal_file_function_time(self):
698        # FIXME: This passes but needs some TLC
699        # Test that File function interpolates correctly
700        # When some points are outside the mesh
701
702        import os, time
703        from anuga.config import time_format
704        from mesh_factory import rectangular
705        from shallow_water import Domain
706        import anuga.shallow_water.data_manager 
707        from anuga.pmesh.mesh_interface import create_mesh_from_regions
708        finaltime = 1200
709       
710        filename = tempfile.mktemp()
711        #print "filename",filename
712        filename = 'test_file_function'
713
714        meshfilename = tempfile.mktemp(".tsh")
715
716        boundary_tags = {'walls':[0,1],'bom':[2]}
717       
718        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
719       
720        create_mesh_from_regions(polygon_absolute,
721                                 boundary_tags,
722                                 10000000,
723                                 filename=meshfilename)
724        domain = Domain(mesh_filename=meshfilename)
725        domain.smooth = False
726        domain.default_order = 2
727        domain.set_datadir('.')
728        domain.set_name(filename)
729        domain.store = True
730        domain.format = 'sww'   #Native netcdf visualisation format
731
732        #print points
733        start = time.mktime(time.strptime('2000', '%Y'))
734        domain.starttime = start
735       
736
737        #Store structure
738        domain.initialise_storage()
739
740        #Compute artificial time steps and store
741        dt = 60  #One minute intervals
742        t = 0.0
743        while t <= finaltime:
744            #Compute quantities
745            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
746            domain.set_quantity('stage', f1)
747
748            f2 = lambda x,y: x+y+t**2
749            domain.set_quantity('xmomentum', f2)
750
751            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
752            domain.set_quantity('ymomentum', f3)
753
754            #Store and advance time
755            domain.time = t
756            domain.store_timestep(domain.conserved_quantities)
757            t += dt
758
759        interpolation_points = [[1,0]]
760        interpolation_points = [[100,1000]]
761       
762        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
763                                [78787,78787],[7878,3432]]
764           
765        #Deliberately set domain.starttime to too early
766        domain.starttime = start - 1
767
768        #Create file function
769        F = file_function(filename + '.sww', domain,
770                          quantities = domain.conserved_quantities,
771                          interpolation_points = interpolation_points)
772
773        #Check that FF updates fixes domain starttime
774        assert num.allclose(domain.starttime, start)
775
776        #Check that domain.starttime isn't updated if later
777        domain.starttime = start + 1
778        F = file_function(filename + '.sww', domain,
779                          quantities = domain.conserved_quantities,
780                          interpolation_points = interpolation_points)
781        assert num.allclose(domain.starttime, start+1)
782        domain.starttime = start
783
784
785        #Check linear interpolation in time
786        # checking points inside and outside the mesh
787        F = file_function(filename + '.sww', domain,
788                          quantities = domain.conserved_quantities,
789                          interpolation_points = interpolation_points)
790       
791        for id in range(len(interpolation_points)):
792            x = interpolation_points[id][0]
793            y = interpolation_points[id][1]
794
795            for i in range(20):
796                t = i*10
797                k = i%6
798
799                if k == 0:
800                    q0 = F(t, point_id=id)
801                    q1 = F(t+60, point_id=id)
802
803                if q0 == NAN:
804                    actual = q0
805                else:
806                    actual = (k*q1 + (6-k)*q0)/6
807                q = F(t, point_id=id)
808                #print i, k, t, q
809                #print ' ', q0
810                #print ' ', q1
811                #print "q",q
812                #print "actual", actual
813                #print
814                if q0 == NAN:
815                     self.failUnless( q == actual, 'Fail!')
816                else:
817                    assert num.allclose(q, actual)
818
819        # now lets check points inside the mesh
820        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
821        interpolation_points = [[10,-12.5]]
822           
823        print "len(interpolation_points)",len(interpolation_points) 
824        F = file_function(filename + '.sww', domain,
825                          quantities = domain.conserved_quantities,
826                          interpolation_points = interpolation_points)
827
828        domain.starttime = start
829
830
831        #Check linear interpolation in time
832        F = file_function(filename + '.sww', domain,
833                          quantities = domain.conserved_quantities,
834                          interpolation_points = interpolation_points)               
835        for id in range(len(interpolation_points)):
836            x = interpolation_points[id][0]
837            y = interpolation_points[id][1]
838
839            for i in range(20):
840                t = i*10
841                k = i%6
842
843                if k == 0:
844                    q0 = F(t, point_id=id)
845                    q1 = F(t+60, point_id=id)
846
847                if q0 == NAN:
848                    actual = q0
849                else:
850                    actual = (k*q1 + (6-k)*q0)/6
851                q = F(t, point_id=id)
852                print "############"
853                print "id, x, y ", id, x, y #k, t, q
854                print "t", t
855                #print ' ', q0
856                #print ' ', q1
857                print "q",q
858                print "actual", actual
859                #print
860                if q0 == NAN:
861                     self.failUnless( q == actual, 'Fail!')
862                else:
863                    assert num.allclose(q, actual)
864
865
866        #Another check of linear interpolation in time
867        for id in range(len(interpolation_points)):
868            q60 = F(60, point_id=id)
869            q120 = F(120, point_id=id)
870
871            t = 90 #Halfway between 60 and 120
872            q = F(t, point_id=id)
873            assert num.allclose( (q120+q60)/2, q )
874
875            t = 100 #Two thirds of the way between between 60 and 120
876            q = F(t, point_id=id)
877            assert num.allclose(q60/3 + 2*q120/3, q)
878
879
880
881        #Check that domain.starttime isn't updated if later than file starttime but earlier
882        #than file end time
883        delta = 23
884        domain.starttime = start + delta
885        F = file_function(filename + '.sww', domain,
886                          quantities = domain.conserved_quantities,
887                          interpolation_points = interpolation_points)
888        assert num.allclose(domain.starttime, start+delta)
889
890
891
892
893        #Now try interpolation with delta offset
894        for id in range(len(interpolation_points)):           
895            x = interpolation_points[id][0]
896            y = interpolation_points[id][1]
897
898            for i in range(20):
899                t = i*10
900                k = i%6
901
902                if k == 0:
903                    q0 = F(t-delta, point_id=id)
904                    q1 = F(t+60-delta, point_id=id)
905
906                q = F(t-delta, point_id=id)
907                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
908
909
910        os.remove(filename + '.sww')
911
912    def test_file_function_time_with_domain(self):
913        """Test that File function interpolates correctly
914        between given times. No x,y dependency here.
915        Use domain with starttime
916        """
917
918        #Write file
919        import os, time, calendar
920        from anuga.config import time_format
921        from math import sin, pi
922        from domain import Domain
923
924        finaltime = 1200
925        filename = 'test_file_function'
926        fid = open(filename + '.txt', 'w')
927        start = time.mktime(time.strptime('2000', '%Y'))
928        dt = 60  #One minute intervals
929        t = 0.0
930        while t <= finaltime:
931            t_string = time.strftime(time_format, time.gmtime(t+start))
932            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
933            t += dt
934
935        fid.close()
936
937
938        #Convert ASCII file to NetCDF (Which is what we really like!)
939        timefile2netcdf(filename)
940
941
942
943        a = [0.0, 0.0]
944        b = [4.0, 0.0]
945        c = [0.0, 3.0]
946
947        points = [a, b, c]
948        vertices = [[0,1,2]]
949        domain = Domain(points, vertices)
950
951        # Check that domain.starttime is updated if non-existing
952        F = file_function(filename + '.tms',
953                          domain,
954                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
955        assert num.allclose(domain.starttime, start)
956
957        # Check that domain.starttime is updated if too early
958        domain.starttime = start - 1
959        F = file_function(filename + '.tms',
960                          domain,
961                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
962        assert num.allclose(domain.starttime, start)
963
964        # Check that domain.starttime isn't updated if later
965        domain.starttime = start + 1
966        F = file_function(filename + '.tms',
967                          domain,
968                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
969        assert num.allclose(domain.starttime, start+1)
970
971        domain.starttime = start
972        F = file_function(filename + '.tms',
973                          domain,
974                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
975                          use_cache=True)
976       
977
978        #print F.precomputed_values
979        #print 'F(60)', F(60)
980       
981        #Now try interpolation
982        for i in range(20):
983            t = i*10
984            q = F(t)
985
986            #Exact linear intpolation
987            assert num.allclose(q[0], 2*t)
988            if i%6 == 0:
989                assert num.allclose(q[1], t**2)
990                assert num.allclose(q[2], sin(t*pi/600))
991
992        #Check non-exact
993
994        t = 90 #Halfway between 60 and 120
995        q = F(t)
996        assert num.allclose( (120**2 + 60**2)/2, q[1] )
997        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
998
999
1000        t = 100 #Two thirds of the way between between 60 and 120
1001        q = F(t)
1002        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1003        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1004
1005        os.remove(filename + '.tms')
1006        os.remove(filename + '.txt')       
1007
1008    def test_file_function_time_with_domain_different_start(self):
1009        """Test that File function interpolates correctly
1010        between given times. No x,y dependency here.
1011        Use domain with a starttime later than that of file
1012
1013        ASCII version
1014        """
1015
1016        #Write file
1017        import os, time, calendar
1018        from anuga.config import time_format
1019        from math import sin, pi
1020        from domain import Domain
1021
1022        finaltime = 1200
1023        filename = 'test_file_function'
1024        fid = open(filename + '.txt', 'w')
1025        start = time.mktime(time.strptime('2000', '%Y'))
1026        dt = 60  #One minute intervals
1027        t = 0.0
1028        while t <= finaltime:
1029            t_string = time.strftime(time_format, time.gmtime(t+start))
1030            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1031            t += dt
1032
1033        fid.close()
1034
1035        #Convert ASCII file to NetCDF (Which is what we really like!)
1036        timefile2netcdf(filename)       
1037
1038        a = [0.0, 0.0]
1039        b = [4.0, 0.0]
1040        c = [0.0, 3.0]
1041
1042        points = [a, b, c]
1043        vertices = [[0,1,2]]
1044        domain = Domain(points, vertices)
1045
1046        #Check that domain.starttime isn't updated if later than file starttime but earlier
1047        #than file end time
1048        delta = 23
1049        domain.starttime = start + delta
1050        F = file_function(filename + '.tms', domain,
1051                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1052        assert num.allclose(domain.starttime, start+delta)
1053
1054        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1055                                            277., 337., 397., 457., 517.,
1056                                            577., 637., 697., 757., 817.,
1057                                            877., 937., 997., 1057., 1117.,
1058                                            1177.])
1059
1060
1061        #Now try interpolation with delta offset
1062        for i in range(20):
1063            t = i*10
1064            q = F(t-delta)
1065
1066            #Exact linear intpolation
1067            assert num.allclose(q[0], 2*t)
1068            if i%6 == 0:
1069                assert num.allclose(q[1], t**2)
1070                assert num.allclose(q[2], sin(t*pi/600))
1071
1072        #Check non-exact
1073
1074        t = 90 #Halfway between 60 and 120
1075        q = F(t-delta)
1076        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1077        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1078
1079
1080        t = 100 #Two thirds of the way between between 60 and 120
1081        q = F(t-delta)
1082        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1083        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1084
1085
1086        os.remove(filename + '.tms')
1087        os.remove(filename + '.txt')               
1088
1089       
1090
1091    def test_file_function_time_with_domain_different_start_and_time_limit(self):
1092        """Test that File function interpolates correctly
1093        between given times. No x,y dependency here.
1094        Use domain with a starttime later than that of file
1095
1096        ASCII version
1097       
1098        This test also tests that time can be truncated.
1099        """
1100
1101        # Write file
1102        import os, time, calendar
1103        from anuga.config import time_format
1104        from math import sin, pi
1105        from domain import Domain
1106
1107        finaltime = 1200
1108        filename = 'test_file_function'
1109        fid = open(filename + '.txt', 'w')
1110        start = time.mktime(time.strptime('2000', '%Y'))
1111        dt = 60  #One minute intervals
1112        t = 0.0
1113        while t <= finaltime:
1114            t_string = time.strftime(time_format, time.gmtime(t+start))
1115            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1116            t += dt
1117
1118        fid.close()
1119
1120        # Convert ASCII file to NetCDF (Which is what we really like!)
1121        timefile2netcdf(filename)       
1122
1123        a = [0.0, 0.0]
1124        b = [4.0, 0.0]
1125        c = [0.0, 3.0]
1126
1127        points = [a, b, c]
1128        vertices = [[0,1,2]]
1129        domain = Domain(points, vertices)
1130
1131        # Check that domain.starttime isn't updated if later than file starttime but earlier
1132        # than file end time
1133        delta = 23
1134        domain.starttime = start + delta
1135        time_limit = domain.starttime + 600
1136        F = file_function(filename + '.tms', domain,
1137                          time_limit=time_limit,
1138                          quantities=['Attribute0', 'Attribute1', 'Attribute2'])       
1139        assert num.allclose(domain.starttime, start+delta)
1140
1141        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1142                                            277., 337., 397., 457., 517.,
1143                                            577.])       
1144
1145
1146
1147        # Now try interpolation with delta offset
1148        for i in range(20):
1149            t = i*10
1150            q = F(t-delta)
1151
1152            #Exact linear intpolation
1153            assert num.allclose(q[0], 2*t)
1154            if i%6 == 0:
1155                assert num.allclose(q[1], t**2)
1156                assert num.allclose(q[2], sin(t*pi/600))
1157
1158        # Check non-exact
1159        t = 90 #Halfway between 60 and 120
1160        q = F(t-delta)
1161        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1162        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1163
1164
1165        t = 100 # Two thirds of the way between between 60 and 120
1166        q = F(t-delta)
1167        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1168        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1169
1170
1171        os.remove(filename + '.tms')
1172        os.remove(filename + '.txt')               
1173
1174       
1175       
1176       
1177
1178
1179    def test_apply_expression_to_dictionary(self):
1180
1181        #FIXME: Division is not expected to work for integers.
1182        #This must be caught.
1183        foo = num.array([[1,2,3], [4,5,6]], num.Float)
1184
1185        bar = num.array([[-1,0,5], [6,1,1]], num.Float)                 
1186
1187        D = {'X': foo, 'Y': bar}
1188
1189        Z = apply_expression_to_dictionary('X+Y', D)       
1190        assert num.allclose(Z, foo+bar)
1191
1192        Z = apply_expression_to_dictionary('X*Y', D)       
1193        assert num.allclose(Z, foo*bar)       
1194
1195        Z = apply_expression_to_dictionary('4*X+Y', D)       
1196        assert num.allclose(Z, 4*foo+bar)       
1197
1198        # test zero division is OK
1199        Z = apply_expression_to_dictionary('X/Y', D)
1200        assert num.allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1201
1202        # make an error for zero on zero
1203        # this is really an error in Numeric, SciPy core can handle it
1204        # Z = apply_expression_to_dictionary('0/Y', D)
1205
1206        #Check exceptions
1207        try:
1208            #Wrong name
1209            Z = apply_expression_to_dictionary('4*X+A', D)       
1210        except NameError:
1211            pass
1212        else:
1213            msg = 'Should have raised a NameError Exception'
1214            raise msg
1215
1216
1217        try:
1218            #Wrong order
1219            Z = apply_expression_to_dictionary(D, '4*X+A')       
1220        except AssertionError:
1221            pass
1222        else:
1223            msg = 'Should have raised a AssertionError Exception'
1224            raise msg       
1225       
1226
1227    def test_multiple_replace(self):
1228        """Hard test that checks a true word-by-word simultaneous replace
1229        """
1230       
1231        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1232        exp = '3*x+y + xi'
1233       
1234        new = multiple_replace(exp, D)
1235       
1236        assert new == '3*xi+eta + lam'
1237                         
1238
1239
1240    def test_point_on_line_obsolete(self):
1241        """Test that obsolete call issues appropriate warning"""
1242
1243        #Turn warning into an exception
1244        import warnings
1245        warnings.filterwarnings('error')
1246
1247        try:
1248            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1249        except DeprecationWarning:
1250            pass
1251        else:
1252            msg = 'point_on_line should have issued a DeprecationWarning'
1253            raise Exception(msg)   
1254
1255        warnings.resetwarnings()
1256   
1257    def test_get_revision_number(self):
1258        """test_get_revision_number(self):
1259
1260        Test that revision number can be retrieved.
1261        """
1262        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1263            # I have a known snv incompatability issue,
1264            # so I'm skipping this test.
1265            # FIXME when SVN is upgraded on our clusters
1266            pass
1267        else:   
1268            n = get_revision_number()
1269            assert n>=0
1270
1271
1272       
1273    def test_add_directories(self):
1274       
1275        import tempfile
1276        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1277        directories = ['ja','ne','ke']
1278        kens_dir = add_directories(root_dir, directories)
1279        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1280               sep + 'ke'
1281        assert access(root_dir,F_OK)
1282
1283        add_directories(root_dir, directories)
1284        assert access(root_dir,F_OK)
1285       
1286        #clean up!
1287        os.rmdir(kens_dir)
1288        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1289        os.rmdir(root_dir + sep + 'ja')
1290        os.rmdir(root_dir)
1291
1292    def test_add_directories_bad(self):
1293       
1294        import tempfile
1295        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1296        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1297       
1298        try:
1299            kens_dir = add_directories(root_dir, directories)
1300        except OSError:
1301            pass
1302        else:
1303            msg = 'bad dir name should give OSError'
1304            raise Exception(msg)   
1305           
1306        #clean up!
1307        os.rmdir(root_dir)
1308
1309    def test_check_list(self):
1310
1311        check_list(['stage','xmomentum'])
1312
1313       
1314    def test_add_directories(self):
1315       
1316        import tempfile
1317        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1318        directories = ['ja','ne','ke']
1319        kens_dir = add_directories(root_dir, directories)
1320        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1321               sep + 'ke'
1322        assert access(root_dir,F_OK)
1323
1324        add_directories(root_dir, directories)
1325        assert access(root_dir,F_OK)
1326       
1327        #clean up!
1328        os.rmdir(kens_dir)
1329        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1330        os.rmdir(root_dir + sep + 'ja')
1331        os.rmdir(root_dir)
1332
1333    def test_add_directories_bad(self):
1334       
1335        import tempfile
1336        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1337        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1338       
1339        try:
1340            kens_dir = add_directories(root_dir, directories)
1341        except OSError:
1342            pass
1343        else:
1344            msg = 'bad dir name should give OSError'
1345            raise Exception(msg)   
1346           
1347        #clean up!
1348        os.rmdir(root_dir)
1349
1350    def test_check_list(self):
1351
1352        check_list(['stage','xmomentum'])
1353
1354######
1355# Test the remove_lone_verts() function
1356######
1357       
1358    def test_remove_lone_verts_a(self):
1359        verts = [[0,0],[1,0],[0,1]]
1360        tris = [[0,1,2]]
1361        new_verts, new_tris = remove_lone_verts(verts, tris)
1362        self.failUnless(new_verts.tolist() == verts)
1363        self.failUnless(new_tris.tolist() == tris)
1364
1365    def test_remove_lone_verts_b(self):
1366        verts = [[0,0],[1,0],[0,1],[99,99]]
1367        tris = [[0,1,2]]
1368        new_verts, new_tris = remove_lone_verts(verts, tris)
1369        self.failUnless(new_verts.tolist() == verts[0:3])
1370        self.failUnless(new_tris.tolist() == tris)
1371       
1372    def test_remove_lone_verts_c(self):
1373        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1374        tris = [[1,2,4]]
1375        new_verts, new_tris = remove_lone_verts(verts, tris)
1376        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1377        self.failUnless(new_tris.tolist() == [[0,1,2]])
1378     
1379    def test_remove_lone_verts_d(self):
1380        verts = [[0,0],[1,0],[99,99],[0,1]]
1381        tris = [[0,1,3]]
1382        new_verts, new_tris = remove_lone_verts(verts, tris)
1383        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1384        self.failUnless(new_tris.tolist() == [[0,1,2]])
1385       
1386    def test_remove_lone_verts_e(self):
1387        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1388        tris = [[0,1,2]]
1389        new_verts, new_tris = remove_lone_verts(verts, tris)
1390        self.failUnless(new_verts.tolist() == verts[0:3])
1391        self.failUnless(new_tris.tolist() == tris)
1392     
1393    def test_remove_lone_verts_f(self):
1394        verts = [[0,0],[1,0],[99,99],[0,1],[99,99],[1,1],[99,99]]
1395        tris = [[0,1,3],[0,1,5]]
1396        new_verts, new_tris = remove_lone_verts(verts, tris)
1397        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1],[1,1]])
1398        self.failUnless(new_tris.tolist() == [[0,1,2],[0,1,3]])
1399       
1400######
1401#
1402######
1403       
1404    def test_get_min_max_values(self):
1405       
1406        list=[8,9,6,1,4]
1407        min1, max1 = get_min_max_values(list)
1408       
1409        assert min1==1 
1410        assert max1==9
1411       
1412    def test_get_min_max_values1(self):
1413       
1414        list=[-8,-9,-6,-1,-4]
1415        min1, max1 = get_min_max_values(list)
1416       
1417#        print 'min1,max1',min1,max1
1418        assert min1==-9 
1419        assert max1==-1
1420
1421#    def test_get_min_max_values2(self):
1422#        '''
1423#        The min and max supplied are greater than the ones in the
1424#        list and therefore are the ones returned
1425#        '''
1426#        list=[-8,-9,-6,-1,-4]
1427#        min1, max1 = get_min_max_values(list,-10,10)
1428#       
1429##        print 'min1,max1',min1,max1
1430#        assert min1==-10
1431#        assert max1==10
1432       
1433    def test_make_plots_from_csv_files(self):
1434       
1435        #if sys.platform == 'win32':  #Windows
1436            try: 
1437                import pylab
1438            except ImportError:
1439                #ANUGA don't need pylab to work so the system doesn't
1440                #rely on pylab being installed
1441                return
1442           
1443       
1444            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1445            temp_dir = tempfile.mkdtemp('','figures')
1446    #        print 'temp_dir',temp_dir
1447            fileName = temp_dir+sep+'time_series_3.csv'
1448            file = open(fileName,"w")
1449            file.write("time,stage,speed,momentum,elevation\n\
14501.0, 0, 0, 0, 10 \n\
14512.0, 5, 2, 4, 10 \n\
14523.0, 3, 3, 5, 10 \n")
1453            file.close()
1454   
1455            fileName1 = temp_dir+sep+'time_series_4.csv'
1456            file1 = open(fileName1,"w")
1457            file1.write("time,stage,speed,momentum,elevation\n\
14581.0, 0, 0, 0, 5 \n\
14592.0, -5, -2, -4, 5 \n\
14603.0, -4, -3, -5, 5 \n")
1461            file1.close()
1462   
1463            fileName2 = temp_dir+sep+'time_series_5.csv'
1464            file2 = open(fileName2,"w")
1465            file2.write("time,stage,speed,momentum,elevation\n\
14661.0, 0, 0, 0, 7 \n\
14672.0, 4, -0.45, 57, 7 \n\
14683.0, 6, -0.5, 56, 7 \n")
1469            file2.close()
1470           
1471            dir, name=os.path.split(fileName)
1472            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1473                                  output_dir=temp_dir,
1474                                  base_name='time_series_',
1475                                  plot_numbers=['3-5'],
1476                                  quantities=['speed','stage','momentum'],
1477                                  assess_all_csv_files=True,
1478                                  extra_plot_name='test')
1479           
1480            #print dir+sep+name[:-4]+'_stage_test.png'
1481            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1482            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1483            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1484   
1485            dir1, name1=os.path.split(fileName1)
1486            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1487            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1488            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1489   
1490   
1491            dir2, name2=os.path.split(fileName2)
1492            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1493            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1494            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1495   
1496            del_dir(temp_dir)
1497       
1498
1499    def test_sww2csv_gauges(self):
1500
1501        def elevation_function(x, y):
1502            return -x
1503       
1504        """Most of this test was copied from test_interpolate
1505        test_interpole_sww2csv
1506       
1507        This is testing the gauge_sww2csv function, by creating a sww file and
1508        then exporting the gauges and checking the results.
1509        """
1510       
1511        # Create mesh
1512        mesh_file = tempfile.mktemp(".tsh")   
1513        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1514        m = Mesh()
1515        m.add_vertices(points)
1516        m.auto_segment()
1517        m.generate_mesh(verbose=False)
1518        m.export_mesh_file(mesh_file)
1519       
1520        # Create shallow water domain
1521        domain = Domain(mesh_file)
1522        os.remove(mesh_file)
1523       
1524        domain.default_order=2
1525       
1526        # This test was made before tight_slope_limiters were introduced
1527        # Since were are testing interpolation values this is OK
1528        domain.tight_slope_limiters = 0 
1529       
1530
1531        # Set some field values
1532        domain.set_quantity('elevation', elevation_function)
1533        domain.set_quantity('friction', 0.03)
1534        domain.set_quantity('xmomentum', 3.0)
1535        domain.set_quantity('ymomentum', 4.0)
1536
1537        ######################
1538        # Boundary conditions
1539        B = Transmissive_boundary(domain)
1540        domain.set_boundary( {'exterior': B})
1541
1542        # This call mangles the stage values.
1543        domain.distribute_to_vertices_and_edges()
1544        domain.set_quantity('stage', 1.0)
1545
1546
1547        domain.set_name('datatest' + str(time.time()))
1548        domain.format = 'sww'
1549        domain.smooth = True
1550        domain.reduction = mean
1551
1552
1553        sww = get_dataobject(domain)
1554        sww.store_connectivity()
1555        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1556        domain.set_quantity('stage', 10.0) # This is automatically limited
1557        # so it will not be less than the elevation
1558        domain.time = 2.
1559        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1560
1561        # test the function
1562        points = [[5.0,1.],[0.5,2.]]
1563
1564        points_file = tempfile.mktemp(".csv")
1565#        points_file = 'test_point.csv'
1566        file_id = open(points_file,"w")
1567        file_id.write("name, easting, northing, elevation \n\
1568point1, 5.0, 1.0, 3.0\n\
1569point2, 0.5, 2.0, 9.0\n")
1570        file_id.close()
1571
1572       
1573        sww2csv_gauges(sww.filename, 
1574                       points_file,
1575                       verbose=False,
1576                       use_cache=False)
1577
1578#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1579        point1_answers_array = [[0.0,0.0,1.0,6.0,-5.0,3.0,4.0], [2.0,2.0/3600.,10.0,15.0,-5.0,3.0,4.0]]
1580        point1_filename = 'gauge_point1.csv'
1581        point1_handle = file(point1_filename)
1582        point1_reader = reader(point1_handle)
1583        point1_reader.next()
1584
1585        line=[]
1586        for i,row in enumerate(point1_reader):
1587#            print 'i',i,'row',row
1588            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1589                         float(row[4]),float(row[5]),float(row[6])])
1590#            print 'assert line',line[i],'point1',point1_answers_array[i]
1591            assert num.allclose(line[i], point1_answers_array[i])
1592
1593        point2_answers_array = [[0.0,0.0,1.0,1.5,-0.5,3.0,4.0], [2.0,2.0/3600.,10.0,10.5,-0.5,3.0,4.0]]
1594        point2_filename = 'gauge_point2.csv' 
1595        point2_handle = file(point2_filename)
1596        point2_reader = reader(point2_handle)
1597        point2_reader.next()
1598                       
1599        line=[]
1600        for i,row in enumerate(point2_reader):
1601#            print 'i',i,'row',row
1602            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1603                         float(row[4]),float(row[5]),float(row[6])])
1604#            print 'assert line',line[i],'point1',point1_answers_array[i]
1605            assert num.allclose(line[i], point2_answers_array[i])
1606                         
1607        # clean up
1608        point1_handle.close()
1609        point2_handle.close()
1610        #print "sww.filename",sww.filename
1611        os.remove(sww.filename)
1612        os.remove(points_file)
1613        os.remove(point1_filename)
1614        os.remove(point2_filename)
1615
1616
1617
1618    def test_sww2csv_gauges1(self):
1619        from anuga.pmesh.mesh import Mesh
1620        from anuga.shallow_water import Domain, Transmissive_boundary
1621        from anuga.shallow_water.data_manager import get_dataobject
1622        from csv import reader,writer
1623        import time
1624        import string
1625
1626        def elevation_function(x, y):
1627            return -x
1628       
1629        """Most of this test was copied from test_interpolate
1630        test_interpole_sww2csv
1631       
1632        This is testing the gauge_sww2csv function, by creating a sww file and
1633        then exporting the gauges and checking the results.
1634       
1635        This tests the ablity not to have elevation in the points file and
1636        not store xmomentum and ymomentum
1637        """
1638       
1639        # Create mesh
1640        mesh_file = tempfile.mktemp(".tsh")   
1641        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1642        m = Mesh()
1643        m.add_vertices(points)
1644        m.auto_segment()
1645        m.generate_mesh(verbose=False)
1646        m.export_mesh_file(mesh_file)
1647       
1648        # Create shallow water domain
1649        domain = Domain(mesh_file)
1650        os.remove(mesh_file)
1651       
1652        domain.default_order=2
1653
1654        # Set some field values
1655        domain.set_quantity('elevation', elevation_function)
1656        domain.set_quantity('friction', 0.03)
1657        domain.set_quantity('xmomentum', 3.0)
1658        domain.set_quantity('ymomentum', 4.0)
1659
1660        ######################
1661        # Boundary conditions
1662        B = Transmissive_boundary(domain)
1663        domain.set_boundary( {'exterior': B})
1664
1665        # This call mangles the stage values.
1666        domain.distribute_to_vertices_and_edges()
1667        domain.set_quantity('stage', 1.0)
1668
1669
1670        domain.set_name('datatest' + str(time.time()))
1671        domain.format = 'sww'
1672        domain.smooth = True
1673        domain.reduction = mean
1674
1675        sww = get_dataobject(domain)
1676        sww.store_connectivity()
1677        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1678        domain.set_quantity('stage', 10.0) # This is automatically limited
1679        # so it will not be less than the elevation
1680        domain.time = 2.
1681        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1682
1683        # test the function
1684        points = [[5.0,1.],[0.5,2.]]
1685
1686        points_file = tempfile.mktemp(".csv")
1687#        points_file = 'test_point.csv'
1688        file_id = open(points_file,"w")
1689        file_id.write("name,easting,northing \n\
1690point1, 5.0, 1.0\n\
1691point2, 0.5, 2.0\n")
1692        file_id.close()
1693
1694        sww2csv_gauges(sww.filename, 
1695                            points_file,
1696                            quantities=['stage', 'elevation'],
1697                            use_cache=False,
1698                            verbose=False)
1699
1700        point1_answers_array = [[0.0,1.0,-5.0], [2.0,10.0,-5.0]]
1701        point1_filename = 'gauge_point1.csv'
1702        point1_handle = file(point1_filename)
1703        point1_reader = reader(point1_handle)
1704        point1_reader.next()
1705
1706        line=[]
1707        for i,row in enumerate(point1_reader):
1708#            print 'i',i,'row',row
1709            # note the 'hole' (element 1) below - skip the new 'hours' field
1710            line.append([float(row[0]),float(row[2]),float(row[3])])
1711            #print 'line',line[i],'point1',point1_answers_array[i]
1712            assert num.allclose(line[i], point1_answers_array[i])
1713
1714        point2_answers_array = [[0.0,1.0,-0.5], [2.0,10.0,-0.5]]
1715        point2_filename = 'gauge_point2.csv' 
1716        point2_handle = file(point2_filename)
1717        point2_reader = reader(point2_handle)
1718        point2_reader.next()
1719                       
1720        line=[]
1721        for i,row in enumerate(point2_reader):
1722#            print 'i',i,'row',row
1723            # note the 'hole' (element 1) below - skip the new 'hours' field
1724            line.append([float(row[0]),float(row[2]),float(row[3])])
1725#            print 'line',line[i],'point1',point1_answers_array[i]
1726            assert num.allclose(line[i], point2_answers_array[i])
1727                         
1728        # clean up
1729        point1_handle.close()
1730        point2_handle.close()
1731        #print "sww.filename",sww.filename
1732        os.remove(sww.filename)
1733        os.remove(points_file)
1734        os.remove(point1_filename)
1735        os.remove(point2_filename)
1736
1737
1738    def test_sww2csv_gauges2(self):
1739
1740        def elevation_function(x, y):
1741            return -x
1742       
1743        """Most of this test was copied from test_interpolate
1744        test_interpole_sww2csv
1745       
1746        This is testing the gauge_sww2csv function, by creating a sww file and
1747        then exporting the gauges and checking the results.
1748       
1749        This is the same as sww2csv_gauges except set domain.set_starttime to 5.
1750        Therefore testing the storing of the absolute time in the csv files
1751        """
1752       
1753        # Create mesh
1754        mesh_file = tempfile.mktemp(".tsh")   
1755        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1756        m = Mesh()
1757        m.add_vertices(points)
1758        m.auto_segment()
1759        m.generate_mesh(verbose=False)
1760        m.export_mesh_file(mesh_file)
1761       
1762        # Create shallow water domain
1763        domain = Domain(mesh_file)
1764        os.remove(mesh_file)
1765       
1766        domain.default_order=2
1767
1768        # This test was made before tight_slope_limiters were introduced
1769        # Since were are testing interpolation values this is OK
1770        domain.tight_slope_limiters = 0         
1771
1772        # Set some field values
1773        domain.set_quantity('elevation', elevation_function)
1774        domain.set_quantity('friction', 0.03)
1775        domain.set_quantity('xmomentum', 3.0)
1776        domain.set_quantity('ymomentum', 4.0)
1777        domain.set_starttime(5)
1778
1779        ######################
1780        # Boundary conditions
1781        B = Transmissive_boundary(domain)
1782        domain.set_boundary( {'exterior': B})
1783
1784        # This call mangles the stage values.
1785        domain.distribute_to_vertices_and_edges()
1786        domain.set_quantity('stage', 1.0)
1787       
1788
1789
1790        domain.set_name('datatest' + str(time.time()))
1791        domain.format = 'sww'
1792        domain.smooth = True
1793        domain.reduction = mean
1794
1795        sww = get_dataobject(domain)
1796        sww.store_connectivity()
1797        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1798        domain.set_quantity('stage', 10.0) # This is automatically limited
1799        # so it will not be less than the elevation
1800        domain.time = 2.
1801        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1802
1803        # test the function
1804        points = [[5.0,1.],[0.5,2.]]
1805
1806        points_file = tempfile.mktemp(".csv")
1807#        points_file = 'test_point.csv'
1808        file_id = open(points_file,"w")
1809        file_id.write("name, easting, northing, elevation \n\
1810point1, 5.0, 1.0, 3.0\n\
1811point2, 0.5, 2.0, 9.0\n")
1812        file_id.close()
1813
1814       
1815        sww2csv_gauges(sww.filename, 
1816                            points_file,
1817                            verbose=False,
1818                            use_cache=False)
1819
1820#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1821        point1_answers_array = [[5.0,5.0/3600.,1.0,6.0,-5.0,3.0,4.0], [7.0,7.0/3600.,10.0,15.0,-5.0,3.0,4.0]]
1822        point1_filename = 'gauge_point1.csv'
1823        point1_handle = file(point1_filename)
1824        point1_reader = reader(point1_handle)
1825        point1_reader.next()
1826
1827        line=[]
1828        for i,row in enumerate(point1_reader):
1829            #print 'i',i,'row',row
1830            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1831                         float(row[4]), float(row[5]), float(row[6])])
1832            #print 'assert line',line[i],'point1',point1_answers_array[i]
1833            assert num.allclose(line[i], point1_answers_array[i])
1834
1835        point2_answers_array = [[5.0,5.0/3600.,1.0,1.5,-0.5,3.0,4.0], [7.0,7.0/3600.,10.0,10.5,-0.5,3.0,4.0]]
1836        point2_filename = 'gauge_point2.csv' 
1837        point2_handle = file(point2_filename)
1838        point2_reader = reader(point2_handle)
1839        point2_reader.next()
1840                       
1841        line=[]
1842        for i,row in enumerate(point2_reader):
1843            #print 'i',i,'row',row
1844            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1845                         float(row[4]),float(row[5]), float(row[6])])
1846            #print 'assert line',line[i],'point1',point1_answers_array[i]
1847            assert num.allclose(line[i], point2_answers_array[i])
1848                         
1849        # clean up
1850        point1_handle.close()
1851        point2_handle.close()
1852        #print "sww.filename",sww.filename
1853        os.remove(sww.filename)
1854        os.remove(points_file)
1855        os.remove(point1_filename)
1856        os.remove(point2_filename)
1857
1858
1859    def test_greens_law(self):
1860
1861        from math import sqrt
1862       
1863        d1 = 80.0
1864        d2 = 20.0
1865        h1 = 1.0
1866        h2 = greens_law(d1,d2,h1)
1867
1868        assert h2==sqrt(2.0)
1869       
1870    def test_calc_bearings(self):
1871 
1872        from math import atan, degrees
1873        #Test East
1874        uh = 1
1875        vh = 1.e-15
1876        angle = calc_bearing(uh, vh)
1877        if 89 < angle < 91: v=1
1878        assert v==1
1879        #Test West
1880        uh = -1
1881        vh = 1.e-15
1882        angle = calc_bearing(uh, vh)
1883        if 269 < angle < 271: v=1
1884        assert v==1
1885        #Test North
1886        uh = 1.e-15
1887        vh = 1
1888        angle = calc_bearing(uh, vh)
1889        if -1 < angle < 1: v=1
1890        assert v==1
1891        #Test South
1892        uh = 1.e-15
1893        vh = -1
1894        angle = calc_bearing(uh, vh)
1895        if 179 < angle < 181: v=1
1896        assert v==1
1897        #Test South-East
1898        uh = 1
1899        vh = -1
1900        angle = calc_bearing(uh, vh)
1901        if 134 < angle < 136: v=1
1902        assert v==1
1903        #Test North-East
1904        uh = 1
1905        vh = 1
1906        angle = calc_bearing(uh, vh)
1907        if 44 < angle < 46: v=1
1908        assert v==1
1909        #Test South-West
1910        uh = -1
1911        vh = -1
1912        angle = calc_bearing(uh, vh)
1913        if 224 < angle < 226: v=1
1914        assert v==1
1915        #Test North-West
1916        uh = -1
1917        vh = 1
1918        angle = calc_bearing(uh, vh)
1919        if 314 < angle < 316: v=1
1920        assert v==1
1921
1922 
1923
1924
1925       
1926
1927#-------------------------------------------------------------
1928if __name__ == "__main__":
1929    suite = unittest.makeSuite(Test_Util,'test')
1930#    suite = unittest.makeSuite(Test_Util,'test_sww2csv_gauges')
1931#    runner = unittest.TextTestRunner(verbosity=2)
1932    runner = unittest.TextTestRunner(verbosity=1)
1933    runner.run(suite)
1934
1935
1936
1937
Note: See TracBrowser for help on using the repository browser.