source: anuga_core/source/anuga/abstract_2d_finite_volumes/test_util.py @ 7672

Last change on this file since 7672 was 7672, checked in by hudson, 14 years ago

Wrote 1 failing test for the future gauge_sww2csv centroid intersection option.

File size: 53.9 KB
RevLine 
[5897]1#!/usr/bin/env python
2
3
4import unittest
5from math import sqrt, pi
6import tempfile, os
7from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
8
9from anuga.abstract_2d_finite_volumes.util import *
10from anuga.config import epsilon
11from anuga.shallow_water.data_manager import timefile2netcdf,del_dir
12
13from anuga.utilities.numerical_tools import NAN
14
15from sys import platform
16
17from anuga.pmesh.mesh import Mesh
18from anuga.shallow_water import Domain, Transmissive_boundary
[7340]19from anuga.shallow_water.data_manager import SWW_file
[5897]20from csv import reader,writer
21import time
22import string
23
[7276]24import numpy as num
[6145]25
26
[5897]27def test_function(x, y):
28    return x+y
29
30class Test_Util(unittest.TestCase):
31    def setUp(self):
32        pass
33
34    def tearDown(self):
35        pass
36
37
38
39
40    #Geometric
41    #def test_distance(self):
42    #    from anuga.abstract_2d_finite_volumes.util import distance#
43    #
44    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
45    #                     'Distance is wrong!')
46    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
47    #                    'distance is wrong!')
48    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
49    #                    'distance is wrong!')
50    #
51    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
52    #                    'distance is wrong!')
53
54
55    def test_file_function_time1(self):
56        """Test that File function interpolates correctly
57        between given times. No x,y dependency here.
58        """
59
60        #Write file
61        import os, time
62        from anuga.config import time_format
63        from math import sin, pi
64
65        #Typical ASCII file
66        finaltime = 1200
67        filename = 'test_file_function'
68        fid = open(filename + '.txt', 'w')
69        start = time.mktime(time.strptime('2000', '%Y'))
70        dt = 60  #One minute intervals
71        t = 0.0
72        while t <= finaltime:
73            t_string = time.strftime(time_format, time.gmtime(t+start))
74            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
75            t += dt
76
77        fid.close()
78
79        #Convert ASCII file to NetCDF (Which is what we really like!)
80        timefile2netcdf(filename)
81
82
83        #Create file function from time series
84        F = file_function(filename + '.tms',
85                          quantities = ['Attribute0',
86                                        'Attribute1',
87                                        'Attribute2'])
88       
89        #Now try interpolation
90        for i in range(20):
91            t = i*10
92            q = F(t)
93
94            #Exact linear intpolation
[6145]95            assert num.allclose(q[0], 2*t)
[5897]96            if i%6 == 0:
[6145]97                assert num.allclose(q[1], t**2)
98                assert num.allclose(q[2], sin(t*pi/600))
[5897]99
100        #Check non-exact
101
102        t = 90 #Halfway between 60 and 120
103        q = F(t)
[6145]104        assert num.allclose( (120**2 + 60**2)/2, q[1] )
105        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
[5897]106
107
108        t = 100 #Two thirds of the way between between 60 and 120
109        q = F(t)
[6145]110        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
111        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
[5897]112
113        os.remove(filename + '.txt')
114        os.remove(filename + '.tms')       
115
116
117       
118    def test_spatio_temporal_file_function_basic(self):
119        """Test that spatio temporal file function performs the correct
120        interpolations in both time and space
121        NetCDF version (x,y,t dependency)       
122        """
123        import time
124
125        #Create sww file of simple propagation from left to right
126        #through rectangular domain
127        from shallow_water import Domain, Dirichlet_boundary
128        from mesh_factory import rectangular
129
130        #Create basic mesh and shallow water domain
131        points, vertices, boundary = rectangular(3, 3)
132        domain1 = Domain(points, vertices, boundary)
133
134        from anuga.utilities.numerical_tools import mean
135        domain1.reduction = mean
136        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
137                              # only one value.
138
139        domain1.default_order = 2
140        domain1.store = True
141        domain1.set_datadir('.')
[7562]142        sww_file = 'spatio_temporal_boundary_source_%d' %(id(self))
143        domain1.set_name(sww_file)
[5897]144
145        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
146        domain1.set_quantity('elevation', 0)
147        domain1.set_quantity('friction', 0)
148        domain1.set_quantity('stage', 0)
149
150        # Boundary conditions
151        B0 = Dirichlet_boundary([0,0,0])
152        B6 = Dirichlet_boundary([0.6,0,0])
153        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
154        domain1.check_integrity()
155
156        finaltime = 8
157        #Evolution
158        t0 = -1
159        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
160            #print 'Timesteps: %.16f, %.16f' %(t0, t)
161            #if t == t0:
162            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
163            #   raise msg
164            t0 = t
165             
166            #domain1.write_time()
167
168
169        #Now read data from sww and check
170        from Scientific.IO.NetCDF import NetCDFFile
[7340]171        filename = domain1.get_name() + '.sww'
[5897]172        fid = NetCDFFile(filename)
173
174        x = fid.variables['x'][:]
175        y = fid.variables['y'][:]
176        stage = fid.variables['stage'][:]
177        xmomentum = fid.variables['xmomentum'][:]
178        ymomentum = fid.variables['ymomentum'][:]
179        time = fid.variables['time'][:]
180
181        #Take stage vertex values at last timestep on diagonal
182        #Diagonal is identified by vertices: 0, 5, 10, 15
183
184        last_time_index = len(time)-1 #Last last_time_index
[7276]185        d_stage = num.reshape(num.take(stage[last_time_index, :],
186                                       [0,5,10,15],
187                                       axis=0),
188                              (4,1))
189        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
190                                    [0,5,10,15],
191                                   axis=0),
192                           (4,1))
193        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
194                                    [0,5,10,15],
195                                   axis=0),
196                           (4,1))
[6171]197        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
[5897]198
199        #Reference interpolated values at midpoints on diagonal at
200        #this timestep are
201        r0 = (D[0] + D[1])/2
202        r1 = (D[1] + D[2])/2
203        r2 = (D[2] + D[3])/2
204
205        #And the midpoints are found now
[7276]206        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
207        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
[5897]208
[6145]209        diag = num.concatenate( (Dx, Dy), axis=1)
[5897]210        d_midpoints = (diag[1:] + diag[:-1])/2
211
212        #Let us see if the file function can find the correct
213        #values at the midpoints at the last timestep:
214        f = file_function(filename, domain1,
215                          interpolation_points = d_midpoints)
216
217        T = f.get_time()
218        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
219        assert not T[-1] == T[-2], msg
220        t = time[last_time_index]
[6145]221        q = f(t, point_id=0); assert num.allclose(r0, q)
222        q = f(t, point_id=1); assert num.allclose(r1, q)
223        q = f(t, point_id=2); assert num.allclose(r2, q)
[5897]224
225
226        ##################
227        #Now do the same for the first timestep
228
229        timestep = 0 #First timestep
[7276]230        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
231        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
232        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
[6171]233        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
[5897]234
235        #Reference interpolated values at midpoints on diagonal at
236        #this timestep are
237        r0 = (D[0] + D[1])/2
238        r1 = (D[1] + D[2])/2
239        r2 = (D[2] + D[3])/2
240
241        #Let us see if the file function can find the correct
242        #values
[6145]243        q = f(0, point_id=0); assert num.allclose(r0, q)
244        q = f(0, point_id=1); assert num.allclose(r1, q)
245        q = f(0, point_id=2); assert num.allclose(r2, q)
[5897]246
247
248        ##################
249        #Now do it again for a timestep in the middle
250
251        timestep = 33
[7276]252        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
253        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
254        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
[6171]255        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
[5897]256
257        #Reference interpolated values at midpoints on diagonal at
258        #this timestep are
259        r0 = (D[0] + D[1])/2
260        r1 = (D[1] + D[2])/2
261        r2 = (D[2] + D[3])/2
262
[6145]263        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
264        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
265        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
[5897]266
267
268        ##################
269        #Now check temporal interpolation
270        #Halfway between timestep 15 and 16
271
272        timestep = 15
[7276]273        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
274        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
275        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
[6171]276        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
[5897]277
278        #Reference interpolated values at midpoints on diagonal at
279        #this timestep are
280        r0_0 = (D[0] + D[1])/2
281        r1_0 = (D[1] + D[2])/2
282        r2_0 = (D[2] + D[3])/2
283
284        #
285        timestep = 16
[7276]286        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
287        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
288        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
[6171]289        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
[5897]290
291        #Reference interpolated values at midpoints on diagonal at
292        #this timestep are
293        r0_1 = (D[0] + D[1])/2
294        r1_1 = (D[1] + D[2])/2
295        r2_1 = (D[2] + D[3])/2
296
297        # The reference values are
298        r0 = (r0_0 + r0_1)/2
299        r1 = (r1_0 + r1_1)/2
300        r2 = (r2_0 + r2_1)/2
301
[6145]302        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
303        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
304        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
[5897]305
306        ##################
307        #Finally check interpolation 2 thirds of the way
308        #between timestep 15 and 16
309
310        # The reference values are
311        r0 = (r0_0 + 2*r0_1)/3
312        r1 = (r1_0 + 2*r1_1)/3
313        r2 = (r2_0 + 2*r2_1)/3
314
315        #And the file function gives
[6145]316        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
317        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
318        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
[5897]319
320        fid.close()
321        import os
322        os.remove(filename)
323
324
325
326    def test_spatio_temporal_file_function_different_origin(self):
327        """Test that spatio temporal file function performs the correct
328        interpolations in both time and space where space is offset by
329        xllcorner and yllcorner
330        NetCDF version (x,y,t dependency)       
331        """
332        import time
333
334        #Create sww file of simple propagation from left to right
335        #through rectangular domain
336        from shallow_water import Domain, Dirichlet_boundary
337        from mesh_factory import rectangular
338
339
340        from anuga.coordinate_transforms.geo_reference import Geo_reference
341        xllcorner = 2048
342        yllcorner = 11000
343        zone = 2
344
345        #Create basic mesh and shallow water domain
346        points, vertices, boundary = rectangular(3, 3)
347        domain1 = Domain(points, vertices, boundary,
348                         geo_reference = Geo_reference(xllcorner = xllcorner,
349                                                       yllcorner = yllcorner))
350       
351
352        from anuga.utilities.numerical_tools import mean       
353        domain1.reduction = mean
354        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
355                              # only one value.
356
357        domain1.default_order = 2
358        domain1.store = True
359        domain1.set_datadir('.')
360        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
361
362        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
363        domain1.set_quantity('elevation', 0)
364        domain1.set_quantity('friction', 0)
365        domain1.set_quantity('stage', 0)
366
367        # Boundary conditions
368        B0 = Dirichlet_boundary([0,0,0])
369        B6 = Dirichlet_boundary([0.6,0,0])
370        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
371        domain1.check_integrity()
372
373        finaltime = 8
374        #Evolution
375        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
376            pass
377            #domain1.write_time()
378
379
380        #Now read data from sww and check
381        from Scientific.IO.NetCDF import NetCDFFile
[7340]382        filename = domain1.get_name() + '.sww'
[5897]383        fid = NetCDFFile(filename)
384
385        x = fid.variables['x'][:]
386        y = fid.variables['y'][:]
[7276]387        # we 'cast' to 64 bit floats to pass this test
388        # SWW file quantities are stored as 32 bits
389        x = num.array(x, num.float)
390        y = num.array(y, num.float)
391
[5897]392        stage = fid.variables['stage'][:]
393        xmomentum = fid.variables['xmomentum'][:]
394        ymomentum = fid.variables['ymomentum'][:]
395        time = fid.variables['time'][:]
396
397        #Take stage vertex values at last timestep on diagonal
398        #Diagonal is identified by vertices: 0, 5, 10, 15
399
400        last_time_index = len(time)-1 #Last last_time_index     
[7276]401        d_stage = num.reshape(num.take(stage[last_time_index, :],
402                                       [0,5,10,15],
403                                       axis=0),
404                              (4,1))
405        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
406                                    [0,5,10,15],
407                                   axis=0),
408                           (4,1))
409        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
410                                    [0,5,10,15],
411                                    axis=0),
412                           (4,1))
[6171]413        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
[5897]414
415        #Reference interpolated values at midpoints on diagonal at
416        #this timestep are
417        r0 = (D[0] + D[1])/2
418        r1 = (D[1] + D[2])/2
419        r2 = (D[2] + D[3])/2
420
421        #And the midpoints are found now
[7276]422        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
423        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
[5897]424
[7276]425        diag = num.concatenate((Dx, Dy), axis=1)
[5897]426        d_midpoints = (diag[1:] + diag[:-1])/2
427
428
429        #Adjust for georef - make interpolation points absolute
430        d_midpoints[:,0] += xllcorner
431        d_midpoints[:,1] += yllcorner               
432
433        #Let us see if the file function can find the correct
434        #values at the midpoints at the last timestep:
435        f = file_function(filename, domain1,
436                          interpolation_points = d_midpoints)
437
438        t = time[last_time_index]                         
439
[7276]440        q = f(t, point_id=0)
441        msg = '\nr0=%s\nq=%s' % (str(r0), str(q))
442        assert num.allclose(r0, q), msg
[5897]443
[7276]444        q = f(t, point_id=1)
445        msg = '\nr1=%s\nq=%s' % (str(r1), str(q))
446        assert num.allclose(r1, q), msg
447
448        q = f(t, point_id=2)
449        msg = '\nr2=%s\nq=%s' % (str(r2), str(q))
450        assert num.allclose(r2, q), msg
451
452
[5897]453        ##################
454        #Now do the same for the first timestep
455
456        timestep = 0 #First timestep
[7276]457        d_stage = num.reshape(num.take(stage[timestep, :],
458                                       [0,5,10,15],
459                                       axis=0),
460                              (4,1))
461        d_uh = num.reshape(num.take(xmomentum[timestep, :],
462                                    [0,5,10,15],
463                                    axis=0),
464                           (4,1))
465        d_vh = num.reshape(num.take(ymomentum[timestep, :],
466                                    [0,5,10,15],
467                                    axis=0),
468                           (4,1))
[6145]469        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
[5897]470
471        #Reference interpolated values at midpoints on diagonal at
472        #this timestep are
473        r0 = (D[0] + D[1])/2
474        r1 = (D[1] + D[2])/2
475        r2 = (D[2] + D[3])/2
476
477        #Let us see if the file function can find the correct
478        #values
[6145]479        q = f(0, point_id=0); assert num.allclose(r0, q)
480        q = f(0, point_id=1); assert num.allclose(r1, q)
481        q = f(0, point_id=2); assert num.allclose(r2, q)
[5897]482
483
484        ##################
485        #Now do it again for a timestep in the middle
486
487        timestep = 33
[7276]488        d_stage = num.reshape(num.take(stage[timestep, :],
489                                       [0,5,10,15],
490                                       axis=0),
491                              (4,1))
492        d_uh = num.reshape(num.take(xmomentum[timestep, :],
493                                    [0,5,10,15],
494                                    axis=0),
495                           (4,1))
496        d_vh = num.reshape(num.take(ymomentum[timestep, :],
497                                    [0,5,10,15],
498                                    axis=0),
499                           (4,1))
[6145]500        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
[5897]501
502        #Reference interpolated values at midpoints on diagonal at
503        #this timestep are
504        r0 = (D[0] + D[1])/2
505        r1 = (D[1] + D[2])/2
506        r2 = (D[2] + D[3])/2
507
[6145]508        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
509        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
510        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
[5897]511
512
513        ##################
514        #Now check temporal interpolation
515        #Halfway between timestep 15 and 16
516
517        timestep = 15
[7276]518        d_stage = num.reshape(num.take(stage[timestep, :],
519                                       [0,5,10,15],
520                                       axis=0),
521                              (4,1))
522        d_uh = num.reshape(num.take(xmomentum[timestep, :],
523                                    [0,5,10,15],
524                                    axis=0),
525                           (4,1))
526        d_vh = num.reshape(num.take(ymomentum[timestep, :],
527                                    [0,5,10,15],
528                                    axis=0),
529                           (4,1))
[6145]530        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
[5897]531
532        #Reference interpolated values at midpoints on diagonal at
533        #this timestep are
534        r0_0 = (D[0] + D[1])/2
535        r1_0 = (D[1] + D[2])/2
536        r2_0 = (D[2] + D[3])/2
537
538        #
539        timestep = 16
[7276]540        d_stage = num.reshape(num.take(stage[timestep, :],
541                                       [0,5,10,15],
542                                       axis=0),
543                              (4,1))
544        d_uh = num.reshape(num.take(xmomentum[timestep, :],
545                                    [0,5,10,15],
546                                    axis=0),
547                           (4,1))
548        d_vh = num.reshape(num.take(ymomentum[timestep, :],
549                                    [0,5,10,15],
550                                    axis=0),
551                           (4,1))
[6145]552        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
[5897]553
554        #Reference interpolated values at midpoints on diagonal at
555        #this timestep are
556        r0_1 = (D[0] + D[1])/2
557        r1_1 = (D[1] + D[2])/2
558        r2_1 = (D[2] + D[3])/2
559
560        # The reference values are
561        r0 = (r0_0 + r0_1)/2
562        r1 = (r1_0 + r1_1)/2
563        r2 = (r2_0 + r2_1)/2
564
[6145]565        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
566        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
567        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
[5897]568
569        ##################
570        #Finally check interpolation 2 thirds of the way
571        #between timestep 15 and 16
572
573        # The reference values are
574        r0 = (r0_0 + 2*r0_1)/3
575        r1 = (r1_0 + 2*r1_1)/3
576        r2 = (r2_0 + 2*r2_1)/3
577
578        #And the file function gives
[6145]579        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
580        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
581        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
[5897]582
583        fid.close()
584        import os
585        os.remove(filename)
586
587       
588
589
590    def test_spatio_temporal_file_function_time(self):
591        """Test that File function interpolates correctly
592        between given times.
593        NetCDF version (x,y,t dependency)
594        """
595
596        #Create NetCDF (sww) file to be read
597        # x: 0, 5, 10, 15
598        # y: -20, -10, 0, 10
599        # t: 0, 60, 120, ...., 1200
600        #
601        # test quantities (arbitrary but non-trivial expressions):
602        #
603        #   stage     = 3*x - y**2 + 2*t
604        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
605        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
606
607        #NOTE: Nice test that may render some of the others redundant.
608
609        import os, time
610        from anuga.config import time_format
611        from mesh_factory import rectangular
612        from shallow_water import Domain
613        import anuga.shallow_water.data_manager
614
615        finaltime = 1200
616        filename = 'test_file_function'
617
618        #Create a domain to hold test grid
619        #(0:15, -20:10)
620        points, vertices, boundary =\
621                rectangular(4, 4, 15, 30, origin = (0, -20))
622        #print "points", points
623
624        #print 'Number of elements', len(vertices)
625        domain = Domain(points, vertices, boundary)
626        domain.smooth = False
627        domain.default_order = 2
628        domain.set_datadir('.')
629        domain.set_name(filename)
630        domain.store = True
631
632        #print points
633        start = time.mktime(time.strptime('2000', '%Y'))
634        domain.starttime = start
635
636
637        #Store structure
638        domain.initialise_storage()
639
640        #Compute artificial time steps and store
641        dt = 60  #One minute intervals
642        t = 0.0
643        while t <= finaltime:
644            #Compute quantities
645            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
646            domain.set_quantity('stage', f1)
647
648            f2 = lambda x,y: x+y+t**2
649            domain.set_quantity('xmomentum', f2)
650
[6145]651            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
[5897]652            domain.set_quantity('ymomentum', f3)
653
654            #Store and advance time
655            domain.time = t
[7340]656            domain.store_timestep()
[5897]657            t += dt
658
659
660        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
661     
662        #Deliberately set domain.starttime to too early
663        domain.starttime = start - 1
664
665        #Create file function
666        F = file_function(filename + '.sww', domain,
667                          quantities = domain.conserved_quantities,
668                          interpolation_points = interpolation_points)
669
670        #Check that FF updates fixes domain starttime
[6145]671        assert num.allclose(domain.starttime, start)
[5897]672
673        #Check that domain.starttime isn't updated if later
674        domain.starttime = start + 1
675        F = file_function(filename + '.sww', domain,
676                          quantities = domain.conserved_quantities,
677                          interpolation_points = interpolation_points)
[6145]678        assert num.allclose(domain.starttime, start+1)
[5897]679        domain.starttime = start
680
681
682        #Check linear interpolation in time
683        F = file_function(filename + '.sww', domain,
684                          quantities = domain.conserved_quantities,
685                          interpolation_points = interpolation_points)               
686        for id in range(len(interpolation_points)):
687            x = interpolation_points[id][0]
688            y = interpolation_points[id][1]
689
690            for i in range(20):
691                t = i*10
692                k = i%6
693
694                if k == 0:
695                    q0 = F(t, point_id=id)
696                    q1 = F(t+60, point_id=id)
697
[7276]698                if num.alltrue(q0 == NAN):
[5897]699                    actual = q0
700                else:
701                    actual = (k*q1 + (6-k)*q0)/6
702                q = F(t, point_id=id)
703                #print i, k, t, q
704                #print ' ', q0
705                #print ' ', q1
706                #print "q",q
707                #print "actual", actual
708                #print
[7276]709                if num.alltrue(q0 == NAN):
710                     self.failUnless(num.alltrue(q == actual), 'Fail!')
[5897]711                else:
[6145]712                    assert num.allclose(q, actual)
[5897]713
714
715        #Another check of linear interpolation in time
716        for id in range(len(interpolation_points)):
717            q60 = F(60, point_id=id)
718            q120 = F(120, point_id=id)
719
720            t = 90 #Halfway between 60 and 120
721            q = F(t, point_id=id)
[6145]722            assert num.allclose( (q120+q60)/2, q )
[5897]723
724            t = 100 #Two thirds of the way between between 60 and 120
725            q = F(t, point_id=id)
[6145]726            assert num.allclose(q60/3 + 2*q120/3, q)
[5897]727
728
729
730        #Check that domain.starttime isn't updated if later than file starttime but earlier
731        #than file end time
732        delta = 23
733        domain.starttime = start + delta
734        F = file_function(filename + '.sww', domain,
735                          quantities = domain.conserved_quantities,
736                          interpolation_points = interpolation_points)
[6145]737        assert num.allclose(domain.starttime, start+delta)
[5897]738
739
740
741
742        #Now try interpolation with delta offset
743        for id in range(len(interpolation_points)):           
744            x = interpolation_points[id][0]
745            y = interpolation_points[id][1]
746
747            for i in range(20):
748                t = i*10
749                k = i%6
750
751                if k == 0:
752                    q0 = F(t-delta, point_id=id)
753                    q1 = F(t+60-delta, point_id=id)
754
755                q = F(t-delta, point_id=id)
[6145]756                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
[5897]757
758
759        os.remove(filename + '.sww')
760
761
762
763    def Xtest_spatio_temporal_file_function_time(self):
764        # FIXME: This passes but needs some TLC
765        # Test that File function interpolates correctly
766        # When some points are outside the mesh
767
768        import os, time
769        from anuga.config import time_format
770        from mesh_factory import rectangular
771        from shallow_water import Domain
772        import anuga.shallow_water.data_manager 
773        from anuga.pmesh.mesh_interface import create_mesh_from_regions
774        finaltime = 1200
775       
776        filename = tempfile.mktemp()
777        #print "filename",filename
778        filename = 'test_file_function'
779
780        meshfilename = tempfile.mktemp(".tsh")
781
782        boundary_tags = {'walls':[0,1],'bom':[2]}
783       
784        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
785       
786        create_mesh_from_regions(polygon_absolute,
787                                 boundary_tags,
788                                 10000000,
789                                 filename=meshfilename)
790        domain = Domain(mesh_filename=meshfilename)
791        domain.smooth = False
792        domain.default_order = 2
793        domain.set_datadir('.')
794        domain.set_name(filename)
795        domain.store = True
796
797        #print points
798        start = time.mktime(time.strptime('2000', '%Y'))
799        domain.starttime = start
800       
801
802        #Store structure
803        domain.initialise_storage()
804
805        #Compute artificial time steps and store
806        dt = 60  #One minute intervals
807        t = 0.0
808        while t <= finaltime:
809            #Compute quantities
810            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
811            domain.set_quantity('stage', f1)
812
813            f2 = lambda x,y: x+y+t**2
814            domain.set_quantity('xmomentum', f2)
815
[6145]816            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
[5897]817            domain.set_quantity('ymomentum', f3)
818
819            #Store and advance time
820            domain.time = t
[7342]821            domain.store_timestep()
[5897]822            t += dt
823
824        interpolation_points = [[1,0]]
825        interpolation_points = [[100,1000]]
826       
827        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
828                                [78787,78787],[7878,3432]]
829           
830        #Deliberately set domain.starttime to too early
831        domain.starttime = start - 1
832
833        #Create file function
834        F = file_function(filename + '.sww', domain,
835                          quantities = domain.conserved_quantities,
836                          interpolation_points = interpolation_points)
837
838        #Check that FF updates fixes domain starttime
[6145]839        assert num.allclose(domain.starttime, start)
[5897]840
841        #Check that domain.starttime isn't updated if later
842        domain.starttime = start + 1
843        F = file_function(filename + '.sww', domain,
844                          quantities = domain.conserved_quantities,
845                          interpolation_points = interpolation_points)
[6145]846        assert num.allclose(domain.starttime, start+1)
[5897]847        domain.starttime = start
848
849
850        #Check linear interpolation in time
851        # checking points inside and outside the mesh
852        F = file_function(filename + '.sww', domain,
853                          quantities = domain.conserved_quantities,
854                          interpolation_points = interpolation_points)
855       
856        for id in range(len(interpolation_points)):
857            x = interpolation_points[id][0]
858            y = interpolation_points[id][1]
859
860            for i in range(20):
861                t = i*10
862                k = i%6
863
864                if k == 0:
865                    q0 = F(t, point_id=id)
866                    q1 = F(t+60, point_id=id)
867
868                if q0 == NAN:
869                    actual = q0
870                else:
871                    actual = (k*q1 + (6-k)*q0)/6
872                q = F(t, point_id=id)
873                #print i, k, t, q
874                #print ' ', q0
875                #print ' ', q1
876                #print "q",q
877                #print "actual", actual
878                #print
879                if q0 == NAN:
880                     self.failUnless( q == actual, 'Fail!')
881                else:
[6145]882                    assert num.allclose(q, actual)
[5897]883
884        # now lets check points inside the mesh
885        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
886        interpolation_points = [[10,-12.5]]
887           
888        print "len(interpolation_points)",len(interpolation_points) 
889        F = file_function(filename + '.sww', domain,
890                          quantities = domain.conserved_quantities,
891                          interpolation_points = interpolation_points)
892
893        domain.starttime = start
894
895
896        #Check linear interpolation in time
897        F = file_function(filename + '.sww', domain,
898                          quantities = domain.conserved_quantities,
899                          interpolation_points = interpolation_points)               
900        for id in range(len(interpolation_points)):
901            x = interpolation_points[id][0]
902            y = interpolation_points[id][1]
903
904            for i in range(20):
905                t = i*10
906                k = i%6
907
908                if k == 0:
909                    q0 = F(t, point_id=id)
910                    q1 = F(t+60, point_id=id)
911
912                if q0 == NAN:
913                    actual = q0
914                else:
915                    actual = (k*q1 + (6-k)*q0)/6
916                q = F(t, point_id=id)
917                print "############"
918                print "id, x, y ", id, x, y #k, t, q
919                print "t", t
920                #print ' ', q0
921                #print ' ', q1
922                print "q",q
923                print "actual", actual
924                #print
925                if q0 == NAN:
926                     self.failUnless( q == actual, 'Fail!')
927                else:
[6145]928                    assert num.allclose(q, actual)
[5897]929
930
931        #Another check of linear interpolation in time
932        for id in range(len(interpolation_points)):
933            q60 = F(60, point_id=id)
934            q120 = F(120, point_id=id)
935
936            t = 90 #Halfway between 60 and 120
937            q = F(t, point_id=id)
[6145]938            assert num.allclose( (q120+q60)/2, q )
[5897]939
940            t = 100 #Two thirds of the way between between 60 and 120
941            q = F(t, point_id=id)
[6145]942            assert num.allclose(q60/3 + 2*q120/3, q)
[5897]943
944
945
946        #Check that domain.starttime isn't updated if later than file starttime but earlier
947        #than file end time
948        delta = 23
949        domain.starttime = start + delta
950        F = file_function(filename + '.sww', domain,
951                          quantities = domain.conserved_quantities,
952                          interpolation_points = interpolation_points)
[6145]953        assert num.allclose(domain.starttime, start+delta)
[5897]954
955
956
957
958        #Now try interpolation with delta offset
959        for id in range(len(interpolation_points)):           
960            x = interpolation_points[id][0]
961            y = interpolation_points[id][1]
962
963            for i in range(20):
964                t = i*10
965                k = i%6
966
967                if k == 0:
968                    q0 = F(t-delta, point_id=id)
969                    q1 = F(t+60-delta, point_id=id)
970
971                q = F(t-delta, point_id=id)
[6145]972                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
[5897]973
974
975        os.remove(filename + '.sww')
976
977    def test_file_function_time_with_domain(self):
978        """Test that File function interpolates correctly
979        between given times. No x,y dependency here.
980        Use domain with starttime
981        """
982
983        #Write file
984        import os, time, calendar
985        from anuga.config import time_format
986        from math import sin, pi
987        from domain import Domain
988
989        finaltime = 1200
990        filename = 'test_file_function'
991        fid = open(filename + '.txt', 'w')
992        start = time.mktime(time.strptime('2000', '%Y'))
993        dt = 60  #One minute intervals
994        t = 0.0
995        while t <= finaltime:
996            t_string = time.strftime(time_format, time.gmtime(t+start))
997            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
998            t += dt
999
1000        fid.close()
1001
1002
1003        #Convert ASCII file to NetCDF (Which is what we really like!)
1004        timefile2netcdf(filename)
1005
1006
1007
1008        a = [0.0, 0.0]
1009        b = [4.0, 0.0]
1010        c = [0.0, 3.0]
1011
1012        points = [a, b, c]
1013        vertices = [[0,1,2]]
1014        domain = Domain(points, vertices)
1015
1016        # Check that domain.starttime is updated if non-existing
1017        F = file_function(filename + '.tms',
1018                          domain,
1019                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
[6145]1020        assert num.allclose(domain.starttime, start)
[5897]1021
1022        # Check that domain.starttime is updated if too early
1023        domain.starttime = start - 1
1024        F = file_function(filename + '.tms',
1025                          domain,
1026                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
[6145]1027        assert num.allclose(domain.starttime, start)
[5897]1028
1029        # Check that domain.starttime isn't updated if later
1030        domain.starttime = start + 1
1031        F = file_function(filename + '.tms',
1032                          domain,
1033                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
[6145]1034        assert num.allclose(domain.starttime, start+1)
[5897]1035
1036        domain.starttime = start
1037        F = file_function(filename + '.tms',
1038                          domain,
1039                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
1040                          use_cache=True)
1041       
1042
1043        #print F.precomputed_values
1044        #print 'F(60)', F(60)
1045       
1046        #Now try interpolation
1047        for i in range(20):
1048            t = i*10
1049            q = F(t)
1050
1051            #Exact linear intpolation
[6145]1052            assert num.allclose(q[0], 2*t)
[5897]1053            if i%6 == 0:
[6145]1054                assert num.allclose(q[1], t**2)
1055                assert num.allclose(q[2], sin(t*pi/600))
[5897]1056
1057        #Check non-exact
1058
1059        t = 90 #Halfway between 60 and 120
1060        q = F(t)
[6145]1061        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1062        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
[5897]1063
1064
1065        t = 100 #Two thirds of the way between between 60 and 120
1066        q = F(t)
[6145]1067        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1068        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
[5897]1069
1070        os.remove(filename + '.tms')
1071        os.remove(filename + '.txt')       
1072
1073    def test_file_function_time_with_domain_different_start(self):
1074        """Test that File function interpolates correctly
1075        between given times. No x,y dependency here.
1076        Use domain with a starttime later than that of file
1077
1078        ASCII version
1079        """
1080
1081        #Write file
1082        import os, time, calendar
1083        from anuga.config import time_format
1084        from math import sin, pi
1085        from domain import Domain
1086
1087        finaltime = 1200
1088        filename = 'test_file_function'
1089        fid = open(filename + '.txt', 'w')
1090        start = time.mktime(time.strptime('2000', '%Y'))
1091        dt = 60  #One minute intervals
1092        t = 0.0
1093        while t <= finaltime:
1094            t_string = time.strftime(time_format, time.gmtime(t+start))
1095            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1096            t += dt
1097
1098        fid.close()
1099
1100        #Convert ASCII file to NetCDF (Which is what we really like!)
1101        timefile2netcdf(filename)       
1102
1103        a = [0.0, 0.0]
1104        b = [4.0, 0.0]
1105        c = [0.0, 3.0]
1106
1107        points = [a, b, c]
1108        vertices = [[0,1,2]]
1109        domain = Domain(points, vertices)
1110
1111        #Check that domain.starttime isn't updated if later than file starttime but earlier
1112        #than file end time
1113        delta = 23
1114        domain.starttime = start + delta
1115        F = file_function(filename + '.tms', domain,
1116                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
[6145]1117        assert num.allclose(domain.starttime, start+delta)
[5897]1118
[6173]1119        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1120                                            277., 337., 397., 457., 517.,
1121                                            577., 637., 697., 757., 817.,
1122                                            877., 937., 997., 1057., 1117.,
1123                                            1177.])
[5897]1124
1125
1126        #Now try interpolation with delta offset
1127        for i in range(20):
1128            t = i*10
1129            q = F(t-delta)
1130
1131            #Exact linear intpolation
[6145]1132            assert num.allclose(q[0], 2*t)
[5897]1133            if i%6 == 0:
[6145]1134                assert num.allclose(q[1], t**2)
1135                assert num.allclose(q[2], sin(t*pi/600))
[5897]1136
1137        #Check non-exact
1138
1139        t = 90 #Halfway between 60 and 120
1140        q = F(t-delta)
[6145]1141        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1142        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
[5897]1143
1144
1145        t = 100 #Two thirds of the way between between 60 and 120
1146        q = F(t-delta)
[6145]1147        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1148        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
[5897]1149
1150
1151        os.remove(filename + '.tms')
1152        os.remove(filename + '.txt')               
1153
[6173]1154       
[5897]1155
[6173]1156    def test_file_function_time_with_domain_different_start_and_time_limit(self):
1157        """Test that File function interpolates correctly
1158        between given times. No x,y dependency here.
1159        Use domain with a starttime later than that of file
[5897]1160
[6173]1161        ASCII version
1162       
1163        This test also tests that time can be truncated.
1164        """
1165
1166        # Write file
1167        import os, time, calendar
1168        from anuga.config import time_format
1169        from math import sin, pi
1170        from domain import Domain
1171
1172        finaltime = 1200
1173        filename = 'test_file_function'
1174        fid = open(filename + '.txt', 'w')
1175        start = time.mktime(time.strptime('2000', '%Y'))
1176        dt = 60  #One minute intervals
1177        t = 0.0
1178        while t <= finaltime:
1179            t_string = time.strftime(time_format, time.gmtime(t+start))
1180            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1181            t += dt
1182
1183        fid.close()
1184
1185        # Convert ASCII file to NetCDF (Which is what we really like!)
1186        timefile2netcdf(filename)       
1187
1188        a = [0.0, 0.0]
1189        b = [4.0, 0.0]
1190        c = [0.0, 3.0]
1191
1192        points = [a, b, c]
1193        vertices = [[0,1,2]]
1194        domain = Domain(points, vertices)
1195
1196        # Check that domain.starttime isn't updated if later than file starttime but earlier
1197        # than file end time
1198        delta = 23
1199        domain.starttime = start + delta
[6175]1200        time_limit = domain.starttime + 600
[6173]1201        F = file_function(filename + '.tms', domain,
[6175]1202                          time_limit=time_limit,
[6173]1203                          quantities=['Attribute0', 'Attribute1', 'Attribute2'])       
1204        assert num.allclose(domain.starttime, start+delta)
1205
1206        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1207                                            277., 337., 397., 457., 517.,
1208                                            577.])       
1209
1210
1211
1212        # Now try interpolation with delta offset
1213        for i in range(20):
1214            t = i*10
1215            q = F(t-delta)
1216
1217            #Exact linear intpolation
1218            assert num.allclose(q[0], 2*t)
1219            if i%6 == 0:
1220                assert num.allclose(q[1], t**2)
1221                assert num.allclose(q[2], sin(t*pi/600))
1222
1223        # Check non-exact
1224        t = 90 #Halfway between 60 and 120
1225        q = F(t-delta)
1226        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1227        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1228
1229
1230        t = 100 # Two thirds of the way between between 60 and 120
1231        q = F(t-delta)
1232        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1233        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1234
1235
1236        os.remove(filename + '.tms')
1237        os.remove(filename + '.txt')               
1238
1239       
1240       
1241       
1242
1243
[5897]1244    def test_apply_expression_to_dictionary(self):
1245
1246        #FIXME: Division is not expected to work for integers.
1247        #This must be caught.
[7276]1248        foo = num.array([[1,2,3], [4,5,6]], num.float)
[5897]1249
[7276]1250        bar = num.array([[-1,0,5], [6,1,1]], num.float)                 
[5897]1251
1252        D = {'X': foo, 'Y': bar}
1253
1254        Z = apply_expression_to_dictionary('X+Y', D)       
[6145]1255        assert num.allclose(Z, foo+bar)
[5897]1256
1257        Z = apply_expression_to_dictionary('X*Y', D)       
[6145]1258        assert num.allclose(Z, foo*bar)       
[5897]1259
1260        Z = apply_expression_to_dictionary('4*X+Y', D)       
[6145]1261        assert num.allclose(Z, 4*foo+bar)       
[5897]1262
1263        # test zero division is OK
1264        Z = apply_expression_to_dictionary('X/Y', D)
[6145]1265        assert num.allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
[5897]1266
1267        # make an error for zero on zero
[7276]1268        # this is really an error in numeric, SciPy core can handle it
[5897]1269        # Z = apply_expression_to_dictionary('0/Y', D)
1270
1271        #Check exceptions
1272        try:
1273            #Wrong name
1274            Z = apply_expression_to_dictionary('4*X+A', D)       
1275        except NameError:
1276            pass
1277        else:
1278            msg = 'Should have raised a NameError Exception'
1279            raise msg
1280
1281
1282        try:
1283            #Wrong order
1284            Z = apply_expression_to_dictionary(D, '4*X+A')       
1285        except AssertionError:
1286            pass
1287        else:
1288            msg = 'Should have raised a AssertionError Exception'
1289            raise msg       
1290       
1291
1292    def test_multiple_replace(self):
1293        """Hard test that checks a true word-by-word simultaneous replace
1294        """
1295       
1296        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1297        exp = '3*x+y + xi'
1298       
1299        new = multiple_replace(exp, D)
1300       
1301        assert new == '3*xi+eta + lam'
1302                         
1303
1304
1305    def test_point_on_line_obsolete(self):
1306        """Test that obsolete call issues appropriate warning"""
1307
1308        #Turn warning into an exception
1309        import warnings
1310        warnings.filterwarnings('error')
1311
1312        try:
1313            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1314        except DeprecationWarning:
1315            pass
1316        else:
1317            msg = 'point_on_line should have issued a DeprecationWarning'
1318            raise Exception(msg)   
1319
1320        warnings.resetwarnings()
1321   
1322    def test_get_revision_number(self):
1323        """test_get_revision_number(self):
1324
1325        Test that revision number can be retrieved.
1326        """
1327        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1328            # I have a known snv incompatability issue,
1329            # so I'm skipping this test.
1330            # FIXME when SVN is upgraded on our clusters
1331            pass
1332        else:   
1333            n = get_revision_number()
1334            assert n>=0
1335
1336
1337       
1338    def test_add_directories(self):
1339       
1340        import tempfile
1341        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1342        directories = ['ja','ne','ke']
1343        kens_dir = add_directories(root_dir, directories)
1344        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1345               sep + 'ke'
1346        assert access(root_dir,F_OK)
1347
1348        add_directories(root_dir, directories)
1349        assert access(root_dir,F_OK)
1350       
1351        #clean up!
1352        os.rmdir(kens_dir)
1353        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1354        os.rmdir(root_dir + sep + 'ja')
1355        os.rmdir(root_dir)
1356
1357    def test_add_directories_bad(self):
1358       
1359        import tempfile
1360        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1361        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1362       
1363        try:
1364            kens_dir = add_directories(root_dir, directories)
1365        except OSError:
1366            pass
1367        else:
1368            msg = 'bad dir name should give OSError'
1369            raise Exception(msg)   
1370           
1371        #clean up!
1372        os.rmdir(root_dir)
1373
1374    def test_check_list(self):
1375
1376        check_list(['stage','xmomentum'])
1377
1378       
1379    def test_add_directories(self):
1380       
1381        import tempfile
1382        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1383        directories = ['ja','ne','ke']
1384        kens_dir = add_directories(root_dir, directories)
1385        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1386               sep + 'ke'
1387        assert access(root_dir,F_OK)
1388
1389        add_directories(root_dir, directories)
1390        assert access(root_dir,F_OK)
1391       
1392        #clean up!
1393        os.rmdir(kens_dir)
1394        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1395        os.rmdir(root_dir + sep + 'ja')
1396        os.rmdir(root_dir)
1397
1398    def test_add_directories_bad(self):
1399       
1400        import tempfile
1401        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1402        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1403       
1404        try:
1405            kens_dir = add_directories(root_dir, directories)
1406        except OSError:
1407            pass
1408        else:
1409            msg = 'bad dir name should give OSError'
1410            raise Exception(msg)   
1411           
1412        #clean up!
1413        os.rmdir(root_dir)
1414
1415    def test_check_list(self):
1416
1417        check_list(['stage','xmomentum'])
[6070]1418
1419######
1420# Test the remove_lone_verts() function
1421######
[5897]1422       
[6070]1423    def test_remove_lone_verts_a(self):
[5897]1424        verts = [[0,0],[1,0],[0,1]]
1425        tris = [[0,1,2]]
1426        new_verts, new_tris = remove_lone_verts(verts, tris)
[6070]1427        self.failUnless(new_verts.tolist() == verts)
1428        self.failUnless(new_tris.tolist() == tris)
[5897]1429
[6070]1430    def test_remove_lone_verts_b(self):
[5897]1431        verts = [[0,0],[1,0],[0,1],[99,99]]
1432        tris = [[0,1,2]]
1433        new_verts, new_tris = remove_lone_verts(verts, tris)
[6070]1434        self.failUnless(new_verts.tolist() == verts[0:3])
1435        self.failUnless(new_tris.tolist() == tris)
[5897]1436       
[6070]1437    def test_remove_lone_verts_c(self):
[5897]1438        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1439        tris = [[1,2,4]]
1440        new_verts, new_tris = remove_lone_verts(verts, tris)
[6070]1441        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1442        self.failUnless(new_tris.tolist() == [[0,1,2]])
[5897]1443     
[6070]1444    def test_remove_lone_verts_d(self):
[5897]1445        verts = [[0,0],[1,0],[99,99],[0,1]]
1446        tris = [[0,1,3]]
1447        new_verts, new_tris = remove_lone_verts(verts, tris)
[6070]1448        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1449        self.failUnless(new_tris.tolist() == [[0,1,2]])
[5897]1450       
[6070]1451    def test_remove_lone_verts_e(self):
[5897]1452        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1453        tris = [[0,1,2]]
1454        new_verts, new_tris = remove_lone_verts(verts, tris)
[6070]1455        self.failUnless(new_verts.tolist() == verts[0:3])
1456        self.failUnless(new_tris.tolist() == tris)
[5897]1457     
[6070]1458    def test_remove_lone_verts_f(self):
1459        verts = [[0,0],[1,0],[99,99],[0,1],[99,99],[1,1],[99,99]]
1460        tris = [[0,1,3],[0,1,5]]
[5897]1461        new_verts, new_tris = remove_lone_verts(verts, tris)
[6070]1462        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1],[1,1]])
1463        self.failUnless(new_tris.tolist() == [[0,1,2],[0,1,3]])
[5897]1464       
[6070]1465######
1466#
1467######
1468       
[5897]1469    def test_get_min_max_values(self):
1470       
1471        list=[8,9,6,1,4]
1472        min1, max1 = get_min_max_values(list)
1473       
1474        assert min1==1 
1475        assert max1==9
1476       
1477    def test_get_min_max_values1(self):
1478       
1479        list=[-8,-9,-6,-1,-4]
1480        min1, max1 = get_min_max_values(list)
1481       
1482#        print 'min1,max1',min1,max1
1483        assert min1==-9 
1484        assert max1==-1
1485
1486#    def test_get_min_max_values2(self):
1487#        '''
1488#        The min and max supplied are greater than the ones in the
1489#        list and therefore are the ones returned
1490#        '''
1491#        list=[-8,-9,-6,-1,-4]
1492#        min1, max1 = get_min_max_values(list,-10,10)
1493#       
1494##        print 'min1,max1',min1,max1
1495#        assert min1==-10
1496#        assert max1==10
1497       
1498    def test_make_plots_from_csv_files(self):
1499       
1500        #if sys.platform == 'win32':  #Windows
1501            try: 
1502                import pylab
1503            except ImportError:
1504                #ANUGA don't need pylab to work so the system doesn't
1505                #rely on pylab being installed
1506                return
1507           
1508       
1509            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1510            temp_dir = tempfile.mkdtemp('','figures')
1511    #        print 'temp_dir',temp_dir
1512            fileName = temp_dir+sep+'time_series_3.csv'
1513            file = open(fileName,"w")
1514            file.write("time,stage,speed,momentum,elevation\n\
15151.0, 0, 0, 0, 10 \n\
15162.0, 5, 2, 4, 10 \n\
15173.0, 3, 3, 5, 10 \n")
1518            file.close()
1519   
1520            fileName1 = temp_dir+sep+'time_series_4.csv'
1521            file1 = open(fileName1,"w")
1522            file1.write("time,stage,speed,momentum,elevation\n\
15231.0, 0, 0, 0, 5 \n\
15242.0, -5, -2, -4, 5 \n\
15253.0, -4, -3, -5, 5 \n")
1526            file1.close()
1527   
1528            fileName2 = temp_dir+sep+'time_series_5.csv'
1529            file2 = open(fileName2,"w")
1530            file2.write("time,stage,speed,momentum,elevation\n\
15311.0, 0, 0, 0, 7 \n\
15322.0, 4, -0.45, 57, 7 \n\
15333.0, 6, -0.5, 56, 7 \n")
1534            file2.close()
1535           
1536            dir, name=os.path.split(fileName)
1537            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1538                                  output_dir=temp_dir,
1539                                  base_name='time_series_',
1540                                  plot_numbers=['3-5'],
1541                                  quantities=['speed','stage','momentum'],
1542                                  assess_all_csv_files=True,
1543                                  extra_plot_name='test')
1544           
1545            #print dir+sep+name[:-4]+'_stage_test.png'
1546            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1547            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1548            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1549   
1550            dir1, name1=os.path.split(fileName1)
1551            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1552            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1553            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1554   
1555   
1556            dir2, name2=os.path.split(fileName2)
1557            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1558            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1559            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1560   
1561            del_dir(temp_dir)
1562       
1563
1564
1565    def test_greens_law(self):
1566
1567        from math import sqrt
1568       
1569        d1 = 80.0
1570        d2 = 20.0
1571        h1 = 1.0
1572        h2 = greens_law(d1,d2,h1)
1573
1574        assert h2==sqrt(2.0)
1575       
1576    def test_calc_bearings(self):
1577 
1578        from math import atan, degrees
1579        #Test East
1580        uh = 1
1581        vh = 1.e-15
1582        angle = calc_bearing(uh, vh)
1583        if 89 < angle < 91: v=1
1584        assert v==1
1585        #Test West
1586        uh = -1
1587        vh = 1.e-15
1588        angle = calc_bearing(uh, vh)
1589        if 269 < angle < 271: v=1
1590        assert v==1
1591        #Test North
1592        uh = 1.e-15
1593        vh = 1
1594        angle = calc_bearing(uh, vh)
1595        if -1 < angle < 1: v=1
1596        assert v==1
1597        #Test South
1598        uh = 1.e-15
1599        vh = -1
1600        angle = calc_bearing(uh, vh)
1601        if 179 < angle < 181: v=1
1602        assert v==1
1603        #Test South-East
1604        uh = 1
1605        vh = -1
1606        angle = calc_bearing(uh, vh)
1607        if 134 < angle < 136: v=1
1608        assert v==1
1609        #Test North-East
1610        uh = 1
1611        vh = 1
1612        angle = calc_bearing(uh, vh)
1613        if 44 < angle < 46: v=1
1614        assert v==1
1615        #Test South-West
1616        uh = -1
1617        vh = -1
1618        angle = calc_bearing(uh, vh)
1619        if 224 < angle < 226: v=1
1620        assert v==1
1621        #Test North-West
1622        uh = -1
1623        vh = 1
1624        angle = calc_bearing(uh, vh)
1625        if 314 < angle < 316: v=1
1626        assert v==1
1627       
1628
1629#-------------------------------------------------------------
[7276]1630
[5897]1631if __name__ == "__main__":
[7276]1632    suite = unittest.makeSuite(Test_Util, 'test')
[5897]1633#    runner = unittest.TextTestRunner(verbosity=2)
1634    runner = unittest.TextTestRunner(verbosity=1)
1635    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.