source: anuga_core/source/anuga/abstract_2d_finite_volumes/test_util.py @ 7340

Last change on this file since 7340 was 7340, checked in by ole, 15 years ago

More refactoring in preparation for ticket:191
These are mostly simplifications of sww file creation

File size: 66.6 KB
Line 
1#!/usr/bin/env python
2
3
4import unittest
5from math import sqrt, pi
6import tempfile, os
7from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
8
9from anuga.abstract_2d_finite_volumes.util import *
10from anuga.config import epsilon
11from anuga.shallow_water.data_manager import timefile2netcdf,del_dir
12
13from anuga.utilities.numerical_tools import NAN
14
15from sys import platform
16
17from anuga.pmesh.mesh import Mesh
18from anuga.shallow_water import Domain, Transmissive_boundary
19from anuga.shallow_water.data_manager import SWW_file
20from csv import reader,writer
21import time
22import string
23
24import numpy as num
25
26
27def test_function(x, y):
28    return x+y
29
30class Test_Util(unittest.TestCase):
31    def setUp(self):
32        pass
33
34    def tearDown(self):
35        pass
36
37
38
39
40    #Geometric
41    #def test_distance(self):
42    #    from anuga.abstract_2d_finite_volumes.util import distance#
43    #
44    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
45    #                     'Distance is wrong!')
46    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
47    #                    'distance is wrong!')
48    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
49    #                    'distance is wrong!')
50    #
51    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
52    #                    'distance is wrong!')
53
54
55    def test_file_function_time1(self):
56        """Test that File function interpolates correctly
57        between given times. No x,y dependency here.
58        """
59
60        #Write file
61        import os, time
62        from anuga.config import time_format
63        from math import sin, pi
64
65        #Typical ASCII file
66        finaltime = 1200
67        filename = 'test_file_function'
68        fid = open(filename + '.txt', 'w')
69        start = time.mktime(time.strptime('2000', '%Y'))
70        dt = 60  #One minute intervals
71        t = 0.0
72        while t <= finaltime:
73            t_string = time.strftime(time_format, time.gmtime(t+start))
74            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
75            t += dt
76
77        fid.close()
78
79        #Convert ASCII file to NetCDF (Which is what we really like!)
80        timefile2netcdf(filename)
81
82
83        #Create file function from time series
84        F = file_function(filename + '.tms',
85                          quantities = ['Attribute0',
86                                        'Attribute1',
87                                        'Attribute2'])
88       
89        #Now try interpolation
90        for i in range(20):
91            t = i*10
92            q = F(t)
93
94            #Exact linear intpolation
95            assert num.allclose(q[0], 2*t)
96            if i%6 == 0:
97                assert num.allclose(q[1], t**2)
98                assert num.allclose(q[2], sin(t*pi/600))
99
100        #Check non-exact
101
102        t = 90 #Halfway between 60 and 120
103        q = F(t)
104        assert num.allclose( (120**2 + 60**2)/2, q[1] )
105        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
106
107
108        t = 100 #Two thirds of the way between between 60 and 120
109        q = F(t)
110        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
111        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
112
113        os.remove(filename + '.txt')
114        os.remove(filename + '.tms')       
115
116
117       
118    def test_spatio_temporal_file_function_basic(self):
119        """Test that spatio temporal file function performs the correct
120        interpolations in both time and space
121        NetCDF version (x,y,t dependency)       
122        """
123        import time
124
125        #Create sww file of simple propagation from left to right
126        #through rectangular domain
127        from shallow_water import Domain, Dirichlet_boundary
128        from mesh_factory import rectangular
129
130        #Create basic mesh and shallow water domain
131        points, vertices, boundary = rectangular(3, 3)
132        domain1 = Domain(points, vertices, boundary)
133
134        from anuga.utilities.numerical_tools import mean
135        domain1.reduction = mean
136        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
137                              # only one value.
138
139        domain1.default_order = 2
140        domain1.store = True
141        domain1.set_datadir('.')
142        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
143        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
144
145        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
146        domain1.set_quantity('elevation', 0)
147        domain1.set_quantity('friction', 0)
148        domain1.set_quantity('stage', 0)
149
150        # Boundary conditions
151        B0 = Dirichlet_boundary([0,0,0])
152        B6 = Dirichlet_boundary([0.6,0,0])
153        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
154        domain1.check_integrity()
155
156        finaltime = 8
157        #Evolution
158        t0 = -1
159        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
160            #print 'Timesteps: %.16f, %.16f' %(t0, t)
161            #if t == t0:
162            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
163            #   raise msg
164            t0 = t
165             
166            #domain1.write_time()
167
168
169        #Now read data from sww and check
170        from Scientific.IO.NetCDF import NetCDFFile
171        filename = domain1.get_name() + '.sww'
172        fid = NetCDFFile(filename)
173
174        x = fid.variables['x'][:]
175        y = fid.variables['y'][:]
176        stage = fid.variables['stage'][:]
177        xmomentum = fid.variables['xmomentum'][:]
178        ymomentum = fid.variables['ymomentum'][:]
179        time = fid.variables['time'][:]
180
181        #Take stage vertex values at last timestep on diagonal
182        #Diagonal is identified by vertices: 0, 5, 10, 15
183
184        last_time_index = len(time)-1 #Last last_time_index
185        d_stage = num.reshape(num.take(stage[last_time_index, :],
186                                       [0,5,10,15],
187                                       axis=0),
188                              (4,1))
189        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
190                                    [0,5,10,15],
191                                   axis=0),
192                           (4,1))
193        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
194                                    [0,5,10,15],
195                                   axis=0),
196                           (4,1))
197        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
198
199        #Reference interpolated values at midpoints on diagonal at
200        #this timestep are
201        r0 = (D[0] + D[1])/2
202        r1 = (D[1] + D[2])/2
203        r2 = (D[2] + D[3])/2
204
205        #And the midpoints are found now
206        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
207        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
208
209        diag = num.concatenate( (Dx, Dy), axis=1)
210        d_midpoints = (diag[1:] + diag[:-1])/2
211
212        #Let us see if the file function can find the correct
213        #values at the midpoints at the last timestep:
214        f = file_function(filename, domain1,
215                          interpolation_points = d_midpoints)
216
217        T = f.get_time()
218        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
219        assert not T[-1] == T[-2], msg
220        t = time[last_time_index]
221        q = f(t, point_id=0); assert num.allclose(r0, q)
222        q = f(t, point_id=1); assert num.allclose(r1, q)
223        q = f(t, point_id=2); assert num.allclose(r2, q)
224
225
226        ##################
227        #Now do the same for the first timestep
228
229        timestep = 0 #First timestep
230        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
231        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
232        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
233        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
234
235        #Reference interpolated values at midpoints on diagonal at
236        #this timestep are
237        r0 = (D[0] + D[1])/2
238        r1 = (D[1] + D[2])/2
239        r2 = (D[2] + D[3])/2
240
241        #Let us see if the file function can find the correct
242        #values
243        q = f(0, point_id=0); assert num.allclose(r0, q)
244        q = f(0, point_id=1); assert num.allclose(r1, q)
245        q = f(0, point_id=2); assert num.allclose(r2, q)
246
247
248        ##################
249        #Now do it again for a timestep in the middle
250
251        timestep = 33
252        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
253        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
254        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
255        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
256
257        #Reference interpolated values at midpoints on diagonal at
258        #this timestep are
259        r0 = (D[0] + D[1])/2
260        r1 = (D[1] + D[2])/2
261        r2 = (D[2] + D[3])/2
262
263        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
264        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
265        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
266
267
268        ##################
269        #Now check temporal interpolation
270        #Halfway between timestep 15 and 16
271
272        timestep = 15
273        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
274        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
275        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
276        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
277
278        #Reference interpolated values at midpoints on diagonal at
279        #this timestep are
280        r0_0 = (D[0] + D[1])/2
281        r1_0 = (D[1] + D[2])/2
282        r2_0 = (D[2] + D[3])/2
283
284        #
285        timestep = 16
286        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
287        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
288        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
289        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
290
291        #Reference interpolated values at midpoints on diagonal at
292        #this timestep are
293        r0_1 = (D[0] + D[1])/2
294        r1_1 = (D[1] + D[2])/2
295        r2_1 = (D[2] + D[3])/2
296
297        # The reference values are
298        r0 = (r0_0 + r0_1)/2
299        r1 = (r1_0 + r1_1)/2
300        r2 = (r2_0 + r2_1)/2
301
302        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
303        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
304        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
305
306        ##################
307        #Finally check interpolation 2 thirds of the way
308        #between timestep 15 and 16
309
310        # The reference values are
311        r0 = (r0_0 + 2*r0_1)/3
312        r1 = (r1_0 + 2*r1_1)/3
313        r2 = (r2_0 + 2*r2_1)/3
314
315        #And the file function gives
316        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
317        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
318        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
319
320        fid.close()
321        import os
322        os.remove(filename)
323
324
325
326    def test_spatio_temporal_file_function_different_origin(self):
327        """Test that spatio temporal file function performs the correct
328        interpolations in both time and space where space is offset by
329        xllcorner and yllcorner
330        NetCDF version (x,y,t dependency)       
331        """
332        import time
333
334        #Create sww file of simple propagation from left to right
335        #through rectangular domain
336        from shallow_water import Domain, Dirichlet_boundary
337        from mesh_factory import rectangular
338
339
340        from anuga.coordinate_transforms.geo_reference import Geo_reference
341        xllcorner = 2048
342        yllcorner = 11000
343        zone = 2
344
345        #Create basic mesh and shallow water domain
346        points, vertices, boundary = rectangular(3, 3)
347        domain1 = Domain(points, vertices, boundary,
348                         geo_reference = Geo_reference(xllcorner = xllcorner,
349                                                       yllcorner = yllcorner))
350       
351
352        from anuga.utilities.numerical_tools import mean       
353        domain1.reduction = mean
354        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
355                              # only one value.
356
357        domain1.default_order = 2
358        domain1.store = True
359        domain1.set_datadir('.')
360        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
361        domain1.quantities_to_be_stored = ['stage', 'xmomentum', 'ymomentum']
362
363        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
364        domain1.set_quantity('elevation', 0)
365        domain1.set_quantity('friction', 0)
366        domain1.set_quantity('stage', 0)
367
368        # Boundary conditions
369        B0 = Dirichlet_boundary([0,0,0])
370        B6 = Dirichlet_boundary([0.6,0,0])
371        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
372        domain1.check_integrity()
373
374        finaltime = 8
375        #Evolution
376        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
377            pass
378            #domain1.write_time()
379
380
381        #Now read data from sww and check
382        from Scientific.IO.NetCDF import NetCDFFile
383        filename = domain1.get_name() + '.sww'
384        fid = NetCDFFile(filename)
385
386        x = fid.variables['x'][:]
387        y = fid.variables['y'][:]
388        # we 'cast' to 64 bit floats to pass this test
389        # SWW file quantities are stored as 32 bits
390        x = num.array(x, num.float)
391        y = num.array(y, num.float)
392
393        stage = fid.variables['stage'][:]
394        xmomentum = fid.variables['xmomentum'][:]
395        ymomentum = fid.variables['ymomentum'][:]
396        time = fid.variables['time'][:]
397
398        #Take stage vertex values at last timestep on diagonal
399        #Diagonal is identified by vertices: 0, 5, 10, 15
400
401        last_time_index = len(time)-1 #Last last_time_index     
402        d_stage = num.reshape(num.take(stage[last_time_index, :],
403                                       [0,5,10,15],
404                                       axis=0),
405                              (4,1))
406        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
407                                    [0,5,10,15],
408                                   axis=0),
409                           (4,1))
410        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
411                                    [0,5,10,15],
412                                    axis=0),
413                           (4,1))
414        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
415
416        #Reference interpolated values at midpoints on diagonal at
417        #this timestep are
418        r0 = (D[0] + D[1])/2
419        r1 = (D[1] + D[2])/2
420        r2 = (D[2] + D[3])/2
421
422        #And the midpoints are found now
423        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
424        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
425
426        diag = num.concatenate((Dx, Dy), axis=1)
427        d_midpoints = (diag[1:] + diag[:-1])/2
428
429
430        #Adjust for georef - make interpolation points absolute
431        d_midpoints[:,0] += xllcorner
432        d_midpoints[:,1] += yllcorner               
433
434        #Let us see if the file function can find the correct
435        #values at the midpoints at the last timestep:
436        f = file_function(filename, domain1,
437                          interpolation_points = d_midpoints)
438
439        t = time[last_time_index]                         
440
441        q = f(t, point_id=0)
442        msg = '\nr0=%s\nq=%s' % (str(r0), str(q))
443        assert num.allclose(r0, q), msg
444
445        q = f(t, point_id=1)
446        msg = '\nr1=%s\nq=%s' % (str(r1), str(q))
447        assert num.allclose(r1, q), msg
448
449        q = f(t, point_id=2)
450        msg = '\nr2=%s\nq=%s' % (str(r2), str(q))
451        assert num.allclose(r2, q), msg
452
453
454        ##################
455        #Now do the same for the first timestep
456
457        timestep = 0 #First timestep
458        d_stage = num.reshape(num.take(stage[timestep, :],
459                                       [0,5,10,15],
460                                       axis=0),
461                              (4,1))
462        d_uh = num.reshape(num.take(xmomentum[timestep, :],
463                                    [0,5,10,15],
464                                    axis=0),
465                           (4,1))
466        d_vh = num.reshape(num.take(ymomentum[timestep, :],
467                                    [0,5,10,15],
468                                    axis=0),
469                           (4,1))
470        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
471
472        #Reference interpolated values at midpoints on diagonal at
473        #this timestep are
474        r0 = (D[0] + D[1])/2
475        r1 = (D[1] + D[2])/2
476        r2 = (D[2] + D[3])/2
477
478        #Let us see if the file function can find the correct
479        #values
480        q = f(0, point_id=0); assert num.allclose(r0, q)
481        q = f(0, point_id=1); assert num.allclose(r1, q)
482        q = f(0, point_id=2); assert num.allclose(r2, q)
483
484
485        ##################
486        #Now do it again for a timestep in the middle
487
488        timestep = 33
489        d_stage = num.reshape(num.take(stage[timestep, :],
490                                       [0,5,10,15],
491                                       axis=0),
492                              (4,1))
493        d_uh = num.reshape(num.take(xmomentum[timestep, :],
494                                    [0,5,10,15],
495                                    axis=0),
496                           (4,1))
497        d_vh = num.reshape(num.take(ymomentum[timestep, :],
498                                    [0,5,10,15],
499                                    axis=0),
500                           (4,1))
501        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
502
503        #Reference interpolated values at midpoints on diagonal at
504        #this timestep are
505        r0 = (D[0] + D[1])/2
506        r1 = (D[1] + D[2])/2
507        r2 = (D[2] + D[3])/2
508
509        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
510        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
511        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
512
513
514        ##################
515        #Now check temporal interpolation
516        #Halfway between timestep 15 and 16
517
518        timestep = 15
519        d_stage = num.reshape(num.take(stage[timestep, :],
520                                       [0,5,10,15],
521                                       axis=0),
522                              (4,1))
523        d_uh = num.reshape(num.take(xmomentum[timestep, :],
524                                    [0,5,10,15],
525                                    axis=0),
526                           (4,1))
527        d_vh = num.reshape(num.take(ymomentum[timestep, :],
528                                    [0,5,10,15],
529                                    axis=0),
530                           (4,1))
531        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
532
533        #Reference interpolated values at midpoints on diagonal at
534        #this timestep are
535        r0_0 = (D[0] + D[1])/2
536        r1_0 = (D[1] + D[2])/2
537        r2_0 = (D[2] + D[3])/2
538
539        #
540        timestep = 16
541        d_stage = num.reshape(num.take(stage[timestep, :],
542                                       [0,5,10,15],
543                                       axis=0),
544                              (4,1))
545        d_uh = num.reshape(num.take(xmomentum[timestep, :],
546                                    [0,5,10,15],
547                                    axis=0),
548                           (4,1))
549        d_vh = num.reshape(num.take(ymomentum[timestep, :],
550                                    [0,5,10,15],
551                                    axis=0),
552                           (4,1))
553        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
554
555        #Reference interpolated values at midpoints on diagonal at
556        #this timestep are
557        r0_1 = (D[0] + D[1])/2
558        r1_1 = (D[1] + D[2])/2
559        r2_1 = (D[2] + D[3])/2
560
561        # The reference values are
562        r0 = (r0_0 + r0_1)/2
563        r1 = (r1_0 + r1_1)/2
564        r2 = (r2_0 + r2_1)/2
565
566        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
567        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
568        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
569
570        ##################
571        #Finally check interpolation 2 thirds of the way
572        #between timestep 15 and 16
573
574        # The reference values are
575        r0 = (r0_0 + 2*r0_1)/3
576        r1 = (r1_0 + 2*r1_1)/3
577        r2 = (r2_0 + 2*r2_1)/3
578
579        #And the file function gives
580        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
581        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
582        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
583
584        fid.close()
585        import os
586        os.remove(filename)
587
588       
589
590
591    def test_spatio_temporal_file_function_time(self):
592        """Test that File function interpolates correctly
593        between given times.
594        NetCDF version (x,y,t dependency)
595        """
596
597        #Create NetCDF (sww) file to be read
598        # x: 0, 5, 10, 15
599        # y: -20, -10, 0, 10
600        # t: 0, 60, 120, ...., 1200
601        #
602        # test quantities (arbitrary but non-trivial expressions):
603        #
604        #   stage     = 3*x - y**2 + 2*t
605        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
606        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
607
608        #NOTE: Nice test that may render some of the others redundant.
609
610        import os, time
611        from anuga.config import time_format
612        from mesh_factory import rectangular
613        from shallow_water import Domain
614        import anuga.shallow_water.data_manager
615
616        finaltime = 1200
617        filename = 'test_file_function'
618
619        #Create a domain to hold test grid
620        #(0:15, -20:10)
621        points, vertices, boundary =\
622                rectangular(4, 4, 15, 30, origin = (0, -20))
623        #print "points", points
624
625        #print 'Number of elements', len(vertices)
626        domain = Domain(points, vertices, boundary)
627        domain.smooth = False
628        domain.default_order = 2
629        domain.set_datadir('.')
630        domain.set_name(filename)
631        domain.store = True
632
633        #print points
634        start = time.mktime(time.strptime('2000', '%Y'))
635        domain.starttime = start
636
637
638        #Store structure
639        domain.initialise_storage()
640
641        #Compute artificial time steps and store
642        dt = 60  #One minute intervals
643        t = 0.0
644        while t <= finaltime:
645            #Compute quantities
646            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
647            domain.set_quantity('stage', f1)
648
649            f2 = lambda x,y: x+y+t**2
650            domain.set_quantity('xmomentum', f2)
651
652            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
653            domain.set_quantity('ymomentum', f3)
654
655            #Store and advance time
656            domain.time = t
657            domain.store_timestep()
658            t += dt
659
660
661        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
662     
663        #Deliberately set domain.starttime to too early
664        domain.starttime = start - 1
665
666        #Create file function
667        F = file_function(filename + '.sww', domain,
668                          quantities = domain.conserved_quantities,
669                          interpolation_points = interpolation_points)
670
671        #Check that FF updates fixes domain starttime
672        assert num.allclose(domain.starttime, start)
673
674        #Check that domain.starttime isn't updated if later
675        domain.starttime = start + 1
676        F = file_function(filename + '.sww', domain,
677                          quantities = domain.conserved_quantities,
678                          interpolation_points = interpolation_points)
679        assert num.allclose(domain.starttime, start+1)
680        domain.starttime = start
681
682
683        #Check linear interpolation in time
684        F = file_function(filename + '.sww', domain,
685                          quantities = domain.conserved_quantities,
686                          interpolation_points = interpolation_points)               
687        for id in range(len(interpolation_points)):
688            x = interpolation_points[id][0]
689            y = interpolation_points[id][1]
690
691            for i in range(20):
692                t = i*10
693                k = i%6
694
695                if k == 0:
696                    q0 = F(t, point_id=id)
697                    q1 = F(t+60, point_id=id)
698
699                if num.alltrue(q0 == NAN):
700                    actual = q0
701                else:
702                    actual = (k*q1 + (6-k)*q0)/6
703                q = F(t, point_id=id)
704                #print i, k, t, q
705                #print ' ', q0
706                #print ' ', q1
707                #print "q",q
708                #print "actual", actual
709                #print
710                if num.alltrue(q0 == NAN):
711                     self.failUnless(num.alltrue(q == actual), 'Fail!')
712                else:
713                    assert num.allclose(q, actual)
714
715
716        #Another check of linear interpolation in time
717        for id in range(len(interpolation_points)):
718            q60 = F(60, point_id=id)
719            q120 = F(120, point_id=id)
720
721            t = 90 #Halfway between 60 and 120
722            q = F(t, point_id=id)
723            assert num.allclose( (q120+q60)/2, q )
724
725            t = 100 #Two thirds of the way between between 60 and 120
726            q = F(t, point_id=id)
727            assert num.allclose(q60/3 + 2*q120/3, q)
728
729
730
731        #Check that domain.starttime isn't updated if later than file starttime but earlier
732        #than file end time
733        delta = 23
734        domain.starttime = start + delta
735        F = file_function(filename + '.sww', domain,
736                          quantities = domain.conserved_quantities,
737                          interpolation_points = interpolation_points)
738        assert num.allclose(domain.starttime, start+delta)
739
740
741
742
743        #Now try interpolation with delta offset
744        for id in range(len(interpolation_points)):           
745            x = interpolation_points[id][0]
746            y = interpolation_points[id][1]
747
748            for i in range(20):
749                t = i*10
750                k = i%6
751
752                if k == 0:
753                    q0 = F(t-delta, point_id=id)
754                    q1 = F(t+60-delta, point_id=id)
755
756                q = F(t-delta, point_id=id)
757                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
758
759
760        os.remove(filename + '.sww')
761
762
763
764    def Xtest_spatio_temporal_file_function_time(self):
765        # FIXME: This passes but needs some TLC
766        # Test that File function interpolates correctly
767        # When some points are outside the mesh
768
769        import os, time
770        from anuga.config import time_format
771        from mesh_factory import rectangular
772        from shallow_water import Domain
773        import anuga.shallow_water.data_manager 
774        from anuga.pmesh.mesh_interface import create_mesh_from_regions
775        finaltime = 1200
776       
777        filename = tempfile.mktemp()
778        #print "filename",filename
779        filename = 'test_file_function'
780
781        meshfilename = tempfile.mktemp(".tsh")
782
783        boundary_tags = {'walls':[0,1],'bom':[2]}
784       
785        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
786       
787        create_mesh_from_regions(polygon_absolute,
788                                 boundary_tags,
789                                 10000000,
790                                 filename=meshfilename)
791        domain = Domain(mesh_filename=meshfilename)
792        domain.smooth = False
793        domain.default_order = 2
794        domain.set_datadir('.')
795        domain.set_name(filename)
796        domain.store = True
797
798        #print points
799        start = time.mktime(time.strptime('2000', '%Y'))
800        domain.starttime = start
801       
802
803        #Store structure
804        domain.initialise_storage()
805
806        #Compute artificial time steps and store
807        dt = 60  #One minute intervals
808        t = 0.0
809        while t <= finaltime:
810            #Compute quantities
811            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
812            domain.set_quantity('stage', f1)
813
814            f2 = lambda x,y: x+y+t**2
815            domain.set_quantity('xmomentum', f2)
816
817            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
818            domain.set_quantity('ymomentum', f3)
819
820            #Store and advance time
821            domain.time = t
822            domain.store_timestep(domain.conserved_quantities)
823            t += dt
824
825        interpolation_points = [[1,0]]
826        interpolation_points = [[100,1000]]
827       
828        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
829                                [78787,78787],[7878,3432]]
830           
831        #Deliberately set domain.starttime to too early
832        domain.starttime = start - 1
833
834        #Create file function
835        F = file_function(filename + '.sww', domain,
836                          quantities = domain.conserved_quantities,
837                          interpolation_points = interpolation_points)
838
839        #Check that FF updates fixes domain starttime
840        assert num.allclose(domain.starttime, start)
841
842        #Check that domain.starttime isn't updated if later
843        domain.starttime = start + 1
844        F = file_function(filename + '.sww', domain,
845                          quantities = domain.conserved_quantities,
846                          interpolation_points = interpolation_points)
847        assert num.allclose(domain.starttime, start+1)
848        domain.starttime = start
849
850
851        #Check linear interpolation in time
852        # checking points inside and outside the mesh
853        F = file_function(filename + '.sww', domain,
854                          quantities = domain.conserved_quantities,
855                          interpolation_points = interpolation_points)
856       
857        for id in range(len(interpolation_points)):
858            x = interpolation_points[id][0]
859            y = interpolation_points[id][1]
860
861            for i in range(20):
862                t = i*10
863                k = i%6
864
865                if k == 0:
866                    q0 = F(t, point_id=id)
867                    q1 = F(t+60, point_id=id)
868
869                if q0 == NAN:
870                    actual = q0
871                else:
872                    actual = (k*q1 + (6-k)*q0)/6
873                q = F(t, point_id=id)
874                #print i, k, t, q
875                #print ' ', q0
876                #print ' ', q1
877                #print "q",q
878                #print "actual", actual
879                #print
880                if q0 == NAN:
881                     self.failUnless( q == actual, 'Fail!')
882                else:
883                    assert num.allclose(q, actual)
884
885        # now lets check points inside the mesh
886        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
887        interpolation_points = [[10,-12.5]]
888           
889        print "len(interpolation_points)",len(interpolation_points) 
890        F = file_function(filename + '.sww', domain,
891                          quantities = domain.conserved_quantities,
892                          interpolation_points = interpolation_points)
893
894        domain.starttime = start
895
896
897        #Check linear interpolation in time
898        F = file_function(filename + '.sww', domain,
899                          quantities = domain.conserved_quantities,
900                          interpolation_points = interpolation_points)               
901        for id in range(len(interpolation_points)):
902            x = interpolation_points[id][0]
903            y = interpolation_points[id][1]
904
905            for i in range(20):
906                t = i*10
907                k = i%6
908
909                if k == 0:
910                    q0 = F(t, point_id=id)
911                    q1 = F(t+60, point_id=id)
912
913                if q0 == NAN:
914                    actual = q0
915                else:
916                    actual = (k*q1 + (6-k)*q0)/6
917                q = F(t, point_id=id)
918                print "############"
919                print "id, x, y ", id, x, y #k, t, q
920                print "t", t
921                #print ' ', q0
922                #print ' ', q1
923                print "q",q
924                print "actual", actual
925                #print
926                if q0 == NAN:
927                     self.failUnless( q == actual, 'Fail!')
928                else:
929                    assert num.allclose(q, actual)
930
931
932        #Another check of linear interpolation in time
933        for id in range(len(interpolation_points)):
934            q60 = F(60, point_id=id)
935            q120 = F(120, point_id=id)
936
937            t = 90 #Halfway between 60 and 120
938            q = F(t, point_id=id)
939            assert num.allclose( (q120+q60)/2, q )
940
941            t = 100 #Two thirds of the way between between 60 and 120
942            q = F(t, point_id=id)
943            assert num.allclose(q60/3 + 2*q120/3, q)
944
945
946
947        #Check that domain.starttime isn't updated if later than file starttime but earlier
948        #than file end time
949        delta = 23
950        domain.starttime = start + delta
951        F = file_function(filename + '.sww', domain,
952                          quantities = domain.conserved_quantities,
953                          interpolation_points = interpolation_points)
954        assert num.allclose(domain.starttime, start+delta)
955
956
957
958
959        #Now try interpolation with delta offset
960        for id in range(len(interpolation_points)):           
961            x = interpolation_points[id][0]
962            y = interpolation_points[id][1]
963
964            for i in range(20):
965                t = i*10
966                k = i%6
967
968                if k == 0:
969                    q0 = F(t-delta, point_id=id)
970                    q1 = F(t+60-delta, point_id=id)
971
972                q = F(t-delta, point_id=id)
973                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
974
975
976        os.remove(filename + '.sww')
977
978    def test_file_function_time_with_domain(self):
979        """Test that File function interpolates correctly
980        between given times. No x,y dependency here.
981        Use domain with starttime
982        """
983
984        #Write file
985        import os, time, calendar
986        from anuga.config import time_format
987        from math import sin, pi
988        from domain import Domain
989
990        finaltime = 1200
991        filename = 'test_file_function'
992        fid = open(filename + '.txt', 'w')
993        start = time.mktime(time.strptime('2000', '%Y'))
994        dt = 60  #One minute intervals
995        t = 0.0
996        while t <= finaltime:
997            t_string = time.strftime(time_format, time.gmtime(t+start))
998            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
999            t += dt
1000
1001        fid.close()
1002
1003
1004        #Convert ASCII file to NetCDF (Which is what we really like!)
1005        timefile2netcdf(filename)
1006
1007
1008
1009        a = [0.0, 0.0]
1010        b = [4.0, 0.0]
1011        c = [0.0, 3.0]
1012
1013        points = [a, b, c]
1014        vertices = [[0,1,2]]
1015        domain = Domain(points, vertices)
1016
1017        # Check that domain.starttime is updated if non-existing
1018        F = file_function(filename + '.tms',
1019                          domain,
1020                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
1021        assert num.allclose(domain.starttime, start)
1022
1023        # Check that domain.starttime is updated if too early
1024        domain.starttime = start - 1
1025        F = file_function(filename + '.tms',
1026                          domain,
1027                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1028        assert num.allclose(domain.starttime, start)
1029
1030        # Check that domain.starttime isn't updated if later
1031        domain.starttime = start + 1
1032        F = file_function(filename + '.tms',
1033                          domain,
1034                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1035        assert num.allclose(domain.starttime, start+1)
1036
1037        domain.starttime = start
1038        F = file_function(filename + '.tms',
1039                          domain,
1040                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
1041                          use_cache=True)
1042       
1043
1044        #print F.precomputed_values
1045        #print 'F(60)', F(60)
1046       
1047        #Now try interpolation
1048        for i in range(20):
1049            t = i*10
1050            q = F(t)
1051
1052            #Exact linear intpolation
1053            assert num.allclose(q[0], 2*t)
1054            if i%6 == 0:
1055                assert num.allclose(q[1], t**2)
1056                assert num.allclose(q[2], sin(t*pi/600))
1057
1058        #Check non-exact
1059
1060        t = 90 #Halfway between 60 and 120
1061        q = F(t)
1062        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1063        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1064
1065
1066        t = 100 #Two thirds of the way between between 60 and 120
1067        q = F(t)
1068        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1069        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1070
1071        os.remove(filename + '.tms')
1072        os.remove(filename + '.txt')       
1073
1074    def test_file_function_time_with_domain_different_start(self):
1075        """Test that File function interpolates correctly
1076        between given times. No x,y dependency here.
1077        Use domain with a starttime later than that of file
1078
1079        ASCII version
1080        """
1081
1082        #Write file
1083        import os, time, calendar
1084        from anuga.config import time_format
1085        from math import sin, pi
1086        from domain import Domain
1087
1088        finaltime = 1200
1089        filename = 'test_file_function'
1090        fid = open(filename + '.txt', 'w')
1091        start = time.mktime(time.strptime('2000', '%Y'))
1092        dt = 60  #One minute intervals
1093        t = 0.0
1094        while t <= finaltime:
1095            t_string = time.strftime(time_format, time.gmtime(t+start))
1096            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1097            t += dt
1098
1099        fid.close()
1100
1101        #Convert ASCII file to NetCDF (Which is what we really like!)
1102        timefile2netcdf(filename)       
1103
1104        a = [0.0, 0.0]
1105        b = [4.0, 0.0]
1106        c = [0.0, 3.0]
1107
1108        points = [a, b, c]
1109        vertices = [[0,1,2]]
1110        domain = Domain(points, vertices)
1111
1112        #Check that domain.starttime isn't updated if later than file starttime but earlier
1113        #than file end time
1114        delta = 23
1115        domain.starttime = start + delta
1116        F = file_function(filename + '.tms', domain,
1117                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1118        assert num.allclose(domain.starttime, start+delta)
1119
1120        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1121                                            277., 337., 397., 457., 517.,
1122                                            577., 637., 697., 757., 817.,
1123                                            877., 937., 997., 1057., 1117.,
1124                                            1177.])
1125
1126
1127        #Now try interpolation with delta offset
1128        for i in range(20):
1129            t = i*10
1130            q = F(t-delta)
1131
1132            #Exact linear intpolation
1133            assert num.allclose(q[0], 2*t)
1134            if i%6 == 0:
1135                assert num.allclose(q[1], t**2)
1136                assert num.allclose(q[2], sin(t*pi/600))
1137
1138        #Check non-exact
1139
1140        t = 90 #Halfway between 60 and 120
1141        q = F(t-delta)
1142        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1143        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1144
1145
1146        t = 100 #Two thirds of the way between between 60 and 120
1147        q = F(t-delta)
1148        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1149        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1150
1151
1152        os.remove(filename + '.tms')
1153        os.remove(filename + '.txt')               
1154
1155       
1156
1157    def test_file_function_time_with_domain_different_start_and_time_limit(self):
1158        """Test that File function interpolates correctly
1159        between given times. No x,y dependency here.
1160        Use domain with a starttime later than that of file
1161
1162        ASCII version
1163       
1164        This test also tests that time can be truncated.
1165        """
1166
1167        # Write file
1168        import os, time, calendar
1169        from anuga.config import time_format
1170        from math import sin, pi
1171        from domain import Domain
1172
1173        finaltime = 1200
1174        filename = 'test_file_function'
1175        fid = open(filename + '.txt', 'w')
1176        start = time.mktime(time.strptime('2000', '%Y'))
1177        dt = 60  #One minute intervals
1178        t = 0.0
1179        while t <= finaltime:
1180            t_string = time.strftime(time_format, time.gmtime(t+start))
1181            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1182            t += dt
1183
1184        fid.close()
1185
1186        # Convert ASCII file to NetCDF (Which is what we really like!)
1187        timefile2netcdf(filename)       
1188
1189        a = [0.0, 0.0]
1190        b = [4.0, 0.0]
1191        c = [0.0, 3.0]
1192
1193        points = [a, b, c]
1194        vertices = [[0,1,2]]
1195        domain = Domain(points, vertices)
1196
1197        # Check that domain.starttime isn't updated if later than file starttime but earlier
1198        # than file end time
1199        delta = 23
1200        domain.starttime = start + delta
1201        time_limit = domain.starttime + 600
1202        F = file_function(filename + '.tms', domain,
1203                          time_limit=time_limit,
1204                          quantities=['Attribute0', 'Attribute1', 'Attribute2'])       
1205        assert num.allclose(domain.starttime, start+delta)
1206
1207        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1208                                            277., 337., 397., 457., 517.,
1209                                            577.])       
1210
1211
1212
1213        # Now try interpolation with delta offset
1214        for i in range(20):
1215            t = i*10
1216            q = F(t-delta)
1217
1218            #Exact linear intpolation
1219            assert num.allclose(q[0], 2*t)
1220            if i%6 == 0:
1221                assert num.allclose(q[1], t**2)
1222                assert num.allclose(q[2], sin(t*pi/600))
1223
1224        # Check non-exact
1225        t = 90 #Halfway between 60 and 120
1226        q = F(t-delta)
1227        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1228        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1229
1230
1231        t = 100 # Two thirds of the way between between 60 and 120
1232        q = F(t-delta)
1233        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1234        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1235
1236
1237        os.remove(filename + '.tms')
1238        os.remove(filename + '.txt')               
1239
1240       
1241       
1242       
1243
1244
1245    def test_apply_expression_to_dictionary(self):
1246
1247        #FIXME: Division is not expected to work for integers.
1248        #This must be caught.
1249        foo = num.array([[1,2,3], [4,5,6]], num.float)
1250
1251        bar = num.array([[-1,0,5], [6,1,1]], num.float)                 
1252
1253        D = {'X': foo, 'Y': bar}
1254
1255        Z = apply_expression_to_dictionary('X+Y', D)       
1256        assert num.allclose(Z, foo+bar)
1257
1258        Z = apply_expression_to_dictionary('X*Y', D)       
1259        assert num.allclose(Z, foo*bar)       
1260
1261        Z = apply_expression_to_dictionary('4*X+Y', D)       
1262        assert num.allclose(Z, 4*foo+bar)       
1263
1264        # test zero division is OK
1265        Z = apply_expression_to_dictionary('X/Y', D)
1266        assert num.allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1267
1268        # make an error for zero on zero
1269        # this is really an error in numeric, SciPy core can handle it
1270        # Z = apply_expression_to_dictionary('0/Y', D)
1271
1272        #Check exceptions
1273        try:
1274            #Wrong name
1275            Z = apply_expression_to_dictionary('4*X+A', D)       
1276        except NameError:
1277            pass
1278        else:
1279            msg = 'Should have raised a NameError Exception'
1280            raise msg
1281
1282
1283        try:
1284            #Wrong order
1285            Z = apply_expression_to_dictionary(D, '4*X+A')       
1286        except AssertionError:
1287            pass
1288        else:
1289            msg = 'Should have raised a AssertionError Exception'
1290            raise msg       
1291       
1292
1293    def test_multiple_replace(self):
1294        """Hard test that checks a true word-by-word simultaneous replace
1295        """
1296       
1297        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1298        exp = '3*x+y + xi'
1299       
1300        new = multiple_replace(exp, D)
1301       
1302        assert new == '3*xi+eta + lam'
1303                         
1304
1305
1306    def test_point_on_line_obsolete(self):
1307        """Test that obsolete call issues appropriate warning"""
1308
1309        #Turn warning into an exception
1310        import warnings
1311        warnings.filterwarnings('error')
1312
1313        try:
1314            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1315        except DeprecationWarning:
1316            pass
1317        else:
1318            msg = 'point_on_line should have issued a DeprecationWarning'
1319            raise Exception(msg)   
1320
1321        warnings.resetwarnings()
1322   
1323    def test_get_revision_number(self):
1324        """test_get_revision_number(self):
1325
1326        Test that revision number can be retrieved.
1327        """
1328        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1329            # I have a known snv incompatability issue,
1330            # so I'm skipping this test.
1331            # FIXME when SVN is upgraded on our clusters
1332            pass
1333        else:   
1334            n = get_revision_number()
1335            assert n>=0
1336
1337
1338       
1339    def test_add_directories(self):
1340       
1341        import tempfile
1342        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1343        directories = ['ja','ne','ke']
1344        kens_dir = add_directories(root_dir, directories)
1345        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1346               sep + 'ke'
1347        assert access(root_dir,F_OK)
1348
1349        add_directories(root_dir, directories)
1350        assert access(root_dir,F_OK)
1351       
1352        #clean up!
1353        os.rmdir(kens_dir)
1354        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1355        os.rmdir(root_dir + sep + 'ja')
1356        os.rmdir(root_dir)
1357
1358    def test_add_directories_bad(self):
1359       
1360        import tempfile
1361        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1362        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1363       
1364        try:
1365            kens_dir = add_directories(root_dir, directories)
1366        except OSError:
1367            pass
1368        else:
1369            msg = 'bad dir name should give OSError'
1370            raise Exception(msg)   
1371           
1372        #clean up!
1373        os.rmdir(root_dir)
1374
1375    def test_check_list(self):
1376
1377        check_list(['stage','xmomentum'])
1378
1379       
1380    def test_add_directories(self):
1381       
1382        import tempfile
1383        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1384        directories = ['ja','ne','ke']
1385        kens_dir = add_directories(root_dir, directories)
1386        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1387               sep + 'ke'
1388        assert access(root_dir,F_OK)
1389
1390        add_directories(root_dir, directories)
1391        assert access(root_dir,F_OK)
1392       
1393        #clean up!
1394        os.rmdir(kens_dir)
1395        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1396        os.rmdir(root_dir + sep + 'ja')
1397        os.rmdir(root_dir)
1398
1399    def test_add_directories_bad(self):
1400       
1401        import tempfile
1402        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1403        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1404       
1405        try:
1406            kens_dir = add_directories(root_dir, directories)
1407        except OSError:
1408            pass
1409        else:
1410            msg = 'bad dir name should give OSError'
1411            raise Exception(msg)   
1412           
1413        #clean up!
1414        os.rmdir(root_dir)
1415
1416    def test_check_list(self):
1417
1418        check_list(['stage','xmomentum'])
1419
1420######
1421# Test the remove_lone_verts() function
1422######
1423       
1424    def test_remove_lone_verts_a(self):
1425        verts = [[0,0],[1,0],[0,1]]
1426        tris = [[0,1,2]]
1427        new_verts, new_tris = remove_lone_verts(verts, tris)
1428        self.failUnless(new_verts.tolist() == verts)
1429        self.failUnless(new_tris.tolist() == tris)
1430
1431    def test_remove_lone_verts_b(self):
1432        verts = [[0,0],[1,0],[0,1],[99,99]]
1433        tris = [[0,1,2]]
1434        new_verts, new_tris = remove_lone_verts(verts, tris)
1435        self.failUnless(new_verts.tolist() == verts[0:3])
1436        self.failUnless(new_tris.tolist() == tris)
1437       
1438    def test_remove_lone_verts_c(self):
1439        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1440        tris = [[1,2,4]]
1441        new_verts, new_tris = remove_lone_verts(verts, tris)
1442        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1443        self.failUnless(new_tris.tolist() == [[0,1,2]])
1444     
1445    def test_remove_lone_verts_d(self):
1446        verts = [[0,0],[1,0],[99,99],[0,1]]
1447        tris = [[0,1,3]]
1448        new_verts, new_tris = remove_lone_verts(verts, tris)
1449        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1450        self.failUnless(new_tris.tolist() == [[0,1,2]])
1451       
1452    def test_remove_lone_verts_e(self):
1453        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1454        tris = [[0,1,2]]
1455        new_verts, new_tris = remove_lone_verts(verts, tris)
1456        self.failUnless(new_verts.tolist() == verts[0:3])
1457        self.failUnless(new_tris.tolist() == tris)
1458     
1459    def test_remove_lone_verts_f(self):
1460        verts = [[0,0],[1,0],[99,99],[0,1],[99,99],[1,1],[99,99]]
1461        tris = [[0,1,3],[0,1,5]]
1462        new_verts, new_tris = remove_lone_verts(verts, tris)
1463        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1],[1,1]])
1464        self.failUnless(new_tris.tolist() == [[0,1,2],[0,1,3]])
1465       
1466######
1467#
1468######
1469       
1470    def test_get_min_max_values(self):
1471       
1472        list=[8,9,6,1,4]
1473        min1, max1 = get_min_max_values(list)
1474       
1475        assert min1==1 
1476        assert max1==9
1477       
1478    def test_get_min_max_values1(self):
1479       
1480        list=[-8,-9,-6,-1,-4]
1481        min1, max1 = get_min_max_values(list)
1482       
1483#        print 'min1,max1',min1,max1
1484        assert min1==-9 
1485        assert max1==-1
1486
1487#    def test_get_min_max_values2(self):
1488#        '''
1489#        The min and max supplied are greater than the ones in the
1490#        list and therefore are the ones returned
1491#        '''
1492#        list=[-8,-9,-6,-1,-4]
1493#        min1, max1 = get_min_max_values(list,-10,10)
1494#       
1495##        print 'min1,max1',min1,max1
1496#        assert min1==-10
1497#        assert max1==10
1498       
1499    def test_make_plots_from_csv_files(self):
1500       
1501        #if sys.platform == 'win32':  #Windows
1502            try: 
1503                import pylab
1504            except ImportError:
1505                #ANUGA don't need pylab to work so the system doesn't
1506                #rely on pylab being installed
1507                return
1508           
1509       
1510            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1511            temp_dir = tempfile.mkdtemp('','figures')
1512    #        print 'temp_dir',temp_dir
1513            fileName = temp_dir+sep+'time_series_3.csv'
1514            file = open(fileName,"w")
1515            file.write("time,stage,speed,momentum,elevation\n\
15161.0, 0, 0, 0, 10 \n\
15172.0, 5, 2, 4, 10 \n\
15183.0, 3, 3, 5, 10 \n")
1519            file.close()
1520   
1521            fileName1 = temp_dir+sep+'time_series_4.csv'
1522            file1 = open(fileName1,"w")
1523            file1.write("time,stage,speed,momentum,elevation\n\
15241.0, 0, 0, 0, 5 \n\
15252.0, -5, -2, -4, 5 \n\
15263.0, -4, -3, -5, 5 \n")
1527            file1.close()
1528   
1529            fileName2 = temp_dir+sep+'time_series_5.csv'
1530            file2 = open(fileName2,"w")
1531            file2.write("time,stage,speed,momentum,elevation\n\
15321.0, 0, 0, 0, 7 \n\
15332.0, 4, -0.45, 57, 7 \n\
15343.0, 6, -0.5, 56, 7 \n")
1535            file2.close()
1536           
1537            dir, name=os.path.split(fileName)
1538            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1539                                  output_dir=temp_dir,
1540                                  base_name='time_series_',
1541                                  plot_numbers=['3-5'],
1542                                  quantities=['speed','stage','momentum'],
1543                                  assess_all_csv_files=True,
1544                                  extra_plot_name='test')
1545           
1546            #print dir+sep+name[:-4]+'_stage_test.png'
1547            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1548            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1549            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1550   
1551            dir1, name1=os.path.split(fileName1)
1552            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1553            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1554            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1555   
1556   
1557            dir2, name2=os.path.split(fileName2)
1558            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1559            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1560            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1561   
1562            del_dir(temp_dir)
1563       
1564
1565    def test_sww2csv_gauges(self):
1566
1567        def elevation_function(x, y):
1568            return -x
1569       
1570        """Most of this test was copied from test_interpolate
1571        test_interpole_sww2csv
1572       
1573        This is testing the gauge_sww2csv function, by creating a sww file and
1574        then exporting the gauges and checking the results.
1575        """
1576       
1577        # Create mesh
1578        mesh_file = tempfile.mktemp(".tsh")   
1579        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1580        m = Mesh()
1581        m.add_vertices(points)
1582        m.auto_segment()
1583        m.generate_mesh(verbose=False)
1584        m.export_mesh_file(mesh_file)
1585       
1586        # Create shallow water domain
1587        domain = Domain(mesh_file)
1588        os.remove(mesh_file)
1589       
1590        domain.default_order=2
1591       
1592        # This test was made before tight_slope_limiters were introduced
1593        # Since were are testing interpolation values this is OK
1594        domain.tight_slope_limiters = 0 
1595       
1596
1597        # Set some field values
1598        domain.set_quantity('elevation', elevation_function)
1599        domain.set_quantity('friction', 0.03)
1600        domain.set_quantity('xmomentum', 3.0)
1601        domain.set_quantity('ymomentum', 4.0)
1602
1603        ######################
1604        # Boundary conditions
1605        B = Transmissive_boundary(domain)
1606        domain.set_boundary( {'exterior': B})
1607
1608        # This call mangles the stage values.
1609        domain.distribute_to_vertices_and_edges()
1610        domain.set_quantity('stage', 1.0)
1611
1612
1613        domain.set_name('datatest' + str(time.time()))
1614        domain.smooth = True
1615        domain.reduction = mean
1616
1617
1618        sww = SWW_file(domain)
1619        sww.store_connectivity()
1620        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1621        domain.set_quantity('stage', 10.0) # This is automatically limited
1622        # so it will not be less than the elevation
1623        domain.time = 2.
1624        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1625
1626        # test the function
1627        points = [[5.0,1.],[0.5,2.]]
1628
1629        points_file = tempfile.mktemp(".csv")
1630#        points_file = 'test_point.csv'
1631        file_id = open(points_file,"w")
1632        file_id.write("name, easting, northing, elevation \n\
1633point1, 5.0, 1.0, 3.0\n\
1634point2, 0.5, 2.0, 9.0\n")
1635        file_id.close()
1636
1637       
1638        sww2csv_gauges(sww.filename, 
1639                       points_file,
1640                       verbose=False,
1641                       use_cache=False)
1642
1643#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1644        point1_answers_array = [[0.0,0.0,1.0,6.0,-5.0,3.0,4.0], [2.0,2.0/3600.,10.0,15.0,-5.0,3.0,4.0]]
1645        point1_filename = 'gauge_point1.csv'
1646        point1_handle = file(point1_filename)
1647        point1_reader = reader(point1_handle)
1648        point1_reader.next()
1649
1650        line=[]
1651        for i,row in enumerate(point1_reader):
1652#            print 'i',i,'row',row
1653            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1654                         float(row[4]),float(row[5]),float(row[6])])
1655#            print 'assert line',line[i],'point1',point1_answers_array[i]
1656            assert num.allclose(line[i], point1_answers_array[i])
1657
1658        point2_answers_array = [[0.0,0.0,1.0,1.5,-0.5,3.0,4.0], [2.0,2.0/3600.,10.0,10.5,-0.5,3.0,4.0]]
1659        point2_filename = 'gauge_point2.csv' 
1660        point2_handle = file(point2_filename)
1661        point2_reader = reader(point2_handle)
1662        point2_reader.next()
1663                       
1664        line=[]
1665        for i,row in enumerate(point2_reader):
1666#            print 'i',i,'row',row
1667            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1668                         float(row[4]),float(row[5]),float(row[6])])
1669#            print 'assert line',line[i],'point1',point1_answers_array[i]
1670            assert num.allclose(line[i], point2_answers_array[i])
1671                         
1672        # clean up
1673        point1_handle.close()
1674        point2_handle.close()
1675        #print "sww.filename",sww.filename
1676        os.remove(sww.filename)
1677        os.remove(points_file)
1678        os.remove(point1_filename)
1679        os.remove(point2_filename)
1680
1681
1682
1683    def test_sww2csv_gauges1(self):
1684        from anuga.pmesh.mesh import Mesh
1685        from anuga.shallow_water import Domain, Transmissive_boundary
1686        from csv import reader,writer
1687        import time
1688        import string
1689
1690        def elevation_function(x, y):
1691            return -x
1692       
1693        """Most of this test was copied from test_interpolate
1694        test_interpole_sww2csv
1695       
1696        This is testing the gauge_sww2csv function, by creating a sww file and
1697        then exporting the gauges and checking the results.
1698       
1699        This tests the ablity not to have elevation in the points file and
1700        not store xmomentum and ymomentum
1701        """
1702       
1703        # Create mesh
1704        mesh_file = tempfile.mktemp(".tsh")   
1705        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1706        m = Mesh()
1707        m.add_vertices(points)
1708        m.auto_segment()
1709        m.generate_mesh(verbose=False)
1710        m.export_mesh_file(mesh_file)
1711       
1712        # Create shallow water domain
1713        domain = Domain(mesh_file)
1714        os.remove(mesh_file)
1715       
1716        domain.default_order=2
1717
1718        # Set some field values
1719        domain.set_quantity('elevation', elevation_function)
1720        domain.set_quantity('friction', 0.03)
1721        domain.set_quantity('xmomentum', 3.0)
1722        domain.set_quantity('ymomentum', 4.0)
1723
1724        ######################
1725        # Boundary conditions
1726        B = Transmissive_boundary(domain)
1727        domain.set_boundary( {'exterior': B})
1728
1729        # This call mangles the stage values.
1730        domain.distribute_to_vertices_and_edges()
1731        domain.set_quantity('stage', 1.0)
1732
1733
1734        domain.set_name('datatest' + str(time.time()))
1735        domain.smooth = True
1736        domain.reduction = mean
1737
1738        sww = SWW_file(domain)
1739        sww.store_connectivity()
1740        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1741        domain.set_quantity('stage', 10.0) # This is automatically limited
1742        # so it will not be less than the elevation
1743        domain.time = 2.
1744        sww.store_timestep(['stage', 'xmomentum', 'ymomentum'])
1745
1746        # test the function
1747        points = [[5.0,1.],[0.5,2.]]
1748
1749        points_file = tempfile.mktemp(".csv")
1750#        points_file = 'test_point.csv'
1751        file_id = open(points_file,"w")
1752        file_id.write("name,easting,northing \n\
1753point1, 5.0, 1.0\n\
1754point2, 0.5, 2.0\n")
1755        file_id.close()
1756
1757        sww2csv_gauges(sww.filename, 
1758                            points_file,
1759                            quantities=['stage', 'elevation'],
1760                            use_cache=False,
1761                            verbose=False)
1762
1763        point1_answers_array = [[0.0,1.0,-5.0], [2.0,10.0,-5.0]]
1764        point1_filename = 'gauge_point1.csv'
1765        point1_handle = file(point1_filename)
1766        point1_reader = reader(point1_handle)
1767        point1_reader.next()
1768
1769        line=[]
1770        for i,row in enumerate(point1_reader):
1771#            print 'i',i,'row',row
1772            # note the 'hole' (element 1) below - skip the new 'hours' field
1773            line.append([float(row[0]),float(row[2]),float(row[3])])
1774            #print 'line',line[i],'point1',point1_answers_array[i]
1775            assert num.allclose(line[i], point1_answers_array[i])
1776
1777        point2_answers_array = [[0.0,1.0,-0.5], [2.0,10.0,-0.5]]
1778        point2_filename = 'gauge_point2.csv' 
1779        point2_handle = file(point2_filename)
1780        point2_reader = reader(point2_handle)
1781        point2_reader.next()
1782                       
1783        line=[]
1784        for i,row in enumerate(point2_reader):
1785#            print 'i',i,'row',row
1786            # note the 'hole' (element 1) below - skip the new 'hours' field
1787            line.append([float(row[0]),float(row[2]),float(row[3])])
1788#            print 'line',line[i],'point1',point1_answers_array[i]
1789            assert num.allclose(line[i], point2_answers_array[i])
1790                         
1791        # clean up
1792        point1_handle.close()
1793        point2_handle.close()
1794        #print "sww.filename",sww.filename
1795        os.remove(sww.filename)
1796        os.remove(points_file)
1797        os.remove(point1_filename)
1798        os.remove(point2_filename)
1799
1800
1801    def test_sww2csv_gauges2(self):
1802
1803        def elevation_function(x, y):
1804            return -x
1805       
1806        """Most of this test was copied from test_interpolate
1807        test_interpole_sww2csv
1808       
1809        This is testing the gauge_sww2csv function, by creating a sww file and
1810        then exporting the gauges and checking the results.
1811       
1812        This is the same as sww2csv_gauges except set domain.set_starttime to 5.
1813        Therefore testing the storing of the absolute time in the csv files
1814        """
1815       
1816        # Create mesh
1817        mesh_file = tempfile.mktemp(".tsh")   
1818        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1819        m = Mesh()
1820        m.add_vertices(points)
1821        m.auto_segment()
1822        m.generate_mesh(verbose=False)
1823        m.export_mesh_file(mesh_file)
1824       
1825        # Create shallow water domain
1826        domain = Domain(mesh_file)
1827        os.remove(mesh_file)
1828       
1829        domain.default_order=2
1830
1831        # This test was made before tight_slope_limiters were introduced
1832        # Since were are testing interpolation values this is OK
1833        domain.tight_slope_limiters = 0         
1834
1835        # Set some field values
1836        domain.set_quantity('elevation', elevation_function)
1837        domain.set_quantity('friction', 0.03)
1838        domain.set_quantity('xmomentum', 3.0)
1839        domain.set_quantity('ymomentum', 4.0)
1840        domain.set_starttime(5)
1841
1842        ######################
1843        # Boundary conditions
1844        B = Transmissive_boundary(domain)
1845        domain.set_boundary( {'exterior': B})
1846
1847        # This call mangles the stage values.
1848        domain.distribute_to_vertices_and_edges()
1849        domain.set_quantity('stage', 1.0)
1850       
1851
1852
1853        domain.set_name('datatest' + str(time.time()))
1854        domain.smooth = True
1855        domain.reduction = mean
1856
1857        sww = SWW_file(domain)
1858        sww.store_connectivity()
1859        sww.store_timestep(['stage', 'xmomentum', 'ymomentum','elevation'])
1860        domain.set_quantity('stage', 10.0) # This is automatically limited
1861        # so it will not be less than the elevation
1862        domain.time = 2.
1863        sww.store_timestep(['stage','elevation', 'xmomentum', 'ymomentum'])
1864
1865        # test the function
1866        points = [[5.0,1.],[0.5,2.]]
1867
1868        points_file = tempfile.mktemp(".csv")
1869#        points_file = 'test_point.csv'
1870        file_id = open(points_file,"w")
1871        file_id.write("name, easting, northing, elevation \n\
1872point1, 5.0, 1.0, 3.0\n\
1873point2, 0.5, 2.0, 9.0\n")
1874        file_id.close()
1875
1876       
1877        sww2csv_gauges(sww.filename, 
1878                            points_file,
1879                            verbose=False,
1880                            use_cache=False)
1881
1882#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1883        point1_answers_array = [[5.0,5.0/3600.,1.0,6.0,-5.0,3.0,4.0], [7.0,7.0/3600.,10.0,15.0,-5.0,3.0,4.0]]
1884        point1_filename = 'gauge_point1.csv'
1885        point1_handle = file(point1_filename)
1886        point1_reader = reader(point1_handle)
1887        point1_reader.next()
1888
1889        line=[]
1890        for i,row in enumerate(point1_reader):
1891            #print 'i',i,'row',row
1892            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1893                         float(row[4]), float(row[5]), float(row[6])])
1894            #print 'assert line',line[i],'point1',point1_answers_array[i]
1895            assert num.allclose(line[i], point1_answers_array[i])
1896
1897        point2_answers_array = [[5.0,5.0/3600.,1.0,1.5,-0.5,3.0,4.0], [7.0,7.0/3600.,10.0,10.5,-0.5,3.0,4.0]]
1898        point2_filename = 'gauge_point2.csv' 
1899        point2_handle = file(point2_filename)
1900        point2_reader = reader(point2_handle)
1901        point2_reader.next()
1902                       
1903        line=[]
1904        for i,row in enumerate(point2_reader):
1905            #print 'i',i,'row',row
1906            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1907                         float(row[4]),float(row[5]), float(row[6])])
1908            #print 'assert line',line[i],'point1',point1_answers_array[i]
1909            assert num.allclose(line[i], point2_answers_array[i])
1910                         
1911        # clean up
1912        point1_handle.close()
1913        point2_handle.close()
1914        #print "sww.filename",sww.filename
1915        os.remove(sww.filename)
1916        os.remove(points_file)
1917        os.remove(point1_filename)
1918        os.remove(point2_filename)
1919
1920
1921    def test_greens_law(self):
1922
1923        from math import sqrt
1924       
1925        d1 = 80.0
1926        d2 = 20.0
1927        h1 = 1.0
1928        h2 = greens_law(d1,d2,h1)
1929
1930        assert h2==sqrt(2.0)
1931       
1932    def test_calc_bearings(self):
1933 
1934        from math import atan, degrees
1935        #Test East
1936        uh = 1
1937        vh = 1.e-15
1938        angle = calc_bearing(uh, vh)
1939        if 89 < angle < 91: v=1
1940        assert v==1
1941        #Test West
1942        uh = -1
1943        vh = 1.e-15
1944        angle = calc_bearing(uh, vh)
1945        if 269 < angle < 271: v=1
1946        assert v==1
1947        #Test North
1948        uh = 1.e-15
1949        vh = 1
1950        angle = calc_bearing(uh, vh)
1951        if -1 < angle < 1: v=1
1952        assert v==1
1953        #Test South
1954        uh = 1.e-15
1955        vh = -1
1956        angle = calc_bearing(uh, vh)
1957        if 179 < angle < 181: v=1
1958        assert v==1
1959        #Test South-East
1960        uh = 1
1961        vh = -1
1962        angle = calc_bearing(uh, vh)
1963        if 134 < angle < 136: v=1
1964        assert v==1
1965        #Test North-East
1966        uh = 1
1967        vh = 1
1968        angle = calc_bearing(uh, vh)
1969        if 44 < angle < 46: v=1
1970        assert v==1
1971        #Test South-West
1972        uh = -1
1973        vh = -1
1974        angle = calc_bearing(uh, vh)
1975        if 224 < angle < 226: v=1
1976        assert v==1
1977        #Test North-West
1978        uh = -1
1979        vh = 1
1980        angle = calc_bearing(uh, vh)
1981        if 314 < angle < 316: v=1
1982        assert v==1
1983       
1984
1985#-------------------------------------------------------------
1986
1987if __name__ == "__main__":
1988    suite = unittest.makeSuite(Test_Util, 'test')
1989#    runner = unittest.TextTestRunner(verbosity=2)
1990    runner = unittest.TextTestRunner(verbosity=1)
1991    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.