source: anuga_core/source/anuga/abstract_2d_finite_volumes/test_util.py @ 7562

Last change on this file since 7562 was 7562, checked in by steve, 13 years ago

Updating the balanced and parallel code

File size: 66.2 KB
Line 
1#!/usr/bin/env python
2
3
4import unittest
5from math import sqrt, pi
6import tempfile, os
7from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
8
9from anuga.abstract_2d_finite_volumes.util import *
10from anuga.config import epsilon
11from anuga.shallow_water.data_manager import timefile2netcdf,del_dir
12
13from anuga.utilities.numerical_tools import NAN
14
15from sys import platform
16
17from anuga.pmesh.mesh import Mesh
18from anuga.shallow_water import Domain, Transmissive_boundary
19from anuga.shallow_water.data_manager import SWW_file
20from csv import reader,writer
21import time
22import string
23
24import numpy as num
25
26
27def test_function(x, y):
28    return x+y
29
30class Test_Util(unittest.TestCase):
31    def setUp(self):
32        pass
33
34    def tearDown(self):
35        pass
36
37
38
39
40    #Geometric
41    #def test_distance(self):
42    #    from anuga.abstract_2d_finite_volumes.util import distance#
43    #
44    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
45    #                     'Distance is wrong!')
46    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
47    #                    'distance is wrong!')
48    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
49    #                    'distance is wrong!')
50    #
51    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
52    #                    'distance is wrong!')
53
54
55    def test_file_function_time1(self):
56        """Test that File function interpolates correctly
57        between given times. No x,y dependency here.
58        """
59
60        #Write file
61        import os, time
62        from anuga.config import time_format
63        from math import sin, pi
64
65        #Typical ASCII file
66        finaltime = 1200
67        filename = 'test_file_function'
68        fid = open(filename + '.txt', 'w')
69        start = time.mktime(time.strptime('2000', '%Y'))
70        dt = 60  #One minute intervals
71        t = 0.0
72        while t <= finaltime:
73            t_string = time.strftime(time_format, time.gmtime(t+start))
74            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
75            t += dt
76
77        fid.close()
78
79        #Convert ASCII file to NetCDF (Which is what we really like!)
80        timefile2netcdf(filename)
81
82
83        #Create file function from time series
84        F = file_function(filename + '.tms',
85                          quantities = ['Attribute0',
86                                        'Attribute1',
87                                        'Attribute2'])
88       
89        #Now try interpolation
90        for i in range(20):
91            t = i*10
92            q = F(t)
93
94            #Exact linear intpolation
95            assert num.allclose(q[0], 2*t)
96            if i%6 == 0:
97                assert num.allclose(q[1], t**2)
98                assert num.allclose(q[2], sin(t*pi/600))
99
100        #Check non-exact
101
102        t = 90 #Halfway between 60 and 120
103        q = F(t)
104        assert num.allclose( (120**2 + 60**2)/2, q[1] )
105        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
106
107
108        t = 100 #Two thirds of the way between between 60 and 120
109        q = F(t)
110        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
111        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
112
113        os.remove(filename + '.txt')
114        os.remove(filename + '.tms')       
115
116
117       
118    def test_spatio_temporal_file_function_basic(self):
119        """Test that spatio temporal file function performs the correct
120        interpolations in both time and space
121        NetCDF version (x,y,t dependency)       
122        """
123        import time
124
125        #Create sww file of simple propagation from left to right
126        #through rectangular domain
127        from shallow_water import Domain, Dirichlet_boundary
128        from mesh_factory import rectangular
129
130        #Create basic mesh and shallow water domain
131        points, vertices, boundary = rectangular(3, 3)
132        domain1 = Domain(points, vertices, boundary)
133
134        from anuga.utilities.numerical_tools import mean
135        domain1.reduction = mean
136        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
137                              # only one value.
138
139        domain1.default_order = 2
140        domain1.store = True
141        domain1.set_datadir('.')
142        sww_file = 'spatio_temporal_boundary_source_%d' %(id(self))
143        domain1.set_name(sww_file)
144
145        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
146        domain1.set_quantity('elevation', 0)
147        domain1.set_quantity('friction', 0)
148        domain1.set_quantity('stage', 0)
149
150        # Boundary conditions
151        B0 = Dirichlet_boundary([0,0,0])
152        B6 = Dirichlet_boundary([0.6,0,0])
153        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
154        domain1.check_integrity()
155
156        finaltime = 8
157        #Evolution
158        t0 = -1
159        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
160            #print 'Timesteps: %.16f, %.16f' %(t0, t)
161            #if t == t0:
162            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
163            #   raise msg
164            t0 = t
165             
166            #domain1.write_time()
167
168
169        #Now read data from sww and check
170        from Scientific.IO.NetCDF import NetCDFFile
171        filename = domain1.get_name() + '.sww'
172        fid = NetCDFFile(filename)
173
174        x = fid.variables['x'][:]
175        y = fid.variables['y'][:]
176        stage = fid.variables['stage'][:]
177        xmomentum = fid.variables['xmomentum'][:]
178        ymomentum = fid.variables['ymomentum'][:]
179        time = fid.variables['time'][:]
180
181        #Take stage vertex values at last timestep on diagonal
182        #Diagonal is identified by vertices: 0, 5, 10, 15
183
184        last_time_index = len(time)-1 #Last last_time_index
185        d_stage = num.reshape(num.take(stage[last_time_index, :],
186                                       [0,5,10,15],
187                                       axis=0),
188                              (4,1))
189        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
190                                    [0,5,10,15],
191                                   axis=0),
192                           (4,1))
193        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
194                                    [0,5,10,15],
195                                   axis=0),
196                           (4,1))
197        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
198
199        #Reference interpolated values at midpoints on diagonal at
200        #this timestep are
201        r0 = (D[0] + D[1])/2
202        r1 = (D[1] + D[2])/2
203        r2 = (D[2] + D[3])/2
204
205        #And the midpoints are found now
206        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
207        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
208
209        diag = num.concatenate( (Dx, Dy), axis=1)
210        d_midpoints = (diag[1:] + diag[:-1])/2
211
212        #Let us see if the file function can find the correct
213        #values at the midpoints at the last timestep:
214        f = file_function(filename, domain1,
215                          interpolation_points = d_midpoints)
216
217        T = f.get_time()
218        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
219        assert not T[-1] == T[-2], msg
220        t = time[last_time_index]
221        q = f(t, point_id=0); assert num.allclose(r0, q)
222        q = f(t, point_id=1); assert num.allclose(r1, q)
223        q = f(t, point_id=2); assert num.allclose(r2, q)
224
225
226        ##################
227        #Now do the same for the first timestep
228
229        timestep = 0 #First timestep
230        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
231        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
232        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
233        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
234
235        #Reference interpolated values at midpoints on diagonal at
236        #this timestep are
237        r0 = (D[0] + D[1])/2
238        r1 = (D[1] + D[2])/2
239        r2 = (D[2] + D[3])/2
240
241        #Let us see if the file function can find the correct
242        #values
243        q = f(0, point_id=0); assert num.allclose(r0, q)
244        q = f(0, point_id=1); assert num.allclose(r1, q)
245        q = f(0, point_id=2); assert num.allclose(r2, q)
246
247
248        ##################
249        #Now do it again for a timestep in the middle
250
251        timestep = 33
252        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
253        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
254        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
255        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
256
257        #Reference interpolated values at midpoints on diagonal at
258        #this timestep are
259        r0 = (D[0] + D[1])/2
260        r1 = (D[1] + D[2])/2
261        r2 = (D[2] + D[3])/2
262
263        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
264        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
265        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
266
267
268        ##################
269        #Now check temporal interpolation
270        #Halfway between timestep 15 and 16
271
272        timestep = 15
273        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
274        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
275        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
276        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
277
278        #Reference interpolated values at midpoints on diagonal at
279        #this timestep are
280        r0_0 = (D[0] + D[1])/2
281        r1_0 = (D[1] + D[2])/2
282        r2_0 = (D[2] + D[3])/2
283
284        #
285        timestep = 16
286        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
287        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
288        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
289        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
290
291        #Reference interpolated values at midpoints on diagonal at
292        #this timestep are
293        r0_1 = (D[0] + D[1])/2
294        r1_1 = (D[1] + D[2])/2
295        r2_1 = (D[2] + D[3])/2
296
297        # The reference values are
298        r0 = (r0_0 + r0_1)/2
299        r1 = (r1_0 + r1_1)/2
300        r2 = (r2_0 + r2_1)/2
301
302        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
303        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
304        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
305
306        ##################
307        #Finally check interpolation 2 thirds of the way
308        #between timestep 15 and 16
309
310        # The reference values are
311        r0 = (r0_0 + 2*r0_1)/3
312        r1 = (r1_0 + 2*r1_1)/3
313        r2 = (r2_0 + 2*r2_1)/3
314
315        #And the file function gives
316        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
317        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
318        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
319
320        fid.close()
321        import os
322        os.remove(filename)
323
324
325
326    def test_spatio_temporal_file_function_different_origin(self):
327        """Test that spatio temporal file function performs the correct
328        interpolations in both time and space where space is offset by
329        xllcorner and yllcorner
330        NetCDF version (x,y,t dependency)       
331        """
332        import time
333
334        #Create sww file of simple propagation from left to right
335        #through rectangular domain
336        from shallow_water import Domain, Dirichlet_boundary
337        from mesh_factory import rectangular
338
339
340        from anuga.coordinate_transforms.geo_reference import Geo_reference
341        xllcorner = 2048
342        yllcorner = 11000
343        zone = 2
344
345        #Create basic mesh and shallow water domain
346        points, vertices, boundary = rectangular(3, 3)
347        domain1 = Domain(points, vertices, boundary,
348                         geo_reference = Geo_reference(xllcorner = xllcorner,
349                                                       yllcorner = yllcorner))
350       
351
352        from anuga.utilities.numerical_tools import mean       
353        domain1.reduction = mean
354        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
355                              # only one value.
356
357        domain1.default_order = 2
358        domain1.store = True
359        domain1.set_datadir('.')
360        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
361
362        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
363        domain1.set_quantity('elevation', 0)
364        domain1.set_quantity('friction', 0)
365        domain1.set_quantity('stage', 0)
366
367        # Boundary conditions
368        B0 = Dirichlet_boundary([0,0,0])
369        B6 = Dirichlet_boundary([0.6,0,0])
370        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
371        domain1.check_integrity()
372
373        finaltime = 8
374        #Evolution
375        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
376            pass
377            #domain1.write_time()
378
379
380        #Now read data from sww and check
381        from Scientific.IO.NetCDF import NetCDFFile
382        filename = domain1.get_name() + '.sww'
383        fid = NetCDFFile(filename)
384
385        x = fid.variables['x'][:]
386        y = fid.variables['y'][:]
387        # we 'cast' to 64 bit floats to pass this test
388        # SWW file quantities are stored as 32 bits
389        x = num.array(x, num.float)
390        y = num.array(y, num.float)
391
392        stage = fid.variables['stage'][:]
393        xmomentum = fid.variables['xmomentum'][:]
394        ymomentum = fid.variables['ymomentum'][:]
395        time = fid.variables['time'][:]
396
397        #Take stage vertex values at last timestep on diagonal
398        #Diagonal is identified by vertices: 0, 5, 10, 15
399
400        last_time_index = len(time)-1 #Last last_time_index     
401        d_stage = num.reshape(num.take(stage[last_time_index, :],
402                                       [0,5,10,15],
403                                       axis=0),
404                              (4,1))
405        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
406                                    [0,5,10,15],
407                                   axis=0),
408                           (4,1))
409        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
410                                    [0,5,10,15],
411                                    axis=0),
412                           (4,1))
413        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
414
415        #Reference interpolated values at midpoints on diagonal at
416        #this timestep are
417        r0 = (D[0] + D[1])/2
418        r1 = (D[1] + D[2])/2
419        r2 = (D[2] + D[3])/2
420
421        #And the midpoints are found now
422        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
423        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
424
425        diag = num.concatenate((Dx, Dy), axis=1)
426        d_midpoints = (diag[1:] + diag[:-1])/2
427
428
429        #Adjust for georef - make interpolation points absolute
430        d_midpoints[:,0] += xllcorner
431        d_midpoints[:,1] += yllcorner               
432
433        #Let us see if the file function can find the correct
434        #values at the midpoints at the last timestep:
435        f = file_function(filename, domain1,
436                          interpolation_points = d_midpoints)
437
438        t = time[last_time_index]                         
439
440        q = f(t, point_id=0)
441        msg = '\nr0=%s\nq=%s' % (str(r0), str(q))
442        assert num.allclose(r0, q), msg
443
444        q = f(t, point_id=1)
445        msg = '\nr1=%s\nq=%s' % (str(r1), str(q))
446        assert num.allclose(r1, q), msg
447
448        q = f(t, point_id=2)
449        msg = '\nr2=%s\nq=%s' % (str(r2), str(q))
450        assert num.allclose(r2, q), msg
451
452
453        ##################
454        #Now do the same for the first timestep
455
456        timestep = 0 #First timestep
457        d_stage = num.reshape(num.take(stage[timestep, :],
458                                       [0,5,10,15],
459                                       axis=0),
460                              (4,1))
461        d_uh = num.reshape(num.take(xmomentum[timestep, :],
462                                    [0,5,10,15],
463                                    axis=0),
464                           (4,1))
465        d_vh = num.reshape(num.take(ymomentum[timestep, :],
466                                    [0,5,10,15],
467                                    axis=0),
468                           (4,1))
469        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
470
471        #Reference interpolated values at midpoints on diagonal at
472        #this timestep are
473        r0 = (D[0] + D[1])/2
474        r1 = (D[1] + D[2])/2
475        r2 = (D[2] + D[3])/2
476
477        #Let us see if the file function can find the correct
478        #values
479        q = f(0, point_id=0); assert num.allclose(r0, q)
480        q = f(0, point_id=1); assert num.allclose(r1, q)
481        q = f(0, point_id=2); assert num.allclose(r2, q)
482
483
484        ##################
485        #Now do it again for a timestep in the middle
486
487        timestep = 33
488        d_stage = num.reshape(num.take(stage[timestep, :],
489                                       [0,5,10,15],
490                                       axis=0),
491                              (4,1))
492        d_uh = num.reshape(num.take(xmomentum[timestep, :],
493                                    [0,5,10,15],
494                                    axis=0),
495                           (4,1))
496        d_vh = num.reshape(num.take(ymomentum[timestep, :],
497                                    [0,5,10,15],
498                                    axis=0),
499                           (4,1))
500        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
501
502        #Reference interpolated values at midpoints on diagonal at
503        #this timestep are
504        r0 = (D[0] + D[1])/2
505        r1 = (D[1] + D[2])/2
506        r2 = (D[2] + D[3])/2
507
508        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
509        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
510        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
511
512
513        ##################
514        #Now check temporal interpolation
515        #Halfway between timestep 15 and 16
516
517        timestep = 15
518        d_stage = num.reshape(num.take(stage[timestep, :],
519                                       [0,5,10,15],
520                                       axis=0),
521                              (4,1))
522        d_uh = num.reshape(num.take(xmomentum[timestep, :],
523                                    [0,5,10,15],
524                                    axis=0),
525                           (4,1))
526        d_vh = num.reshape(num.take(ymomentum[timestep, :],
527                                    [0,5,10,15],
528                                    axis=0),
529                           (4,1))
530        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
531
532        #Reference interpolated values at midpoints on diagonal at
533        #this timestep are
534        r0_0 = (D[0] + D[1])/2
535        r1_0 = (D[1] + D[2])/2
536        r2_0 = (D[2] + D[3])/2
537
538        #
539        timestep = 16
540        d_stage = num.reshape(num.take(stage[timestep, :],
541                                       [0,5,10,15],
542                                       axis=0),
543                              (4,1))
544        d_uh = num.reshape(num.take(xmomentum[timestep, :],
545                                    [0,5,10,15],
546                                    axis=0),
547                           (4,1))
548        d_vh = num.reshape(num.take(ymomentum[timestep, :],
549                                    [0,5,10,15],
550                                    axis=0),
551                           (4,1))
552        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
553
554        #Reference interpolated values at midpoints on diagonal at
555        #this timestep are
556        r0_1 = (D[0] + D[1])/2
557        r1_1 = (D[1] + D[2])/2
558        r2_1 = (D[2] + D[3])/2
559
560        # The reference values are
561        r0 = (r0_0 + r0_1)/2
562        r1 = (r1_0 + r1_1)/2
563        r2 = (r2_0 + r2_1)/2
564
565        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
566        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
567        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
568
569        ##################
570        #Finally check interpolation 2 thirds of the way
571        #between timestep 15 and 16
572
573        # The reference values are
574        r0 = (r0_0 + 2*r0_1)/3
575        r1 = (r1_0 + 2*r1_1)/3
576        r2 = (r2_0 + 2*r2_1)/3
577
578        #And the file function gives
579        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
580        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
581        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
582
583        fid.close()
584        import os
585        os.remove(filename)
586
587       
588
589
590    def test_spatio_temporal_file_function_time(self):
591        """Test that File function interpolates correctly
592        between given times.
593        NetCDF version (x,y,t dependency)
594        """
595
596        #Create NetCDF (sww) file to be read
597        # x: 0, 5, 10, 15
598        # y: -20, -10, 0, 10
599        # t: 0, 60, 120, ...., 1200
600        #
601        # test quantities (arbitrary but non-trivial expressions):
602        #
603        #   stage     = 3*x - y**2 + 2*t
604        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
605        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
606
607        #NOTE: Nice test that may render some of the others redundant.
608
609        import os, time
610        from anuga.config import time_format
611        from mesh_factory import rectangular
612        from shallow_water import Domain
613        import anuga.shallow_water.data_manager
614
615        finaltime = 1200
616        filename = 'test_file_function'
617
618        #Create a domain to hold test grid
619        #(0:15, -20:10)
620        points, vertices, boundary =\
621                rectangular(4, 4, 15, 30, origin = (0, -20))
622        #print "points", points
623
624        #print 'Number of elements', len(vertices)
625        domain = Domain(points, vertices, boundary)
626        domain.smooth = False
627        domain.default_order = 2
628        domain.set_datadir('.')
629        domain.set_name(filename)
630        domain.store = True
631
632        #print points
633        start = time.mktime(time.strptime('2000', '%Y'))
634        domain.starttime = start
635
636
637        #Store structure
638        domain.initialise_storage()
639
640        #Compute artificial time steps and store
641        dt = 60  #One minute intervals
642        t = 0.0
643        while t <= finaltime:
644            #Compute quantities
645            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
646            domain.set_quantity('stage', f1)
647
648            f2 = lambda x,y: x+y+t**2
649            domain.set_quantity('xmomentum', f2)
650
651            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
652            domain.set_quantity('ymomentum', f3)
653
654            #Store and advance time
655            domain.time = t
656            domain.store_timestep()
657            t += dt
658
659
660        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
661     
662        #Deliberately set domain.starttime to too early
663        domain.starttime = start - 1
664
665        #Create file function
666        F = file_function(filename + '.sww', domain,
667                          quantities = domain.conserved_quantities,
668                          interpolation_points = interpolation_points)
669
670        #Check that FF updates fixes domain starttime
671        assert num.allclose(domain.starttime, start)
672
673        #Check that domain.starttime isn't updated if later
674        domain.starttime = start + 1
675        F = file_function(filename + '.sww', domain,
676                          quantities = domain.conserved_quantities,
677                          interpolation_points = interpolation_points)
678        assert num.allclose(domain.starttime, start+1)
679        domain.starttime = start
680
681
682        #Check linear interpolation in time
683        F = file_function(filename + '.sww', domain,
684                          quantities = domain.conserved_quantities,
685                          interpolation_points = interpolation_points)               
686        for id in range(len(interpolation_points)):
687            x = interpolation_points[id][0]
688            y = interpolation_points[id][1]
689
690            for i in range(20):
691                t = i*10
692                k = i%6
693
694                if k == 0:
695                    q0 = F(t, point_id=id)
696                    q1 = F(t+60, point_id=id)
697
698                if num.alltrue(q0 == NAN):
699                    actual = q0
700                else:
701                    actual = (k*q1 + (6-k)*q0)/6
702                q = F(t, point_id=id)
703                #print i, k, t, q
704                #print ' ', q0
705                #print ' ', q1
706                #print "q",q
707                #print "actual", actual
708                #print
709                if num.alltrue(q0 == NAN):
710                     self.failUnless(num.alltrue(q == actual), 'Fail!')
711                else:
712                    assert num.allclose(q, actual)
713
714
715        #Another check of linear interpolation in time
716        for id in range(len(interpolation_points)):
717            q60 = F(60, point_id=id)
718            q120 = F(120, point_id=id)
719
720            t = 90 #Halfway between 60 and 120
721            q = F(t, point_id=id)
722            assert num.allclose( (q120+q60)/2, q )
723
724            t = 100 #Two thirds of the way between between 60 and 120
725            q = F(t, point_id=id)
726            assert num.allclose(q60/3 + 2*q120/3, q)
727
728
729
730        #Check that domain.starttime isn't updated if later than file starttime but earlier
731        #than file end time
732        delta = 23
733        domain.starttime = start + delta
734        F = file_function(filename + '.sww', domain,
735                          quantities = domain.conserved_quantities,
736                          interpolation_points = interpolation_points)
737        assert num.allclose(domain.starttime, start+delta)
738
739
740
741
742        #Now try interpolation with delta offset
743        for id in range(len(interpolation_points)):           
744            x = interpolation_points[id][0]
745            y = interpolation_points[id][1]
746
747            for i in range(20):
748                t = i*10
749                k = i%6
750
751                if k == 0:
752                    q0 = F(t-delta, point_id=id)
753                    q1 = F(t+60-delta, point_id=id)
754
755                q = F(t-delta, point_id=id)
756                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
757
758
759        os.remove(filename + '.sww')
760
761
762
763    def Xtest_spatio_temporal_file_function_time(self):
764        # FIXME: This passes but needs some TLC
765        # Test that File function interpolates correctly
766        # When some points are outside the mesh
767
768        import os, time
769        from anuga.config import time_format
770        from mesh_factory import rectangular
771        from shallow_water import Domain
772        import anuga.shallow_water.data_manager 
773        from anuga.pmesh.mesh_interface import create_mesh_from_regions
774        finaltime = 1200
775       
776        filename = tempfile.mktemp()
777        #print "filename",filename
778        filename = 'test_file_function'
779
780        meshfilename = tempfile.mktemp(".tsh")
781
782        boundary_tags = {'walls':[0,1],'bom':[2]}
783       
784        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
785       
786        create_mesh_from_regions(polygon_absolute,
787                                 boundary_tags,
788                                 10000000,
789                                 filename=meshfilename)
790        domain = Domain(mesh_filename=meshfilename)
791        domain.smooth = False
792        domain.default_order = 2
793        domain.set_datadir('.')
794        domain.set_name(filename)
795        domain.store = True
796
797        #print points
798        start = time.mktime(time.strptime('2000', '%Y'))
799        domain.starttime = start
800       
801
802        #Store structure
803        domain.initialise_storage()
804
805        #Compute artificial time steps and store
806        dt = 60  #One minute intervals
807        t = 0.0
808        while t <= finaltime:
809            #Compute quantities
810            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
811            domain.set_quantity('stage', f1)
812
813            f2 = lambda x,y: x+y+t**2
814            domain.set_quantity('xmomentum', f2)
815
816            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
817            domain.set_quantity('ymomentum', f3)
818
819            #Store and advance time
820            domain.time = t
821            domain.store_timestep()
822            t += dt
823
824        interpolation_points = [[1,0]]
825        interpolation_points = [[100,1000]]
826       
827        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
828                                [78787,78787],[7878,3432]]
829           
830        #Deliberately set domain.starttime to too early
831        domain.starttime = start - 1
832
833        #Create file function
834        F = file_function(filename + '.sww', domain,
835                          quantities = domain.conserved_quantities,
836                          interpolation_points = interpolation_points)
837
838        #Check that FF updates fixes domain starttime
839        assert num.allclose(domain.starttime, start)
840
841        #Check that domain.starttime isn't updated if later
842        domain.starttime = start + 1
843        F = file_function(filename + '.sww', domain,
844                          quantities = domain.conserved_quantities,
845                          interpolation_points = interpolation_points)
846        assert num.allclose(domain.starttime, start+1)
847        domain.starttime = start
848
849
850        #Check linear interpolation in time
851        # checking points inside and outside the mesh
852        F = file_function(filename + '.sww', domain,
853                          quantities = domain.conserved_quantities,
854                          interpolation_points = interpolation_points)
855       
856        for id in range(len(interpolation_points)):
857            x = interpolation_points[id][0]
858            y = interpolation_points[id][1]
859
860            for i in range(20):
861                t = i*10
862                k = i%6
863
864                if k == 0:
865                    q0 = F(t, point_id=id)
866                    q1 = F(t+60, point_id=id)
867
868                if q0 == NAN:
869                    actual = q0
870                else:
871                    actual = (k*q1 + (6-k)*q0)/6
872                q = F(t, point_id=id)
873                #print i, k, t, q
874                #print ' ', q0
875                #print ' ', q1
876                #print "q",q
877                #print "actual", actual
878                #print
879                if q0 == NAN:
880                     self.failUnless( q == actual, 'Fail!')
881                else:
882                    assert num.allclose(q, actual)
883
884        # now lets check points inside the mesh
885        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
886        interpolation_points = [[10,-12.5]]
887           
888        print "len(interpolation_points)",len(interpolation_points) 
889        F = file_function(filename + '.sww', domain,
890                          quantities = domain.conserved_quantities,
891                          interpolation_points = interpolation_points)
892
893        domain.starttime = start
894
895
896        #Check linear interpolation in time
897        F = file_function(filename + '.sww', domain,
898                          quantities = domain.conserved_quantities,
899                          interpolation_points = interpolation_points)               
900        for id in range(len(interpolation_points)):
901            x = interpolation_points[id][0]
902            y = interpolation_points[id][1]
903
904            for i in range(20):
905                t = i*10
906                k = i%6
907
908                if k == 0:
909                    q0 = F(t, point_id=id)
910                    q1 = F(t+60, point_id=id)
911
912                if q0 == NAN:
913                    actual = q0
914                else:
915                    actual = (k*q1 + (6-k)*q0)/6
916                q = F(t, point_id=id)
917                print "############"
918                print "id, x, y ", id, x, y #k, t, q
919                print "t", t
920                #print ' ', q0
921                #print ' ', q1
922                print "q",q
923                print "actual", actual
924                #print
925                if q0 == NAN:
926                     self.failUnless( q == actual, 'Fail!')
927                else:
928                    assert num.allclose(q, actual)
929
930
931        #Another check of linear interpolation in time
932        for id in range(len(interpolation_points)):
933            q60 = F(60, point_id=id)
934            q120 = F(120, point_id=id)
935
936            t = 90 #Halfway between 60 and 120
937            q = F(t, point_id=id)
938            assert num.allclose( (q120+q60)/2, q )
939
940            t = 100 #Two thirds of the way between between 60 and 120
941            q = F(t, point_id=id)
942            assert num.allclose(q60/3 + 2*q120/3, q)
943
944
945
946        #Check that domain.starttime isn't updated if later than file starttime but earlier
947        #than file end time
948        delta = 23
949        domain.starttime = start + delta
950        F = file_function(filename + '.sww', domain,
951                          quantities = domain.conserved_quantities,
952                          interpolation_points = interpolation_points)
953        assert num.allclose(domain.starttime, start+delta)
954
955
956
957
958        #Now try interpolation with delta offset
959        for id in range(len(interpolation_points)):           
960            x = interpolation_points[id][0]
961            y = interpolation_points[id][1]
962
963            for i in range(20):
964                t = i*10
965                k = i%6
966
967                if k == 0:
968                    q0 = F(t-delta, point_id=id)
969                    q1 = F(t+60-delta, point_id=id)
970
971                q = F(t-delta, point_id=id)
972                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
973
974
975        os.remove(filename + '.sww')
976
977    def test_file_function_time_with_domain(self):
978        """Test that File function interpolates correctly
979        between given times. No x,y dependency here.
980        Use domain with starttime
981        """
982
983        #Write file
984        import os, time, calendar
985        from anuga.config import time_format
986        from math import sin, pi
987        from domain import Domain
988
989        finaltime = 1200
990        filename = 'test_file_function'
991        fid = open(filename + '.txt', 'w')
992        start = time.mktime(time.strptime('2000', '%Y'))
993        dt = 60  #One minute intervals
994        t = 0.0
995        while t <= finaltime:
996            t_string = time.strftime(time_format, time.gmtime(t+start))
997            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
998            t += dt
999
1000        fid.close()
1001
1002
1003        #Convert ASCII file to NetCDF (Which is what we really like!)
1004        timefile2netcdf(filename)
1005
1006
1007
1008        a = [0.0, 0.0]
1009        b = [4.0, 0.0]
1010        c = [0.0, 3.0]
1011
1012        points = [a, b, c]
1013        vertices = [[0,1,2]]
1014        domain = Domain(points, vertices)
1015
1016        # Check that domain.starttime is updated if non-existing
1017        F = file_function(filename + '.tms',
1018                          domain,
1019                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
1020        assert num.allclose(domain.starttime, start)
1021
1022        # Check that domain.starttime is updated if too early
1023        domain.starttime = start - 1
1024        F = file_function(filename + '.tms',
1025                          domain,
1026                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1027        assert num.allclose(domain.starttime, start)
1028
1029        # Check that domain.starttime isn't updated if later
1030        domain.starttime = start + 1
1031        F = file_function(filename + '.tms',
1032                          domain,
1033                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1034        assert num.allclose(domain.starttime, start+1)
1035
1036        domain.starttime = start
1037        F = file_function(filename + '.tms',
1038                          domain,
1039                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
1040                          use_cache=True)
1041       
1042
1043        #print F.precomputed_values
1044        #print 'F(60)', F(60)
1045       
1046        #Now try interpolation
1047        for i in range(20):
1048            t = i*10
1049            q = F(t)
1050
1051            #Exact linear intpolation
1052            assert num.allclose(q[0], 2*t)
1053            if i%6 == 0:
1054                assert num.allclose(q[1], t**2)
1055                assert num.allclose(q[2], sin(t*pi/600))
1056
1057        #Check non-exact
1058
1059        t = 90 #Halfway between 60 and 120
1060        q = F(t)
1061        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1062        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1063
1064
1065        t = 100 #Two thirds of the way between between 60 and 120
1066        q = F(t)
1067        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1068        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1069
1070        os.remove(filename + '.tms')
1071        os.remove(filename + '.txt')       
1072
1073    def test_file_function_time_with_domain_different_start(self):
1074        """Test that File function interpolates correctly
1075        between given times. No x,y dependency here.
1076        Use domain with a starttime later than that of file
1077
1078        ASCII version
1079        """
1080
1081        #Write file
1082        import os, time, calendar
1083        from anuga.config import time_format
1084        from math import sin, pi
1085        from domain import Domain
1086
1087        finaltime = 1200
1088        filename = 'test_file_function'
1089        fid = open(filename + '.txt', 'w')
1090        start = time.mktime(time.strptime('2000', '%Y'))
1091        dt = 60  #One minute intervals
1092        t = 0.0
1093        while t <= finaltime:
1094            t_string = time.strftime(time_format, time.gmtime(t+start))
1095            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1096            t += dt
1097
1098        fid.close()
1099
1100        #Convert ASCII file to NetCDF (Which is what we really like!)
1101        timefile2netcdf(filename)       
1102
1103        a = [0.0, 0.0]
1104        b = [4.0, 0.0]
1105        c = [0.0, 3.0]
1106
1107        points = [a, b, c]
1108        vertices = [[0,1,2]]
1109        domain = Domain(points, vertices)
1110
1111        #Check that domain.starttime isn't updated if later than file starttime but earlier
1112        #than file end time
1113        delta = 23
1114        domain.starttime = start + delta
1115        F = file_function(filename + '.tms', domain,
1116                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1117        assert num.allclose(domain.starttime, start+delta)
1118
1119        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1120                                            277., 337., 397., 457., 517.,
1121                                            577., 637., 697., 757., 817.,
1122                                            877., 937., 997., 1057., 1117.,
1123                                            1177.])
1124
1125
1126        #Now try interpolation with delta offset
1127        for i in range(20):
1128            t = i*10
1129            q = F(t-delta)
1130
1131            #Exact linear intpolation
1132            assert num.allclose(q[0], 2*t)
1133            if i%6 == 0:
1134                assert num.allclose(q[1], t**2)
1135                assert num.allclose(q[2], sin(t*pi/600))
1136
1137        #Check non-exact
1138
1139        t = 90 #Halfway between 60 and 120
1140        q = F(t-delta)
1141        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1142        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1143
1144
1145        t = 100 #Two thirds of the way between between 60 and 120
1146        q = F(t-delta)
1147        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1148        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1149
1150
1151        os.remove(filename + '.tms')
1152        os.remove(filename + '.txt')               
1153
1154       
1155
1156    def test_file_function_time_with_domain_different_start_and_time_limit(self):
1157        """Test that File function interpolates correctly
1158        between given times. No x,y dependency here.
1159        Use domain with a starttime later than that of file
1160
1161        ASCII version
1162       
1163        This test also tests that time can be truncated.
1164        """
1165
1166        # Write file
1167        import os, time, calendar
1168        from anuga.config import time_format
1169        from math import sin, pi
1170        from domain import Domain
1171
1172        finaltime = 1200
1173        filename = 'test_file_function'
1174        fid = open(filename + '.txt', 'w')
1175        start = time.mktime(time.strptime('2000', '%Y'))
1176        dt = 60  #One minute intervals
1177        t = 0.0
1178        while t <= finaltime:
1179            t_string = time.strftime(time_format, time.gmtime(t+start))
1180            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1181            t += dt
1182
1183        fid.close()
1184
1185        # Convert ASCII file to NetCDF (Which is what we really like!)
1186        timefile2netcdf(filename)       
1187
1188        a = [0.0, 0.0]
1189        b = [4.0, 0.0]
1190        c = [0.0, 3.0]
1191
1192        points = [a, b, c]
1193        vertices = [[0,1,2]]
1194        domain = Domain(points, vertices)
1195
1196        # Check that domain.starttime isn't updated if later than file starttime but earlier
1197        # than file end time
1198        delta = 23
1199        domain.starttime = start + delta
1200        time_limit = domain.starttime + 600
1201        F = file_function(filename + '.tms', domain,
1202                          time_limit=time_limit,
1203                          quantities=['Attribute0', 'Attribute1', 'Attribute2'])       
1204        assert num.allclose(domain.starttime, start+delta)
1205
1206        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1207                                            277., 337., 397., 457., 517.,
1208                                            577.])       
1209
1210
1211
1212        # Now try interpolation with delta offset
1213        for i in range(20):
1214            t = i*10
1215            q = F(t-delta)
1216
1217            #Exact linear intpolation
1218            assert num.allclose(q[0], 2*t)
1219            if i%6 == 0:
1220                assert num.allclose(q[1], t**2)
1221                assert num.allclose(q[2], sin(t*pi/600))
1222
1223        # Check non-exact
1224        t = 90 #Halfway between 60 and 120
1225        q = F(t-delta)
1226        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1227        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1228
1229
1230        t = 100 # Two thirds of the way between between 60 and 120
1231        q = F(t-delta)
1232        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1233        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1234
1235
1236        os.remove(filename + '.tms')
1237        os.remove(filename + '.txt')               
1238
1239       
1240       
1241       
1242
1243
1244    def test_apply_expression_to_dictionary(self):
1245
1246        #FIXME: Division is not expected to work for integers.
1247        #This must be caught.
1248        foo = num.array([[1,2,3], [4,5,6]], num.float)
1249
1250        bar = num.array([[-1,0,5], [6,1,1]], num.float)                 
1251
1252        D = {'X': foo, 'Y': bar}
1253
1254        Z = apply_expression_to_dictionary('X+Y', D)       
1255        assert num.allclose(Z, foo+bar)
1256
1257        Z = apply_expression_to_dictionary('X*Y', D)       
1258        assert num.allclose(Z, foo*bar)       
1259
1260        Z = apply_expression_to_dictionary('4*X+Y', D)       
1261        assert num.allclose(Z, 4*foo+bar)       
1262
1263        # test zero division is OK
1264        Z = apply_expression_to_dictionary('X/Y', D)
1265        assert num.allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1266
1267        # make an error for zero on zero
1268        # this is really an error in numeric, SciPy core can handle it
1269        # Z = apply_expression_to_dictionary('0/Y', D)
1270
1271        #Check exceptions
1272        try:
1273            #Wrong name
1274            Z = apply_expression_to_dictionary('4*X+A', D)       
1275        except NameError:
1276            pass
1277        else:
1278            msg = 'Should have raised a NameError Exception'
1279            raise msg
1280
1281
1282        try:
1283            #Wrong order
1284            Z = apply_expression_to_dictionary(D, '4*X+A')       
1285        except AssertionError:
1286            pass
1287        else:
1288            msg = 'Should have raised a AssertionError Exception'
1289            raise msg       
1290       
1291
1292    def test_multiple_replace(self):
1293        """Hard test that checks a true word-by-word simultaneous replace
1294        """
1295       
1296        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1297        exp = '3*x+y + xi'
1298       
1299        new = multiple_replace(exp, D)
1300       
1301        assert new == '3*xi+eta + lam'
1302                         
1303
1304
1305    def test_point_on_line_obsolete(self):
1306        """Test that obsolete call issues appropriate warning"""
1307
1308        #Turn warning into an exception
1309        import warnings
1310        warnings.filterwarnings('error')
1311
1312        try:
1313            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1314        except DeprecationWarning:
1315            pass
1316        else:
1317            msg = 'point_on_line should have issued a DeprecationWarning'
1318            raise Exception(msg)   
1319
1320        warnings.resetwarnings()
1321   
1322    def test_get_revision_number(self):
1323        """test_get_revision_number(self):
1324
1325        Test that revision number can be retrieved.
1326        """
1327        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1328            # I have a known snv incompatability issue,
1329            # so I'm skipping this test.
1330            # FIXME when SVN is upgraded on our clusters
1331            pass
1332        else:   
1333            n = get_revision_number()
1334            assert n>=0
1335
1336
1337       
1338    def test_add_directories(self):
1339       
1340        import tempfile
1341        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1342        directories = ['ja','ne','ke']
1343        kens_dir = add_directories(root_dir, directories)
1344        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1345               sep + 'ke'
1346        assert access(root_dir,F_OK)
1347
1348        add_directories(root_dir, directories)
1349        assert access(root_dir,F_OK)
1350       
1351        #clean up!
1352        os.rmdir(kens_dir)
1353        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1354        os.rmdir(root_dir + sep + 'ja')
1355        os.rmdir(root_dir)
1356
1357    def test_add_directories_bad(self):
1358       
1359        import tempfile
1360        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1361        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1362       
1363        try:
1364            kens_dir = add_directories(root_dir, directories)
1365        except OSError:
1366            pass
1367        else:
1368            msg = 'bad dir name should give OSError'
1369            raise Exception(msg)   
1370           
1371        #clean up!
1372        os.rmdir(root_dir)
1373
1374    def test_check_list(self):
1375
1376        check_list(['stage','xmomentum'])
1377
1378       
1379    def test_add_directories(self):
1380       
1381        import tempfile
1382        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1383        directories = ['ja','ne','ke']
1384        kens_dir = add_directories(root_dir, directories)
1385        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1386               sep + 'ke'
1387        assert access(root_dir,F_OK)
1388
1389        add_directories(root_dir, directories)
1390        assert access(root_dir,F_OK)
1391       
1392        #clean up!
1393        os.rmdir(kens_dir)
1394        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1395        os.rmdir(root_dir + sep + 'ja')
1396        os.rmdir(root_dir)
1397
1398    def test_add_directories_bad(self):
1399       
1400        import tempfile
1401        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1402        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1403       
1404        try:
1405            kens_dir = add_directories(root_dir, directories)
1406        except OSError:
1407            pass
1408        else:
1409            msg = 'bad dir name should give OSError'
1410            raise Exception(msg)   
1411           
1412        #clean up!
1413        os.rmdir(root_dir)
1414
1415    def test_check_list(self):
1416
1417        check_list(['stage','xmomentum'])
1418
1419######
1420# Test the remove_lone_verts() function
1421######
1422       
1423    def test_remove_lone_verts_a(self):
1424        verts = [[0,0],[1,0],[0,1]]
1425        tris = [[0,1,2]]
1426        new_verts, new_tris = remove_lone_verts(verts, tris)
1427        self.failUnless(new_verts.tolist() == verts)
1428        self.failUnless(new_tris.tolist() == tris)
1429
1430    def test_remove_lone_verts_b(self):
1431        verts = [[0,0],[1,0],[0,1],[99,99]]
1432        tris = [[0,1,2]]
1433        new_verts, new_tris = remove_lone_verts(verts, tris)
1434        self.failUnless(new_verts.tolist() == verts[0:3])
1435        self.failUnless(new_tris.tolist() == tris)
1436       
1437    def test_remove_lone_verts_c(self):
1438        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1439        tris = [[1,2,4]]
1440        new_verts, new_tris = remove_lone_verts(verts, tris)
1441        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1442        self.failUnless(new_tris.tolist() == [[0,1,2]])
1443     
1444    def test_remove_lone_verts_d(self):
1445        verts = [[0,0],[1,0],[99,99],[0,1]]
1446        tris = [[0,1,3]]
1447        new_verts, new_tris = remove_lone_verts(verts, tris)
1448        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1449        self.failUnless(new_tris.tolist() == [[0,1,2]])
1450       
1451    def test_remove_lone_verts_e(self):
1452        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1453        tris = [[0,1,2]]
1454        new_verts, new_tris = remove_lone_verts(verts, tris)
1455        self.failUnless(new_verts.tolist() == verts[0:3])
1456        self.failUnless(new_tris.tolist() == tris)
1457     
1458    def test_remove_lone_verts_f(self):
1459        verts = [[0,0],[1,0],[99,99],[0,1],[99,99],[1,1],[99,99]]
1460        tris = [[0,1,3],[0,1,5]]
1461        new_verts, new_tris = remove_lone_verts(verts, tris)
1462        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1],[1,1]])
1463        self.failUnless(new_tris.tolist() == [[0,1,2],[0,1,3]])
1464       
1465######
1466#
1467######
1468       
1469    def test_get_min_max_values(self):
1470       
1471        list=[8,9,6,1,4]
1472        min1, max1 = get_min_max_values(list)
1473       
1474        assert min1==1 
1475        assert max1==9
1476       
1477    def test_get_min_max_values1(self):
1478       
1479        list=[-8,-9,-6,-1,-4]
1480        min1, max1 = get_min_max_values(list)
1481       
1482#        print 'min1,max1',min1,max1
1483        assert min1==-9 
1484        assert max1==-1
1485
1486#    def test_get_min_max_values2(self):
1487#        '''
1488#        The min and max supplied are greater than the ones in the
1489#        list and therefore are the ones returned
1490#        '''
1491#        list=[-8,-9,-6,-1,-4]
1492#        min1, max1 = get_min_max_values(list,-10,10)
1493#       
1494##        print 'min1,max1',min1,max1
1495#        assert min1==-10
1496#        assert max1==10
1497       
1498    def test_make_plots_from_csv_files(self):
1499       
1500        #if sys.platform == 'win32':  #Windows
1501            try: 
1502                import pylab
1503            except ImportError:
1504                #ANUGA don't need pylab to work so the system doesn't
1505                #rely on pylab being installed
1506                return
1507           
1508       
1509            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1510            temp_dir = tempfile.mkdtemp('','figures')
1511    #        print 'temp_dir',temp_dir
1512            fileName = temp_dir+sep+'time_series_3.csv'
1513            file = open(fileName,"w")
1514            file.write("time,stage,speed,momentum,elevation\n\
15151.0, 0, 0, 0, 10 \n\
15162.0, 5, 2, 4, 10 \n\
15173.0, 3, 3, 5, 10 \n")
1518            file.close()
1519   
1520            fileName1 = temp_dir+sep+'time_series_4.csv'
1521            file1 = open(fileName1,"w")
1522            file1.write("time,stage,speed,momentum,elevation\n\
15231.0, 0, 0, 0, 5 \n\
15242.0, -5, -2, -4, 5 \n\
15253.0, -4, -3, -5, 5 \n")
1526            file1.close()
1527   
1528            fileName2 = temp_dir+sep+'time_series_5.csv'
1529            file2 = open(fileName2,"w")
1530            file2.write("time,stage,speed,momentum,elevation\n\
15311.0, 0, 0, 0, 7 \n\
15322.0, 4, -0.45, 57, 7 \n\
15333.0, 6, -0.5, 56, 7 \n")
1534            file2.close()
1535           
1536            dir, name=os.path.split(fileName)
1537            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1538                                  output_dir=temp_dir,
1539                                  base_name='time_series_',
1540                                  plot_numbers=['3-5'],
1541                                  quantities=['speed','stage','momentum'],
1542                                  assess_all_csv_files=True,
1543                                  extra_plot_name='test')
1544           
1545            #print dir+sep+name[:-4]+'_stage_test.png'
1546            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1547            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1548            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1549   
1550            dir1, name1=os.path.split(fileName1)
1551            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1552            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1553            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1554   
1555   
1556            dir2, name2=os.path.split(fileName2)
1557            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1558            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1559            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1560   
1561            del_dir(temp_dir)
1562       
1563
1564    def test_sww2csv_gauges(self):
1565
1566        def elevation_function(x, y):
1567            return -x
1568       
1569        """Most of this test was copied from test_interpolate
1570        test_interpole_sww2csv
1571       
1572        This is testing the gauge_sww2csv function, by creating a sww file and
1573        then exporting the gauges and checking the results.
1574        """
1575       
1576        # Create mesh
1577        mesh_file = tempfile.mktemp(".tsh")   
1578        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1579        m = Mesh()
1580        m.add_vertices(points)
1581        m.auto_segment()
1582        m.generate_mesh(verbose=False)
1583        m.export_mesh_file(mesh_file)
1584       
1585        # Create shallow water domain
1586        domain = Domain(mesh_file)
1587        os.remove(mesh_file)
1588       
1589        domain.default_order=2
1590       
1591        # This test was made before tight_slope_limiters were introduced
1592        # Since were are testing interpolation values this is OK
1593        domain.tight_slope_limiters = 0 
1594       
1595
1596        # Set some field values
1597        domain.set_quantity('elevation', elevation_function)
1598        domain.set_quantity('friction', 0.03)
1599        domain.set_quantity('xmomentum', 3.0)
1600        domain.set_quantity('ymomentum', 4.0)
1601
1602        ######################
1603        # Boundary conditions
1604        B = Transmissive_boundary(domain)
1605        domain.set_boundary( {'exterior': B})
1606
1607        # This call mangles the stage values.
1608        domain.distribute_to_vertices_and_edges()
1609        domain.set_quantity('stage', 1.0)
1610
1611
1612        domain.set_name('datatest' + str(time.time()))
1613        domain.smooth = True
1614        domain.reduction = mean
1615
1616
1617        sww = SWW_file(domain)
1618        sww.store_connectivity()
1619        sww.store_timestep()
1620        domain.set_quantity('stage', 10.0) # This is automatically limited
1621        # so it will not be less than the elevation
1622        domain.time = 2.
1623        sww.store_timestep()
1624
1625        # test the function
1626        points = [[5.0,1.],[0.5,2.]]
1627
1628        points_file = tempfile.mktemp(".csv")
1629#        points_file = 'test_point.csv'
1630        file_id = open(points_file,"w")
1631        file_id.write("name, easting, northing, elevation \n\
1632point1, 5.0, 1.0, 3.0\n\
1633point2, 0.5, 2.0, 9.0\n")
1634        file_id.close()
1635
1636       
1637        sww2csv_gauges(sww.filename, 
1638                       points_file,
1639                       verbose=False,
1640                       use_cache=False)
1641
1642#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1643        point1_answers_array = [[0.0,0.0,1.0,6.0,-5.0,3.0,4.0], [2.0,2.0/3600.,10.0,15.0,-5.0,3.0,4.0]]
1644        point1_filename = 'gauge_point1.csv'
1645        point1_handle = file(point1_filename)
1646        point1_reader = reader(point1_handle)
1647        point1_reader.next()
1648
1649        line=[]
1650        for i,row in enumerate(point1_reader):
1651#            print 'i',i,'row',row
1652            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1653                         float(row[4]),float(row[5]),float(row[6])])
1654#            print 'assert line',line[i],'point1',point1_answers_array[i]
1655            assert num.allclose(line[i], point1_answers_array[i])
1656
1657        point2_answers_array = [[0.0,0.0,1.0,1.5,-0.5,3.0,4.0], [2.0,2.0/3600.,10.0,10.5,-0.5,3.0,4.0]]
1658        point2_filename = 'gauge_point2.csv' 
1659        point2_handle = file(point2_filename)
1660        point2_reader = reader(point2_handle)
1661        point2_reader.next()
1662                       
1663        line=[]
1664        for i,row in enumerate(point2_reader):
1665#            print 'i',i,'row',row
1666            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1667                         float(row[4]),float(row[5]),float(row[6])])
1668#            print 'assert line',line[i],'point1',point1_answers_array[i]
1669            assert num.allclose(line[i], point2_answers_array[i])
1670                         
1671        # clean up
1672        point1_handle.close()
1673        point2_handle.close()
1674        #print "sww.filename",sww.filename
1675        os.remove(sww.filename)
1676        os.remove(points_file)
1677        os.remove(point1_filename)
1678        os.remove(point2_filename)
1679
1680
1681
1682    def test_sww2csv_gauges1(self):
1683        from anuga.pmesh.mesh import Mesh
1684        from anuga.shallow_water import Domain, Transmissive_boundary
1685        from csv import reader,writer
1686        import time
1687        import string
1688
1689        def elevation_function(x, y):
1690            return -x
1691       
1692        """Most of this test was copied from test_interpolate
1693        test_interpole_sww2csv
1694       
1695        This is testing the gauge_sww2csv function, by creating a sww file and
1696        then exporting the gauges and checking the results.
1697       
1698        This tests the ablity not to have elevation in the points file and
1699        not store xmomentum and ymomentum
1700        """
1701       
1702        # Create mesh
1703        mesh_file = tempfile.mktemp(".tsh")   
1704        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1705        m = Mesh()
1706        m.add_vertices(points)
1707        m.auto_segment()
1708        m.generate_mesh(verbose=False)
1709        m.export_mesh_file(mesh_file)
1710       
1711        # Create shallow water domain
1712        domain = Domain(mesh_file)
1713        os.remove(mesh_file)
1714       
1715        domain.default_order=2
1716
1717        # Set some field values
1718        domain.set_quantity('elevation', elevation_function)
1719        domain.set_quantity('friction', 0.03)
1720        domain.set_quantity('xmomentum', 3.0)
1721        domain.set_quantity('ymomentum', 4.0)
1722
1723        ######################
1724        # Boundary conditions
1725        B = Transmissive_boundary(domain)
1726        domain.set_boundary( {'exterior': B})
1727
1728        # This call mangles the stage values.
1729        domain.distribute_to_vertices_and_edges()
1730        domain.set_quantity('stage', 1.0)
1731
1732
1733        domain.set_name('datatest' + str(time.time()))
1734        domain.smooth = True
1735        domain.reduction = mean
1736
1737        sww = SWW_file(domain)
1738        sww.store_connectivity()
1739        sww.store_timestep()
1740        domain.set_quantity('stage', 10.0) # This is automatically limited
1741        # so it will not be less than the elevation
1742        domain.time = 2.
1743        sww.store_timestep()
1744
1745        # test the function
1746        points = [[5.0,1.],[0.5,2.]]
1747
1748        points_file = tempfile.mktemp(".csv")
1749#        points_file = 'test_point.csv'
1750        file_id = open(points_file,"w")
1751        file_id.write("name,easting,northing \n\
1752point1, 5.0, 1.0\n\
1753point2, 0.5, 2.0\n")
1754        file_id.close()
1755
1756        sww2csv_gauges(sww.filename, 
1757                            points_file,
1758                            quantities=['stage', 'elevation'],
1759                            use_cache=False,
1760                            verbose=False)
1761
1762        point1_answers_array = [[0.0,1.0,-5.0], [2.0,10.0,-5.0]]
1763        point1_filename = 'gauge_point1.csv'
1764        point1_handle = file(point1_filename)
1765        point1_reader = reader(point1_handle)
1766        point1_reader.next()
1767
1768        line=[]
1769        for i,row in enumerate(point1_reader):
1770#            print 'i',i,'row',row
1771            # note the 'hole' (element 1) below - skip the new 'hours' field
1772            line.append([float(row[0]),float(row[2]),float(row[3])])
1773            #print 'line',line[i],'point1',point1_answers_array[i]
1774            assert num.allclose(line[i], point1_answers_array[i])
1775
1776        point2_answers_array = [[0.0,1.0,-0.5], [2.0,10.0,-0.5]]
1777        point2_filename = 'gauge_point2.csv' 
1778        point2_handle = file(point2_filename)
1779        point2_reader = reader(point2_handle)
1780        point2_reader.next()
1781                       
1782        line=[]
1783        for i,row in enumerate(point2_reader):
1784#            print 'i',i,'row',row
1785            # note the 'hole' (element 1) below - skip the new 'hours' field
1786            line.append([float(row[0]),float(row[2]),float(row[3])])
1787#            print 'line',line[i],'point1',point1_answers_array[i]
1788            assert num.allclose(line[i], point2_answers_array[i])
1789                         
1790        # clean up
1791        point1_handle.close()
1792        point2_handle.close()
1793        #print "sww.filename",sww.filename
1794        os.remove(sww.filename)
1795        os.remove(points_file)
1796        os.remove(point1_filename)
1797        os.remove(point2_filename)
1798
1799
1800    def test_sww2csv_gauges2(self):
1801
1802        def elevation_function(x, y):
1803            return -x
1804       
1805        """Most of this test was copied from test_interpolate
1806        test_interpole_sww2csv
1807       
1808        This is testing the gauge_sww2csv function, by creating a sww file and
1809        then exporting the gauges and checking the results.
1810       
1811        This is the same as sww2csv_gauges except set domain.set_starttime to 5.
1812        Therefore testing the storing of the absolute time in the csv files
1813        """
1814       
1815        # Create mesh
1816        mesh_file = tempfile.mktemp(".tsh")   
1817        points = [[0.0,0.0],[6.0,0.0],[6.0,6.0],[0.0,6.0]]
1818        m = Mesh()
1819        m.add_vertices(points)
1820        m.auto_segment()
1821        m.generate_mesh(verbose=False)
1822        m.export_mesh_file(mesh_file)
1823       
1824        # Create shallow water domain
1825        domain = Domain(mesh_file)
1826        os.remove(mesh_file)
1827       
1828        domain.default_order=2
1829
1830        # This test was made before tight_slope_limiters were introduced
1831        # Since were are testing interpolation values this is OK
1832        domain.tight_slope_limiters = 0         
1833
1834        # Set some field values
1835        domain.set_quantity('elevation', elevation_function)
1836        domain.set_quantity('friction', 0.03)
1837        domain.set_quantity('xmomentum', 3.0)
1838        domain.set_quantity('ymomentum', 4.0)
1839        domain.set_starttime(5)
1840
1841        ######################
1842        # Boundary conditions
1843        B = Transmissive_boundary(domain)
1844        domain.set_boundary( {'exterior': B})
1845
1846        # This call mangles the stage values.
1847        domain.distribute_to_vertices_and_edges()
1848        domain.set_quantity('stage', 1.0)
1849       
1850
1851
1852        domain.set_name('datatest' + str(time.time()))
1853        domain.smooth = True
1854        domain.reduction = mean
1855
1856        sww = SWW_file(domain)
1857        sww.store_connectivity()
1858        sww.store_timestep()
1859        domain.set_quantity('stage', 10.0) # This is automatically limited
1860        # so it will not be less than the elevation
1861        domain.time = 2.
1862        sww.store_timestep()
1863
1864        # test the function
1865        points = [[5.0,1.],[0.5,2.]]
1866
1867        points_file = tempfile.mktemp(".csv")
1868#        points_file = 'test_point.csv'
1869        file_id = open(points_file,"w")
1870        file_id.write("name, easting, northing, elevation \n\
1871point1, 5.0, 1.0, 3.0\n\
1872point2, 0.5, 2.0, 9.0\n")
1873        file_id.close()
1874
1875       
1876        sww2csv_gauges(sww.filename, 
1877                            points_file,
1878                            verbose=False,
1879                            use_cache=False)
1880
1881#        point1_answers_array = [[0.0,1.0,-5.0,3.0,4.0], [2.0,10.0,-5.0,3.0,4.0]]
1882        point1_answers_array = [[5.0,5.0/3600.,1.0,6.0,-5.0,3.0,4.0], [7.0,7.0/3600.,10.0,15.0,-5.0,3.0,4.0]]
1883        point1_filename = 'gauge_point1.csv'
1884        point1_handle = file(point1_filename)
1885        point1_reader = reader(point1_handle)
1886        point1_reader.next()
1887
1888        line=[]
1889        for i,row in enumerate(point1_reader):
1890            #print 'i',i,'row',row
1891            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1892                         float(row[4]), float(row[5]), float(row[6])])
1893            #print 'assert line',line[i],'point1',point1_answers_array[i]
1894            assert num.allclose(line[i], point1_answers_array[i])
1895
1896        point2_answers_array = [[5.0,5.0/3600.,1.0,1.5,-0.5,3.0,4.0], [7.0,7.0/3600.,10.0,10.5,-0.5,3.0,4.0]]
1897        point2_filename = 'gauge_point2.csv' 
1898        point2_handle = file(point2_filename)
1899        point2_reader = reader(point2_handle)
1900        point2_reader.next()
1901                       
1902        line=[]
1903        for i,row in enumerate(point2_reader):
1904            #print 'i',i,'row',row
1905            line.append([float(row[0]),float(row[1]),float(row[2]),float(row[3]),
1906                         float(row[4]),float(row[5]), float(row[6])])
1907            #print 'assert line',line[i],'point1',point1_answers_array[i]
1908            assert num.allclose(line[i], point2_answers_array[i])
1909                         
1910        # clean up
1911        point1_handle.close()
1912        point2_handle.close()
1913        #print "sww.filename",sww.filename
1914        os.remove(sww.filename)
1915        os.remove(points_file)
1916        os.remove(point1_filename)
1917        os.remove(point2_filename)
1918
1919
1920    def test_greens_law(self):
1921
1922        from math import sqrt
1923       
1924        d1 = 80.0
1925        d2 = 20.0
1926        h1 = 1.0
1927        h2 = greens_law(d1,d2,h1)
1928
1929        assert h2==sqrt(2.0)
1930       
1931    def test_calc_bearings(self):
1932 
1933        from math import atan, degrees
1934        #Test East
1935        uh = 1
1936        vh = 1.e-15
1937        angle = calc_bearing(uh, vh)
1938        if 89 < angle < 91: v=1
1939        assert v==1
1940        #Test West
1941        uh = -1
1942        vh = 1.e-15
1943        angle = calc_bearing(uh, vh)
1944        if 269 < angle < 271: v=1
1945        assert v==1
1946        #Test North
1947        uh = 1.e-15
1948        vh = 1
1949        angle = calc_bearing(uh, vh)
1950        if -1 < angle < 1: v=1
1951        assert v==1
1952        #Test South
1953        uh = 1.e-15
1954        vh = -1
1955        angle = calc_bearing(uh, vh)
1956        if 179 < angle < 181: v=1
1957        assert v==1
1958        #Test South-East
1959        uh = 1
1960        vh = -1
1961        angle = calc_bearing(uh, vh)
1962        if 134 < angle < 136: v=1
1963        assert v==1
1964        #Test North-East
1965        uh = 1
1966        vh = 1
1967        angle = calc_bearing(uh, vh)
1968        if 44 < angle < 46: v=1
1969        assert v==1
1970        #Test South-West
1971        uh = -1
1972        vh = -1
1973        angle = calc_bearing(uh, vh)
1974        if 224 < angle < 226: v=1
1975        assert v==1
1976        #Test North-West
1977        uh = -1
1978        vh = 1
1979        angle = calc_bearing(uh, vh)
1980        if 314 < angle < 316: v=1
1981        assert v==1
1982       
1983
1984#-------------------------------------------------------------
1985
1986if __name__ == "__main__":
1987    suite = unittest.makeSuite(Test_Util, 'test')
1988#    runner = unittest.TextTestRunner(verbosity=2)
1989    runner = unittest.TextTestRunner(verbosity=1)
1990    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.