source: anuga_core/source/anuga/abstract_2d_finite_volumes/test_util.py @ 6173

Last change on this file since 6173 was 6173, checked in by ole, 15 years ago

Introduced time_limit in Field_boundary, File_boundary and file_function

File size: 64.1 KB
Line 
1#!/usr/bin/env python
2
3
4import unittest
5from math import sqrt, pi
6import tempfile, os
7from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
8
9from anuga.abstract_2d_finite_volumes.util import *
10from anuga.config import epsilon
11from anuga.shallow_water.data_manager import timefile2netcdf,del_dir
12
13from anuga.utilities.numerical_tools import NAN
14
15from sys import platform
16
17from anuga.pmesh.mesh import Mesh
18from anuga.shallow_water import Domain, Transmissive_boundary
19from anuga.shallow_water.data_manager import get_dataobject
20from csv import reader,writer
21import time
22import string
23
24import Numeric as num
25
26
27def test_function(x, y):
28    return x+y
29
30class Test_Util(unittest.TestCase):
31    def setUp(self):
32        pass
33
34    def tearDown(self):
35        pass
36
37
38
39
40    #Geometric
41    #def test_distance(self):
42    #    from anuga.abstract_2d_finite_volumes.util import distance#
43    #
44    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
45    #                     'Distance is wrong!')
46    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
47    #                    'distance is wrong!')
48    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
49    #                    'distance is wrong!')
50    #
51    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
52    #                    'distance is wrong!')
53
54
55    def test_file_function_time1(self):
56        """Test that File function interpolates correctly
57        between given times. No x,y dependency here.
58        """
59
60        #Write file
61        import os, time
62        from anuga.config import time_format
63        from math import sin, pi
64
65        #Typical ASCII file
66        finaltime = 1200
67        filename = 'test_file_function'
68        fid = open(filename + '.txt', 'w')
69        start = time.mktime(time.strptime('2000', '%Y'))
70        dt = 60  #One minute intervals
71        t = 0.0
72        while t <= finaltime:
73            t_string = time.strftime(time_format, time.gmtime(t+start))
74            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
75            t += dt
76
77        fid.close()
78
79        #Convert ASCII file to NetCDF (Which is what we really like!)
80        timefile2netcdf(filename)
81
82
83        #Create file function from time series
84        F = file_function(filename + '.tms',
85                          quantities = ['Attribute0',
86                                        'Attribute1',
87                                        'Attribute2'])
88       
89        #Now try interpolation
90        for i in range(20):
91            t = i*10
92            q = F(t)
93
94            #Exact linear intpolation
95            assert num.allclose(q[0], 2*t)
96            if i%6 == 0:
97                assert num.allclose(q[1], t**2)
98                assert num.allclose(q[2], sin(t*pi/600))
99
100        #Check non-exact
101
102        t = 90 #Halfway between 60 and 120
103        q = F(t)
104        assert num.allclose( (120**2 + 60**2)/2, q[1] )
105        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
106
107
108        t = 100 #Two thirds of the way between between 60 and 120
109        q = F(t)
110        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
111        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
112
113        os.remove(filename + '.txt')
114        os.remove(filename + '.tms')       
115
116
117       
118    def test_spatio_temporal_file_function_basic(self):
119        """Test that spatio temporal file function performs the correct
120        interpolations in both time and space
121        NetCDF version (x,y,t dependency)       
122        """
123        import time
124
125        #Create sww file of simple propagation from left to right
126        #through rectangular domain
127        from shallow_water import Domain, Dirichlet_boundary
128        from mesh_factory import rectangular
129
130        #Create basic mesh and shallow water domain
131        points, vertices, boundary = rectangular(3, 3)
132        domain1 = Domain(points, vertices, boundary)
133
134        from anuga.utilities.numerical_tools import mean
135        domain1.reduction = mean
136        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
137                              # only one value.
138
139        domain1.default_order = 2
140        domain1.store = True
141        domain1.set_datadir('.')
142        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
143        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
144
145        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
146        domain1.set_quantity('elevation', 0)
147        domain1.set_quantity('friction', 0)
148        domain1.set_quantity('stage', 0)
149
150        # Boundary conditions
151        B0 = Dirichlet_boundary([0,0,0])
152        B6 = Dirichlet_boundary([0.6,0,0])
153        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
154        domain1.check_integrity()
155
156        finaltime = 8
157        #Evolution
158        t0 = -1
159        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
160            #print 'Timesteps: %.16f, %.16f' %(t0, t)
161            #if t == t0:
162            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
163            #   raise msg
164            t0 = t
165             
166            #domain1.write_time()
167
168
169        #Now read data from sww and check
170        from Scientific.IO.NetCDF import NetCDFFile
171        filename = domain1.get_name() + '.' + domain1.format
172        fid = NetCDFFile(filename)
173
174        x = fid.variables['x'][:]
175        y = fid.variables['y'][:]
176        stage = fid.variables['stage'][:]
177        xmomentum = fid.variables['xmomentum'][:]
178        ymomentum = fid.variables['ymomentum'][:]
179        time = fid.variables['time'][:]
180
181        #Take stage vertex values at last timestep on diagonal
182        #Diagonal is identified by vertices: 0, 5, 10, 15
183
184        last_time_index = len(time)-1 #Last last_time_index
185        d_stage = num.reshape(num.take(stage[last_time_index, :], [0,5,10,15]), (4,1))
186        d_uh = num.reshape(num.take(xmomentum[last_time_index, :], [0,5,10,15]), (4,1))
187        d_vh = num.reshape(num.take(ymomentum[last_time_index, :], [0,5,10,15]), (4,1))
188        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
189
190        #Reference interpolated values at midpoints on diagonal at
191        #this timestep are
192        r0 = (D[0] + D[1])/2
193        r1 = (D[1] + D[2])/2
194        r2 = (D[2] + D[3])/2
195
196        #And the midpoints are found now
197        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15])
198        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15])
199
200        diag = num.concatenate( (Dx, Dy), axis=1)
201        d_midpoints = (diag[1:] + diag[:-1])/2
202
203        #Let us see if the file function can find the correct
204        #values at the midpoints at the last timestep:
205        f = file_function(filename, domain1,
206                          interpolation_points = d_midpoints)
207
208        T = f.get_time()
209        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
210        assert not T[-1] == T[-2], msg
211        t = time[last_time_index]
212        q = f(t, point_id=0); assert num.allclose(r0, q)
213        q = f(t, point_id=1); assert num.allclose(r1, q)
214        q = f(t, point_id=2); assert num.allclose(r2, q)
215
216
217        ##################
218        #Now do the same for the first timestep
219
220        timestep = 0 #First timestep
221        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
222        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
223        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
224        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
225
226        #Reference interpolated values at midpoints on diagonal at
227        #this timestep are
228        r0 = (D[0] + D[1])/2
229        r1 = (D[1] + D[2])/2
230        r2 = (D[2] + D[3])/2
231
232        #Let us see if the file function can find the correct
233        #values
234        q = f(0, point_id=0); assert num.allclose(r0, q)
235        q = f(0, point_id=1); assert num.allclose(r1, q)
236        q = f(0, point_id=2); assert num.allclose(r2, q)
237
238
239        ##################
240        #Now do it again for a timestep in the middle
241
242        timestep = 33
243        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
244        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
245        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
246        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
247
248        #Reference interpolated values at midpoints on diagonal at
249        #this timestep are
250        r0 = (D[0] + D[1])/2
251        r1 = (D[1] + D[2])/2
252        r2 = (D[2] + D[3])/2
253
254        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
255        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
256        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
257
258
259        ##################
260        #Now check temporal interpolation
261        #Halfway between timestep 15 and 16
262
263        timestep = 15
264        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
265        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
266        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
267        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
268
269        #Reference interpolated values at midpoints on diagonal at
270        #this timestep are
271        r0_0 = (D[0] + D[1])/2
272        r1_0 = (D[1] + D[2])/2
273        r2_0 = (D[2] + D[3])/2
274
275        #
276        timestep = 16
277        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
278        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
279        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
280        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
281
282        #Reference interpolated values at midpoints on diagonal at
283        #this timestep are
284        r0_1 = (D[0] + D[1])/2
285        r1_1 = (D[1] + D[2])/2
286        r2_1 = (D[2] + D[3])/2
287
288        # The reference values are
289        r0 = (r0_0 + r0_1)/2
290        r1 = (r1_0 + r1_1)/2
291        r2 = (r2_0 + r2_1)/2
292
293        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
294        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
295        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
296
297        ##################
298        #Finally check interpolation 2 thirds of the way
299        #between timestep 15 and 16
300
301        # The reference values are
302        r0 = (r0_0 + 2*r0_1)/3
303        r1 = (r1_0 + 2*r1_1)/3
304        r2 = (r2_0 + 2*r2_1)/3
305
306        #And the file function gives
307        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
308        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
309        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
310
311        fid.close()
312        import os
313        os.remove(filename)
314
315
316
317    def test_spatio_temporal_file_function_different_origin(self):
318        """Test that spatio temporal file function performs the correct
319        interpolations in both time and space where space is offset by
320        xllcorner and yllcorner
321        NetCDF version (x,y,t dependency)       
322        """
323        import time
324
325        #Create sww file of simple propagation from left to right
326        #through rectangular domain
327        from shallow_water import Domain, Dirichlet_boundary
328        from mesh_factory import rectangular
329
330
331        from anuga.coordinate_transforms.geo_reference import Geo_reference
332        xllcorner = 2048
333        yllcorner = 11000
334        zone = 2
335
336        #Create basic mesh and shallow water domain
337        points, vertices, boundary = rectangular(3, 3)
338        domain1 = Domain(points, vertices, boundary,
339                         geo_reference = Geo_reference(xllcorner = xllcorner,
340                                                       yllcorner = yllcorner))
341       
342
343        from anuga.utilities.numerical_tools import mean       
344        domain1.reduction = mean
345        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
346                              # only one value.
347
348        domain1.default_order = 2
349        domain1.store = True
350        domain1.set_datadir('.')
351        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
352        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
353
354        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
355        domain1.set_quantity('elevation', 0)
356        domain1.set_quantity('friction', 0)
357        domain1.set_quantity('stage', 0)
358
359        # Boundary conditions
360        B0 = Dirichlet_boundary([0,0,0])
361        B6 = Dirichlet_boundary([0.6,0,0])
362        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
363        domain1.check_integrity()
364
365        finaltime = 8
366        #Evolution
367        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
368            pass
369            #domain1.write_time()
370
371
372        #Now read data from sww and check
373        from Scientific.IO.NetCDF import NetCDFFile
374        filename = domain1.get_name() + '.' + domain1.format
375        fid = NetCDFFile(filename)
376
377        x = fid.variables['x'][:]
378        y = fid.variables['y'][:]
379        stage = fid.variables['stage'][:]
380        xmomentum = fid.variables['xmomentum'][:]
381        ymomentum = fid.variables['ymomentum'][:]
382        time = fid.variables['time'][:]
383
384        #Take stage vertex values at last timestep on diagonal
385        #Diagonal is identified by vertices: 0, 5, 10, 15
386
387        last_time_index = len(time)-1 #Last last_time_index     
388        d_stage = num.reshape(num.take(stage[last_time_index, :], [0,5,10,15]), (4,1))
389        d_uh = num.reshape(num.take(xmomentum[last_time_index, :], [0,5,10,15]), (4,1))
390        d_vh = num.reshape(num.take(ymomentum[last_time_index, :], [0,5,10,15]), (4,1))
391        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
392
393        #Reference interpolated values at midpoints on diagonal at
394        #this timestep are
395        r0 = (D[0] + D[1])/2
396        r1 = (D[1] + D[2])/2
397        r2 = (D[2] + D[3])/2
398
399        #And the midpoints are found now
400        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15])
401        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15])
402
403        diag = num.concatenate( (Dx, Dy), axis=1)
404        d_midpoints = (diag[1:] + diag[:-1])/2
405
406
407        #Adjust for georef - make interpolation points absolute
408        d_midpoints[:,0] += xllcorner
409        d_midpoints[:,1] += yllcorner               
410
411        #Let us see if the file function can find the correct
412        #values at the midpoints at the last timestep:
413        f = file_function(filename, domain1,
414                          interpolation_points = d_midpoints)
415
416        t = time[last_time_index]                         
417        q = f(t, point_id=0); assert num.allclose(r0, q)
418        q = f(t, point_id=1); assert num.allclose(r1, q)
419        q = f(t, point_id=2); assert num.allclose(r2, q)
420
421
422        ##################
423        #Now do the same for the first timestep
424
425        timestep = 0 #First timestep
426        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
427        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
428        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
429        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
430
431        #Reference interpolated values at midpoints on diagonal at
432        #this timestep are
433        r0 = (D[0] + D[1])/2
434        r1 = (D[1] + D[2])/2
435        r2 = (D[2] + D[3])/2
436
437        #Let us see if the file function can find the correct
438        #values
439        q = f(0, point_id=0); assert num.allclose(r0, q)
440        q = f(0, point_id=1); assert num.allclose(r1, q)
441        q = f(0, point_id=2); assert num.allclose(r2, q)
442
443
444        ##################
445        #Now do it again for a timestep in the middle
446
447        timestep = 33
448        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
449        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
450        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
451        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
452
453        #Reference interpolated values at midpoints on diagonal at
454        #this timestep are
455        r0 = (D[0] + D[1])/2
456        r1 = (D[1] + D[2])/2
457        r2 = (D[2] + D[3])/2
458
459        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
460        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
461        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
462
463
464        ##################
465        #Now check temporal interpolation
466        #Halfway between timestep 15 and 16
467
468        timestep = 15
469        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
470        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
471        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
472        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
473
474        #Reference interpolated values at midpoints on diagonal at
475        #this timestep are
476        r0_0 = (D[0] + D[1])/2
477        r1_0 = (D[1] + D[2])/2
478        r2_0 = (D[2] + D[3])/2
479
480        #
481        timestep = 16
482        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15]), (4,1))
483        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15]), (4,1))
484        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15]), (4,1))
485        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
486
487        #Reference interpolated values at midpoints on diagonal at
488        #this timestep are
489        r0_1 = (D[0] + D[1])/2
490        r1_1 = (D[1] + D[2])/2
491        r2_1 = (D[2] + D[3])/2
492
493        # The reference values are
494        r0 = (r0_0 + r0_1)/2
495        r1 = (r1_0 + r1_1)/2
496        r2 = (r2_0 + r2_1)/2
497
498        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
499        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
500        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
501
502        ##################
503        #Finally check interpolation 2 thirds of the way
504        #between timestep 15 and 16
505
506        # The reference values are
507        r0 = (r0_0 + 2*r0_1)/3
508        r1 = (r1_0 + 2*r1_1)/3
509        r2 = (r2_0 + 2*r2_1)/3
510
511        #And the file function gives
512        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
513        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
514        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
515
516        fid.close()
517        import os
518        os.remove(filename)
519
520       
521
522
523    def test_spatio_temporal_file_function_time(self):
524        """Test that File function interpolates correctly
525        between given times.
526        NetCDF version (x,y,t dependency)
527        """
528
529        #Create NetCDF (sww) file to be read
530        # x: 0, 5, 10, 15
531        # y: -20, -10, 0, 10
532        # t: 0, 60, 120, ...., 1200
533        #
534        # test quantities (arbitrary but non-trivial expressions):
535        #
536        #   stage     = 3*x - y**2 + 2*t
537        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
538        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
539
540        #NOTE: Nice test that may render some of the others redundant.
541
542        import os, time
543        from anuga.config import time_format
544        from mesh_factory import rectangular
545        from shallow_water import Domain
546        import anuga.shallow_water.data_manager
547
548        finaltime = 1200
549        filename = 'test_file_function'
550
551        #Create a domain to hold test grid
552        #(0:15, -20:10)
553        points, vertices, boundary =\
554                rectangular(4, 4, 15, 30, origin = (0, -20))
555        #print "points", points
556
557        #print 'Number of elements', len(vertices)
558        domain = Domain(points, vertices, boundary)
559        domain.smooth = False
560        domain.default_order = 2
561        domain.set_datadir('.')
562        domain.set_name(filename)
563        domain.store = True
564        domain.format = 'sww'   #Native netcdf visualisation format
565
566        #print points
567        start = time.mktime(time.strptime('2000', '%Y'))
568        domain.starttime = start
569
570
571        #Store structure
572        domain.initialise_storage()
573
574        #Compute artificial time steps and store
575        dt = 60  #One minute intervals
576        t = 0.0
577        while t <= finaltime:
578            #Compute quantities
579            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
580            domain.set_quantity('stage', f1)
581
582            f2 = lambda x,y: x+y+t**2
583            domain.set_quantity('xmomentum', f2)
584
585            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
586            domain.set_quantity('ymomentum', f3)
587
588            #Store and advance time
589            domain.time = t
590            domain.store_timestep(domain.conserved_quantities)
591            t += dt
592
593
594        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
595     
596        #Deliberately set domain.starttime to too early
597        domain.starttime = start - 1
598
599        #Create file function
600        F = file_function(filename + '.sww', domain,
601                          quantities = domain.conserved_quantities,
602                          interpolation_points = interpolation_points)
603
604        #Check that FF updates fixes domain starttime
605        assert num.allclose(domain.starttime, start)
606
607        #Check that domain.starttime isn't updated if later
608        domain.starttime = start + 1
609        F = file_function(filename + '.sww', domain,
610                          quantities = domain.conserved_quantities,
611                          interpolation_points = interpolation_points)
612        assert num.allclose(domain.starttime, start+1)
613        domain.starttime = start
614
615
616        #Check linear interpolation in time
617        F = file_function(filename + '.sww', domain,
618                          quantities = domain.conserved_quantities,
619                          interpolation_points = interpolation_points)               
620        for id in range(len(interpolation_points)):
621            x = interpolation_points[id][0]
622            y = interpolation_points[id][1]
623
624            for i in range(20):
625                t = i*10
626                k = i%6
627
628                if k == 0:
629                    q0 = F(t, point_id=id)
630                    q1 = F(t+60, point_id=id)
631
632                if q0 == NAN:
633                    actual = q0
634                else:
635                    actual = (k*q1 + (6-k)*q0)/6
636                q = F(t, point_id=id)
637                #print i, k, t, q
638                #print ' ', q0
639                #print ' ', q1
640                #print "q",q
641                #print "actual", actual
642                #print
643                if q0 == NAN:
644                     self.failUnless( q == actual, 'Fail!')
645                else:
646                    assert num.allclose(q, actual)
647
648
649        #Another check of linear interpolation in time
650        for id in range(len(interpolation_points)):
651            q60 = F(60, point_id=id)
652            q120 = F(120, point_id=id)
653
654            t = 90 #Halfway between 60 and 120
655            q = F(t, point_id=id)
656            assert num.allclose( (q120+q60)/2, q )
657
658            t = 100 #Two thirds of the way between between 60 and 120
659            q = F(t, point_id=id)
660            assert num.allclose(q60/3 + 2*q120/3, q)
661
662
663
664        #Check that domain.starttime isn't updated if later than file starttime but earlier
665        #than file end time
666        delta = 23
667        domain.starttime = start + delta
668        F = file_function(filename + '.sww', domain,
669                          quantities = domain.conserved_quantities,
670                          interpolation_points = interpolation_points)
671        assert num.allclose(domain.starttime, start+delta)
672
673
674
675
676        #Now try interpolation with delta offset
677        for id in range(len(interpolation_points)):           
678            x = interpolation_points[id][0]
679            y = interpolation_points[id][1]
680
681            for i in range(20):
682                t = i*10
683                k = i%6
684
685                if k == 0:
686                    q0 = F(t-delta, point_id=id)
687                    q1 = F(t+60-delta, point_id=id)
688
689                q = F(t-delta, point_id=id)
690                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
691
692
693        os.remove(filename + '.sww')
694
695
696
697    def Xtest_spatio_temporal_file_function_time(self):
698        # FIXME: This passes but needs some TLC
699        # Test that File function interpolates correctly
700        # When some points are outside the mesh
701
702        import os, time
703        from anuga.config import time_format
704        from mesh_factory import rectangular
705        from shallow_water import Domain
706        import anuga.shallow_water.data_manager 
707        from anuga.pmesh.mesh_interface import create_mesh_from_regions
708        finaltime = 1200
709       
710        filename = tempfile.mktemp()
711        #print "filename",filename
712        filename = 'test_file_function'
713
714        meshfilename = tempfile.mktemp(".tsh")
715
716        boundary_tags = {'walls':[0,1],'bom':[2]}
717       
718        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
719       
720        create_mesh_from_regions(polygon_absolute,
721                                 boundary_tags,
722                                 10000000,
723                                 filename=meshfilename)
724        domain = Domain(mesh_filename=meshfilename)
725        domain.smooth = False
726        domain.default_order = 2
727        domain.set_datadir('.')
728        domain.set_name(filename)
729        domain.store = True
730        domain.format = 'sww'   #Native netcdf visualisation format
731
732        #print points
733        start = time.mktime(time.strptime('2000', '%Y'))
734        domain.starttime = start
735       
736
737        #Store structure
738        domain.initialise_storage()
739
740        #Compute artificial time steps and store
741        dt = 60  #One minute intervals
742        t = 0.0
743        while t <= finaltime:
744            #Compute quantities
745            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
746            domain.set_quantity('stage', f1)
747
748            f2 = lambda x,y: x+y+t**2
749            domain.set_quantity('xmomentum', f2)
750
751            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
752            domain.set_quantity('ymomentum', f3)
753
754            #Store and advance time
755            domain.time = t
756            domain.store_timestep(domain.conserved_quantities)
757            t += dt
758
759        interpolation_points = [[1,0]]
760        interpolation_points = [[100,1000]]
761       
762        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
763                                [78787,78787],[7878,3432]]
764           
765        #Deliberately set domain.starttime to too early
766        domain.starttime = start - 1
767
768        #Create file function
769        F = file_function(filename + '.sww', domain,
770                          quantities = domain.conserved_quantities,
771                          interpolation_points = interpolation_points)
772
773        #Check that FF updates fixes domain starttime
774        assert num.allclose(domain.starttime, start)
775
776        #Check that domain.starttime isn't updated if later
777        domain.starttime = start + 1
778        F = file_function(filename + '.sww', domain,
779                          quantities = domain.conserved_quantities,
780                          interpolation_points = interpolation_points)
781        assert num.allclose(domain.starttime, start+1)
782        domain.starttime = start
783
784
785        #Check linear interpolation in time
786        # checking points inside and outside the mesh
787        F = file_function(filename + '.sww', domain,
788                          quantities = domain.conserved_quantities,
789                          interpolation_points = interpolation_points)
790       
791        for id in range(len(interpolation_points)):
792            x = interpolation_points[id][0]
793            y = interpolation_points[id][1]
794
795            for i in range(20):
796                t = i*10
797                k = i%6
798
799                if k == 0:
800                    q0 = F(t, point_id=id)
801                    q1 = F(t+60, point_id=id)
802
803                if q0 == NAN:
804                    actual = q0
805                else:
806                    actual = (k*q1 + (6-k)*q0)/6
807                q = F(t, point_id=id)
808                #print i, k, t, q
809                #print ' ', q0
810                #print ' ', q1
811                #print "q",q
812                #print "actual", actual
813                #print
814                if q0 == NAN:
815                     self.failUnless( q == actual, 'Fail!')
816                else:
817                    assert num.allclose(q, actual)
818
819        # now lets check points inside the mesh
820        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
821        interpolation_points = [[10,-12.5]]
822           
823        print "len(interpolation_points)",len(interpolation_points) 
824        F = file_function(filename + '.sww', domain,
825                          quantities = domain.conserved_quantities,
826                          interpolation_points = interpolation_points)
827
828        domain.starttime = start
829
830
831        #Check linear interpolation in time
832        F = file_function(filename + '.sww', domain,
833                          quantities = domain.conserved_quantities,
834                          interpolation_points = interpolation_points)               
835        for id in range(len(interpolation_points)):
836            x = interpolation_points[id][0]
837            y = interpolation_points[id][1]
838
839            for i in range(20):
840                t = i*10
841                k = i%6
842
843                if k == 0:
844                    q0 = F(t, point_id=id)
845                    q1 = F(t+60, point_id=id)
846
847                if q0 == NAN:
848                    actual = q0
849                else:
850                    actual = (k*q1 + (6-k)*q0)/6
851                q = F(t, point_id=id)
852                print "############"
853                print "id, x, y ", id, x, y #k, t, q
854                print "t", t
855                #print ' ', q0
856                #print ' ', q1
857                print "q",q
858                print "actual", actual
859                #print
860                if q0 == NAN:
861                     self.failUnless( q == actual, 'Fail!')
862                else:
863                    assert num.allclose(q, actual)
864
865
866        #Another check of linear interpolation in time
867        for id in range(len(interpolation_points)):
868            q60 = F(60, point_id=id)
869            q120 = F(120, point_id=id)
870
871            t = 90 #Halfway between 60 and 120
872            q = F(t, point_id=id)
873            assert num.allclose( (q120+q60)/2, q )
874
875            t = 100 #Two thirds of the way between between 60 and 120
876            q = F(t, point_id=id)
877            assert num.allclose(q60/3 + 2*q120/3, q)
878
879
880
881        #Check that domain.starttime isn't updated if later than file starttime but earlier
882        #than file end time
883        delta = 23
884        domain.starttime = start + delta
885        F = file_function(filename + '.sww', domain,
886                          quantities = domain.conserved_quantities,
887                          interpolation_points = interpolation_points)
888        assert num.allclose(domain.starttime, start+delta)
889
890
891
892
893        #Now try interpolation with delta offset
894        for id in range(len(interpolation_points)):           
895            x = interpolation_points[id][0]
896            y = interpolation_points[id][1]
897
898            for i in range(20):
899                t = i*10
900                k = i%6
901
902                if k == 0:
903                    q0 = F(t-delta, point_id=id)
904                    q1 = F(t+60-delta, point_id=id)
905
906                q = F(t-delta, point_id=id)
907                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
908
909
910        os.remove(filename + '.sww')
911
912    def test_file_function_time_with_domain(self):
913        """Test that File function interpolates correctly
914        between given times. No x,y dependency here.
915        Use domain with starttime
916        """
917
918        #Write file
919        import os, time, calendar
920        from anuga.config import time_format
921        from math import sin, pi
922        from domain import Domain
923
924        finaltime = 1200
925        filename = 'test_file_function'
926        fid = open(filename + '.txt', 'w')
927        start = time.mktime(time.strptime('2000', '%Y'))
928        dt = 60  #One minute intervals
929        t = 0.0
930        while t <= finaltime:
931            t_string = time.strftime(time_format, time.gmtime(t+start))
932            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
933            t += dt
934
935        fid.close()
936
937
938        #Convert ASCII file to NetCDF (Which is what we really like!)
939        timefile2netcdf(filename)
940
941
942
943        a = [0.0, 0.0]
944        b = [4.0, 0.0]
945        c = [0.0, 3.0]
946
947        points = [a, b, c]
948        vertices = [[0,1,2]]
949        domain = Domain(points, vertices)
950
951        # Check that domain.starttime is updated if non-existing
952        F = file_function(filename + '.tms',
953                          domain,
954                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
955        assert num.allclose(domain.starttime, start)
956
957        # Check that domain.starttime is updated if too early
958        domain.starttime = start - 1
959        F = file_function(filename + '.tms',
960                          domain,
961                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
962        assert num.allclose(domain.starttime, start)
963
964        # Check that domain.starttime isn't updated if later
965        domain.starttime = start + 1
966        F = file_function(filename + '.tms',
967                          domain,
968                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
969        assert num.allclose(domain.starttime, start+1)
970
971        domain.starttime = start
972        F = file_function(filename + '.tms',
973                          domain,
974                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
975                          use_cache=True)
976       
977
978        #print F.precomputed_values
979        #print 'F(60)', F(60)
980       
981        #Now try interpolation
982        for i in range(20):
983            t = i*10
984            q = F(t)
985
986            #Exact linear intpolation
987            assert num.allclose(q[0], 2*t)
988            if i%6 == 0:
989                assert num.allclose(q[1], t**2)
990                assert num.allclose(q[2], sin(t*pi/600))
991
992        #Check non-exact
993
994        t = 90 #Halfway between 60 and 120
995        q = F(t)
996        assert num.allclose( (120**2 + 60**2)/2, q[1] )
997        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
998
999
1000        t = 100 #Two thirds of the way between between 60 and 120
1001        q = F(t)
1002        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1003        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1004
1005        os.remove(filename + '.tms')
1006        os.remove(filename + '.txt')       
1007
1008    def test_file_function_time_with_domain_different_start(self):
1009        """Test that File function interpolates correctly
1010        between given times. No x,y dependency here.
1011        Use domain with a starttime later than that of file
1012
1013        ASCII version
1014        """
1015
1016        #Write file
1017        import os, time, calendar
1018        from anuga.config import time_format
1019        from math import sin, pi
1020        from domain import Domain
1021
1022        finaltime = 1200
1023        filename = 'test_file_function'
1024        fid = open(filename + '.txt', 'w')
1025        start = time.mktime(time.strptime('2000', '%Y'))
1026        dt = 60  #One minute intervals
1027        t = 0.0
1028        while t <= finaltime:
1029            t_string = time.strftime(time_format, time.gmtime(t+start))
1030            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1031            t += dt
1032
1033        fid.close()
1034
1035        #Convert ASCII file to NetCDF (Which is what we really like!)
1036        timefile2netcdf(filename)       
1037
1038        a = [0.0, 0.0]
1039        b = [4.0, 0.0]
1040        c = [0.0, 3.0]
1041
1042        points = [a, b, c]
1043        vertices = [[0,1,2]]
1044        domain = Domain(points, vertices)
1045
1046        #Check that domain.starttime isn't updated if later than file starttime but earlier
1047        #than file end time
1048        delta = 23
1049        domain.starttime = start + delta
1050        F = file_function(filename + '.tms', domain,
1051                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1052        assert num.allclose(domain.starttime, start+delta)
1053
1054        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1055                                            277., 337., 397., 457., 517.,
1056                                            577., 637., 697., 757., 817.,
1057                                            877., 937., 997., 1057., 1117.,
1058                                            1177.])
1059
1060
1061        #Now try interpolation with delta offset
1062        for i in range(20):
1063            t = i*10
1064            q = F(t-delta)
1065
1066            #Exact linear intpolation
1067            assert num.allclose(q[0], 2*t)
1068            if i%6 == 0:
1069                assert num.allclose(q[1], t**2)
1070                assert num.allclose(q[2], sin(t*pi/600))
1071
1072        #Check non-exact
1073
1074        t = 90 #Halfway between 60 and 120
1075        q = F(t-delta)
1076        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1077        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1078
1079
1080        t = 100 #Two thirds of the way between between 60 and 120
1081        q = F(t-delta)
1082        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1083        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1084
1085
1086        os.remove(filename + '.tms')
1087        os.remove(filename + '.txt')               
1088
1089       
1090
1091    def test_file_function_time_with_domain_different_start_and_time_limit(self):
1092        """Test that File function interpolates correctly
1093        between given times. No x,y dependency here.
1094        Use domain with a starttime later than that of file
1095
1096        ASCII version
1097       
1098        This test also tests that time can be truncated.
1099        """
1100
1101        # Write file
1102        import os, time, calendar
1103        from anuga.config import time_format
1104        from math import sin, pi
1105        from domain import Domain
1106
1107        finaltime = 1200
1108        filename = 'test_file_function'
1109        fid = open(filename + '.txt', 'w')
1110        start = time.mktime(time.strptime('2000', '%Y'))
1111        dt = 60  #One minute intervals
1112        t = 0.0
1113        while t <= finaltime:
1114            t_string = time.strftime(time_format, time.gmtime(t+start))
1115            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1116            t += dt
1117
1118        fid.close()
1119
1120        # Convert ASCII file to NetCDF (Which is what we really like!)
1121        timefile2netcdf(filename)       
1122
1123        a = [0.0, 0.0]
1124        b = [4.0, 0.0]
1125        c = [0.0, 3.0]
1126
1127        points = [a, b, c]
1128        vertices = [[0,1,2]]
1129        domain = Domain(points, vertices)
1130
1131        # Check that domain.starttime isn't updated if later than file starttime but earlier
1132        # than file end time
1133        delta = 23
1134        domain.starttime = start + delta
1135        F = file_function(filename + '.tms', domain,
1136                          time_limit=600,
1137                          quantities=['Attribute0', 'Attribute1', 'Attribute2'])       
1138        assert num.allclose(domain.starttime, start+delta)
1139
1140        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1141                                            277., 337., 397., 457., 517.,
1142                                            577.])       
1143
1144
1145
1146        # Now try interpolation with delta offset
1147        for i in range(20):
1148            t = i*10
1149            q = F(t-delta)
1150
1151            #Exact linear intpolation
1152            assert num.allclose(q[0], 2*t)
1153            if i%6 == 0:
1154                assert num.allclose(q[1], t**2)
1155                assert num.allclose(q[2], sin(t*pi/600))
1156
1157        # Check non-exact
1158        t = 90 #Halfway between 60 and 120
1159        q = F(t-delta)
1160        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1161        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1162
1163
1164        t = 100 # Two thirds of the way between between 60 and 120
1165        q = F(t-delta)
1166        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1167        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1168
1169
1170        os.remove(filename + '.tms')
1171        os.remove(filename + '.txt')               
1172
1173       
1174       
1175       
1176
1177
1178    def test_apply_expression_to_dictionary(self):
1179
1180        #FIXME: Division is not expected to work for integers.
1181        #This must be caught.
1182        foo = num.array([[1,2,3], [4,5,6]], num.Float)
1183
1184        bar = num.array([[-1,0,5], [6,1,1]], num.Float)                 
1185
1186        D = {'X': foo, 'Y': bar}
1187
1188        Z = apply_expression_to_dictionary('X+Y', D)       
1189        assert num.allclose(Z, foo+bar)
1190
1191        Z = apply_expression_to_dictionary('X*Y', D)       
1192        assert num.allclose(Z, foo*bar)       
1193
1194        Z = apply_expression_to_dictionary('4*X+Y', D)       
1195        assert num.allclose(Z, 4*foo+bar)       
1196
1197        # test zero division is OK
1198        Z = apply_expression_to_dictionary('X/Y', D)
1199        assert num.allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1200
1201        # make an error for zero on zero
1202        # this is really an error in Numeric, SciPy core can handle it
1203        # Z = apply_expression_to_dictionary('0/Y', D)
1204
1205        #Check exceptions
1206        try:
1207            #Wrong name
1208            Z = apply_expression_to_dictionary('4*X+A', D)       
1209        except NameError:
1210            pass
1211        else:
1212            msg = 'Should have raised a NameError Exception'
1213            raise msg
1214
1215
1216        try:
1217            #Wrong order
1218            Z = apply_expression_to_dictionary(D, '4*X+A')       
1219        except AssertionError:
1220            pass
1221        else:
1222            msg = 'Should have raised a AssertionError Exception'
1223            raise msg       
1224       
1225
1226    def test_multiple_replace(self):
1227        """Hard test that checks a true word-by-word simultaneous replace
1228        """
1229       
1230        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1231        exp = '3*x+y + xi'
1232       
1233        new = multiple_replace(exp, D)
1234       
1235        assert new == '3*xi+eta + lam'
1236                         
1237
1238
1239    def test_point_on_line_obsolete(self):
1240        """Test that obsolete call issues appropriate warning"""
1241
1242        #Turn warning into an exception
1243        import warnings
1244        warnings.filterwarnings('error')
1245
1246        try:
1247            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1248        except DeprecationWarning:
1249            pass
1250        else:
1251            msg = 'point_on_line should have issued a DeprecationWarning'
1252            raise Exception(msg)   
1253
1254        warnings.resetwarnings()
1255   
1256    def test_get_revision_number(self):
1257        """test_get_revision_number(self):
1258
1259        Test that revision number can be retrieved.
1260        """
1261        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1262            # I have a known snv incompatability issue,
1263            # so I'm skipping this test.
1264            # FIXME when SVN is upgraded on our clusters
1265            pass
1266        else:   
1267            n = get_revision_number()
1268            assert n>=0
1269
1270
1271       
1272    def test_add_directories(self):
1273       
1274        import tempfile
1275        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1276        directories = ['ja','ne','ke']
1277        kens_dir = add_directories(root_dir, directories)
1278        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1279               sep + 'ke'
1280        assert access(root_dir,F_OK)
1281
1282        add_directories(root_dir, directories)
1283        assert access(root_dir,F_OK)
1284       
1285        #clean up!
1286        os.rmdir(kens_dir)
1287        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1288        os.rmdir(root_dir + sep + 'ja')
1289        os.rmdir(root_dir)
1290
1291    def test_add_directories_bad(self):
1292       
1293        import tempfile
1294        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1295        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1296       
1297        try:
1298            kens_dir = add_directories(root_dir, directories)
1299        except OSError:
1300            pass
1301        else:
1302            msg = 'bad dir name should give OSError'
1303            raise Exception(msg)   
1304           
1305        #clean up!
1306        os.rmdir(root_dir)
1307
1308    def test_check_list(self):
1309
1310        check_list(['stage','xmomentum'])
1311
1312       
1313    def test_add_directories(self):
1314       
1315        import tempfile
1316        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1317        directories = ['ja','ne','ke']
1318        kens_dir = add_directories(root_dir, directories)
1319        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1320               sep + 'ke'
1321        assert access(root_dir,F_OK)
1322
1323        add_directories(root_dir, directories)
1324        assert access(root_dir,F_OK)
1325       
1326        #clean up!
1327        os.rmdir(kens_dir)
1328        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1329        os.rmdir(root_dir + sep + 'ja')
1330        os.rmdir(root_dir)
1331
1332    def test_add_directories_bad(self):
1333       
1334        import tempfile
1335        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1336        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1337       
1338        try:
1339            kens_dir = add_directories(root_dir, directories)
1340        except OSError:
1341            pass
1342        else:
1343            msg = 'bad dir name should give OSError'
1344            raise Exception(msg)   
1345           
1346        #clean up!
1347        os.rmdir(root_dir)
1348
1349    def test_check_list(self):
1350
1351        check_list(['stage','xmomentum'])
1352
1353######
1354# Test the remove_lone_verts() function
1355######
1356       
1357    def test_remove_lone_verts_a(self):
1358        verts = [[0,0],[1,0],[0,1]]
1359        tris = [[0,1,2]]
1360        new_verts, new_tris = remove_lone_verts(verts, tris)
1361        self.failUnless(new_verts.tolist() == verts)
1362        self.failUnless(new_tris.tolist() == tris)
1363
1364    def test_remove_lone_verts_b(self):
1365        verts = [[0,0],[1,0],[0,1],[99,99]]
1366        tris = [[0,1,2]]
1367        new_verts, new_tris = remove_lone_verts(verts, tris)
1368        self.failUnless(new_verts.tolist() == verts[0:3])
1369        self.failUnless(new_tris.tolist() == tris)
1370       
1371    def test_remove_lone_verts_c(self):
1372        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1373        tris = [[1,2,4]]
1374        new_verts, new_tris = remove_lone_verts(verts, tris)
1375        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1376        self.failUnless(new_tris.tolist() == [[0,1,2]])
1377     
1378    def test_remove_lone_verts_d(self):
1379        verts = [[0,0],[1,0],[99,99],[0,1]]
1380        tris = [[0,1,3]]
1381        new_verts, new_tris = remove_lone_verts(verts, tris)
1382        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1383        self.failUnless(new_tris.tolist() == [[0,1,2]])
1384       
1385    def test_remove_lone_verts_e(self):
1386        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1387        tris = [[0,1,2]]
1388        new_verts, new_tris = remove_lone_verts(verts, tris)
1389        self.failUnless(new_verts.tolist() == verts[0:3])
1390        self.failUnless(new_tris.tolist() == tris)
1391     
1392    def test_remove_lone_verts_f(self):
1393        verts = [[0,0],[1,0],[99,99],[0,1],[99,99],[1,1],[99,99]]
1394        tris = [[0,1,3],[0,1,5]]
1395        new_verts, new_tris = remove_lone_verts(verts, tris)
1396        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1],[1,1]])
1397        self.failUnless(new_tris.tolist() == [[0,1,2],[0,1,3]])
1398       
1399######
1400#
1401######
1402       
1403    def test_get_min_max_values(self):
1404       
1405        list=[8,9,6,1,4]
1406        min1, max1 = get_min_max_values(list)
1407       
1408        assert min1==1 
1409        assert max1==9
1410       
1411    def test_get_min_max_values1(self):
1412       
1413        list=[-8,-9,-6,-1,-4]
1414        min1, max1 = get_min_max_values(list)
1415       
1416#        print 'min1,max1',min1,max1
1417        assert min1==-9 
1418        assert max1==-1
1419
1420#    def test_get_min_max_values2(self):
1421#        '''
1422#        The min and max supplied are greater than the ones in the
1423#        list and therefore are the ones returned
1424#        '''
1425#        list=[-8,-9,-6,-1,-4]
1426#        min1, max1 = get_min_max_values(list,-10,10)
1427#       
1428##        print 'min1,max1',min1,max1
1429#        assert min1==-10
1430#        assert max1==10
1431       
1432    def test_make_plots_from_csv_files(self):
1433       
1434        #if sys.platform == 'win32':  #Windows
1435            try: 
1436                import pylab
1437            except ImportError:
1438                #ANUGA don't need pylab to work so the system doesn't
1439                #rely on pylab being installed
1440                return
1441           
1442       
1443            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1444            temp_dir = tempfile.mkdtemp('','figures')
1445    #        print 'temp_dir',temp_dir
1446            fileName = temp_dir+sep+'time_series_3.csv'
1447            file = open(fileName,"w")
1448            file.write("time,stage,speed,momentum,elevation\n\
14491.0, 0, 0, 0, 10 \n\
14502.0, 5, 2, 4, 10 \n\
14513.0, 3, 3, 5, 10 \n")
1452            file.close()
1453   
1454            fileName1 = temp_dir+sep+'time_series_4.csv'
1455            file1 = open(fileName1,"w")
1456            file1.write("time,stage,speed,momentum,elevation\n\
14571.0, 0, 0, 0, 5 \n\
14582.0, -5, -2, -4, 5 \n\
14593.0, -4, -3, -5, 5 \n")
1460            file1.close()
1461   
1462            fileName2 = temp_dir+sep+'time_series_5.csv'
1463            file2 = open(fileName2,"w")
1464            file2.write("time,stage,speed,momentum,elevation\n\
14651.0, 0, 0, 0, 7 \n\
14662.0, 4, -0.45, 57, 7 \n\
14673.0, 6, -0.5, 56, 7 \n")
1468            file2.close()
1469           
1470            dir, name=os.path.split(fileName)
1471            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1472                                  output_dir=temp_dir,
1473                                  base_name='time_series_',
1474                                  plot_numbers=['3-5'],
1475                                  quantities=['speed','stage','momentum'],
1476                                  assess_all_csv_files=True,
1477                                  extra_plot_name='test')
1478           
1479            #print dir+sep+name[:-4]+'_stage_test.png'
1480            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1481            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1482            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1483   
1484            dir1, name1=os.path.split(fileName1)
1485            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1486            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1487            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1488   
1489   
1490            dir2, name2=os.path.split(fileName2)
1491            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1492            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1493            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1494   
1495            del_dir(temp_dir)
1496       
1497
1498    def test_sww2csv_gauges(self):
1499
1500        def elevation_function(x, y):
1501            return -x
1502       
1503        """Most of this test was copied from test_interpolate
1504        test_interpole_sww2csv
1505       
1506        This is testing the gauge_sww2csv function, by creating a sww file and
1507        then exporting the gauges and checking the results.
1508        """
1509       
1510        # Create mesh
1511        mesh_file = tempfile.mktemp(".tsh")   
1512        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1513        m = Mesh()
1514        m.add_vertices(points)
1515        m.auto_segment()
1516        m.generate_mesh(verbose=False)
1517        m.export_mesh_file(mesh_file)
1518       
1519        # Create shallow water domain
1520        domain = Domain(mesh_file)
1521        os.remove(mesh_file)
1522       
1523        domain.default_order=2
1524       
1525        # This test was made before tight_slope_limiters were introduced
1526        # Since were are testing interpolation values this is OK
1527        domain.tight_slope_limiters = 0 
1528       
1529
1530        # Set some field values
1531        domain.set_quantity('elevation', elevation_function)
1532        domain.set_quantity('friction', 0.03)
1533        domain.set_quantity('xmomentum', 3.0)
1534        domain.set_quantity('ymomentum', 4.0)
1535
1536        ######################
1537        # Boundary conditions
1538        B = Transmissive_boundary(domain)
1539        domain.set_boundary( {'exterior': B})
1540
1541        # This call mangles the stage values.
1542        domain.distribute_to_vertices_and_edges()
1543        domain.set_quantity('stage', 1.0)
1544
1545
1546        domain.set_name('datatest' + str(time.time()))
1547        domain.format = 'sww'
1548        domain.smooth = True
1549        domain.reduction = mean
1550
1551
1552        sww = get_dataobject(domain)
1553        sww.store_connectivity()
1554        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1555        domain.set_quantity('stage', 10.0) # This is automatically limited
1556        # so it will not be less than the elevation
1557        domain.time = 2.
1558        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1559
1560        # test the function
1561        points = [[5.0,1.],[0.5,2.]]
1562
1563        points_file = tempfile.mktemp(".csv")
1564#        points_file = 'test_point.csv'
1565        file_id = open(points_file,"w")
1566        file_id.write("name, easting, northing, elevation \n\
1567point1, 5.0, 1.0, 3.0\n\
1568point2, 0.5, 2.0, 9.0\n")
1569        file_id.close()
1570
1571       
1572        sww2csv_gauges(sww.filename, 
1573                       points_file,
1574                       verbose=False,
1575                       use_cache=False)
1576
1577#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1578        point1_answers_array = [[0.0,1.0,6.0,-5.0,3.0,4.0], [2.0,10.0,15.0,-5.0,3.0,4.0]]
1579        point1_filename = 'gauge_point1.csv'
1580        point1_handle = file(point1_filename)
1581        point1_reader = reader(point1_handle)
1582        point1_reader.next()
1583
1584        line=[]
1585        for i,row in enumerate(point1_reader):
1586            #print 'i',i,'row',row
1587            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1588            #print 'assert line',line[i],'point1',point1_answers_array[i]
1589            assert num.allclose(line[i], point1_answers_array[i])
1590
1591        point2_answers_array = [[0.0,1.0,1.5,-0.5,3.0,4.0], [2.0,10.0,10.5,-0.5,3.0,4.0]]
1592        point2_filename = 'gauge_point2.csv' 
1593        point2_handle = file(point2_filename)
1594        point2_reader = reader(point2_handle)
1595        point2_reader.next()
1596                       
1597        line=[]
1598        for i,row in enumerate(point2_reader):
1599            #print 'i',i,'row',row
1600            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1601            #print 'assert line',line[i],'point1',point1_answers_array[i]
1602            assert num.allclose(line[i], point2_answers_array[i])
1603                         
1604        # clean up
1605        point1_handle.close()
1606        point2_handle.close()
1607        #print "sww.filename",sww.filename
1608        os.remove(sww.filename)
1609        os.remove(points_file)
1610        os.remove(point1_filename)
1611        os.remove(point2_filename)
1612
1613
1614
1615    def test_sww2csv_gauges1(self):
1616        from anuga.pmesh.mesh import Mesh
1617        from anuga.shallow_water import Domain, Transmissive_boundary
1618        from anuga.shallow_water.data_manager import get_dataobject
1619        from csv import reader,writer
1620        import time
1621        import string
1622
1623        def elevation_function(x, y):
1624            return -x
1625       
1626        """Most of this test was copied from test_interpolate
1627        test_interpole_sww2csv
1628       
1629        This is testing the gauge_sww2csv function, by creating a sww file and
1630        then exporting the gauges and checking the results.
1631       
1632        This tests the ablity not to have elevation in the points file and
1633        not store xmomentum and ymomentum
1634        """
1635       
1636        # Create mesh
1637        mesh_file = tempfile.mktemp(".tsh")   
1638        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1639        m = Mesh()
1640        m.add_vertices(points)
1641        m.auto_segment()
1642        m.generate_mesh(verbose=False)
1643        m.export_mesh_file(mesh_file)
1644       
1645        # Create shallow water domain
1646        domain = Domain(mesh_file)
1647        os.remove(mesh_file)
1648       
1649        domain.default_order=2
1650
1651        # Set some field values
1652        domain.set_quantity('elevation', elevation_function)
1653        domain.set_quantity('friction', 0.03)
1654        domain.set_quantity('xmomentum', 3.0)
1655        domain.set_quantity('ymomentum', 4.0)
1656
1657        ######################
1658        # Boundary conditions
1659        B = Transmissive_boundary(domain)
1660        domain.set_boundary( {'exterior': B})
1661
1662        # This call mangles the stage values.
1663        domain.distribute_to_vertices_and_edges()
1664        domain.set_quantity('stage', 1.0)
1665
1666
1667        domain.set_name('datatest' + str(time.time()))
1668        domain.format = 'sww'
1669        domain.smooth = True
1670        domain.reduction = mean
1671
1672        sww = get_dataobject(domain)
1673        sww.store_connectivity()
1674        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1675        domain.set_quantity('stage', 10.0) # This is automatically limited
1676        # so it will not be less than the elevation
1677        domain.time = 2.
1678        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1679
1680        # test the function
1681        points = [[5.0,1.],[0.5,2.]]
1682
1683        points_file = tempfile.mktemp(".csv")
1684#        points_file = 'test_point.csv'
1685        file_id = open(points_file,"w")
1686        file_id.write("name,easting,northing \n\
1687point1, 5.0, 1.0\n\
1688point2, 0.5, 2.0\n")
1689        file_id.close()
1690
1691        sww2csv_gauges(sww.filename, 
1692                            points_file,
1693                            quantities=['stage', 'elevation'],
1694                            use_cache=False,
1695                            verbose=False)
1696
1697        point1_answers_array = [[0.0,1.0,-5.0], [2.0,10.0,-5.0]]
1698        point1_filename = 'gauge_point1.csv'
1699        point1_handle = file(point1_filename)
1700        point1_reader = reader(point1_handle)
1701        point1_reader.next()
1702
1703        line=[]
1704        for i,row in enumerate(point1_reader):
1705#            print 'i',i,'row',row
1706            line.append([float(row[0]),float(row[1]),float(row[2])])
1707            #print 'line',line[i],'point1',point1_answers_array[i]
1708            assert num.allclose(line[i], point1_answers_array[i])
1709
1710        point2_answers_array = [[0.0,1.0,-0.5], [2.0,10.0,-0.5]]
1711        point2_filename = 'gauge_point2.csv' 
1712        point2_handle = file(point2_filename)
1713        point2_reader = reader(point2_handle)
1714        point2_reader.next()
1715                       
1716        line=[]
1717        for i,row in enumerate(point2_reader):
1718#            print 'i',i,'row',row
1719            line.append([float(row[0]),float(row[1]),float(row[2])])
1720#            print 'line',line[i],'point1',point1_answers_array[i]
1721            assert num.allclose(line[i], point2_answers_array[i])
1722                         
1723        # clean up
1724        point1_handle.close()
1725        point2_handle.close()
1726        #print "sww.filename",sww.filename
1727        os.remove(sww.filename)
1728        os.remove(points_file)
1729        os.remove(point1_filename)
1730        os.remove(point2_filename)
1731
1732
1733    def test_sww2csv_gauges2(self):
1734
1735        def elevation_function(x, y):
1736            return -x
1737       
1738        """Most of this test was copied from test_interpolate
1739        test_interpole_sww2csv
1740       
1741        This is testing the gauge_sww2csv function, by creating a sww file and
1742        then exporting the gauges and checking the results.
1743       
1744        This is the same as sww2csv_gauges except set domain.set_starttime to 5.
1745        Therefore testing the storing of the absolute time in the csv files
1746        """
1747       
1748        # Create mesh
1749        mesh_file = tempfile.mktemp(".tsh")   
1750        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1751        m = Mesh()
1752        m.add_vertices(points)
1753        m.auto_segment()
1754        m.generate_mesh(verbose=False)
1755        m.export_mesh_file(mesh_file)
1756       
1757        # Create shallow water domain
1758        domain = Domain(mesh_file)
1759        os.remove(mesh_file)
1760       
1761        domain.default_order=2
1762
1763        # This test was made before tight_slope_limiters were introduced
1764        # Since were are testing interpolation values this is OK
1765        domain.tight_slope_limiters = 0         
1766
1767        # Set some field values
1768        domain.set_quantity('elevation', elevation_function)
1769        domain.set_quantity('friction', 0.03)
1770        domain.set_quantity('xmomentum', 3.0)
1771        domain.set_quantity('ymomentum', 4.0)
1772        domain.set_starttime(5)
1773
1774        ######################
1775        # Boundary conditions
1776        B = Transmissive_boundary(domain)
1777        domain.set_boundary( {'exterior': B})
1778
1779        # This call mangles the stage values.
1780        domain.distribute_to_vertices_and_edges()
1781        domain.set_quantity('stage', 1.0)
1782       
1783
1784
1785        domain.set_name('datatest' + str(time.time()))
1786        domain.format = 'sww'
1787        domain.smooth = True
1788        domain.reduction = mean
1789
1790        sww = get_dataobject(domain)
1791        sww.store_connectivity()
1792        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1793        domain.set_quantity('stage', 10.0) # This is automatically limited
1794        # so it will not be less than the elevation
1795        domain.time = 2.
1796        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1797
1798        # test the function
1799        points = [[5.0,1.],[0.5,2.]]
1800
1801        points_file = tempfile.mktemp(".csv")
1802#        points_file = 'test_point.csv'
1803        file_id = open(points_file,"w")
1804        file_id.write("name, easting, northing, elevation \n\
1805point1, 5.0, 1.0, 3.0\n\
1806point2, 0.5, 2.0, 9.0\n")
1807        file_id.close()
1808
1809       
1810        sww2csv_gauges(sww.filename, 
1811                            points_file,
1812                            verbose=False,
1813                            use_cache=False)
1814
1815#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1816        point1_answers_array = [[5.0,1.0,6.0,-5.0,3.0,4.0], [7.0,10.0,15.0,-5.0,3.0,4.0]]
1817        point1_filename = 'gauge_point1.csv'
1818        point1_handle = file(point1_filename)
1819        point1_reader = reader(point1_handle)
1820        point1_reader.next()
1821
1822        line=[]
1823        for i,row in enumerate(point1_reader):
1824            #print 'i',i,'row',row
1825            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1826            #print 'assert line',line[i],'point1',point1_answers_array[i]
1827            assert num.allclose(line[i], point1_answers_array[i])
1828
1829        point2_answers_array = [[5.0,1.0,1.5,-0.5,3.0,4.0], [7.0,10.0,10.5,-0.5,3.0,4.0]]
1830        point2_filename = 'gauge_point2.csv' 
1831        point2_handle = file(point2_filename)
1832        point2_reader = reader(point2_handle)
1833        point2_reader.next()
1834                       
1835        line=[]
1836        for i,row in enumerate(point2_reader):
1837            #print 'i',i,'row',row
1838            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),float(row[4]),float(row[5])])
1839            #print 'assert line',line[i],'point1',point1_answers_array[i]
1840            assert num.allclose(line[i], point2_answers_array[i])
1841                         
1842        # clean up
1843        point1_handle.close()
1844        point2_handle.close()
1845        #print "sww.filename",sww.filename
1846        os.remove(sww.filename)
1847        os.remove(points_file)
1848        os.remove(point1_filename)
1849        os.remove(point2_filename)
1850
1851
1852    def test_greens_law(self):
1853
1854        from math import sqrt
1855       
1856        d1 = 80.0
1857        d2 = 20.0
1858        h1 = 1.0
1859        h2 = greens_law(d1,d2,h1)
1860
1861        assert h2==sqrt(2.0)
1862       
1863    def test_calc_bearings(self):
1864 
1865        from math import atan, degrees
1866        #Test East
1867        uh = 1
1868        vh = 1.e-15
1869        angle = calc_bearing(uh, vh)
1870        if 89 < angle < 91: v=1
1871        assert v==1
1872        #Test West
1873        uh = -1
1874        vh = 1.e-15
1875        angle = calc_bearing(uh, vh)
1876        if 269 < angle < 271: v=1
1877        assert v==1
1878        #Test North
1879        uh = 1.e-15
1880        vh = 1
1881        angle = calc_bearing(uh, vh)
1882        if -1 < angle < 1: v=1
1883        assert v==1
1884        #Test South
1885        uh = 1.e-15
1886        vh = -1
1887        angle = calc_bearing(uh, vh)
1888        if 179 < angle < 181: v=1
1889        assert v==1
1890        #Test South-East
1891        uh = 1
1892        vh = -1
1893        angle = calc_bearing(uh, vh)
1894        if 134 < angle < 136: v=1
1895        assert v==1
1896        #Test North-East
1897        uh = 1
1898        vh = 1
1899        angle = calc_bearing(uh, vh)
1900        if 44 < angle < 46: v=1
1901        assert v==1
1902        #Test South-West
1903        uh = -1
1904        vh = -1
1905        angle = calc_bearing(uh, vh)
1906        if 224 < angle < 226: v=1
1907        assert v==1
1908        #Test North-West
1909        uh = -1
1910        vh = 1
1911        angle = calc_bearing(uh, vh)
1912        if 314 < angle < 316: v=1
1913        assert v==1
1914
1915 
1916
1917
1918       
1919
1920#-------------------------------------------------------------
1921if __name__ == "__main__":
1922    suite = unittest.makeSuite(Test_Util,'test')
1923#    suite = unittest.makeSuite(Test_Util,'test_sww2csv_gauges')
1924#    runner = unittest.TextTestRunner(verbosity=2)
1925    runner = unittest.TextTestRunner(verbosity=1)
1926    runner.run(suite)
1927
1928
1929
1930
Note: See TracBrowser for help on using the repository browser.