source: trunk/anuga_core/source/anuga/abstract_2d_finite_volumes/test_util.py @ 8050

Last change on this file since 8050 was 7814, checked in by hudson, 14 years ago

New filename conventions for file conversion. Filenames must always be passed in with the correct extension.

File size: 53.4 KB
Line 
1#!/usr/bin/env python
2
3
4import unittest
5from math import sqrt, pi
6import tempfile, os
7from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
8
9from anuga.abstract_2d_finite_volumes.util import *
10from anuga.config import epsilon
11from anuga.config import netcdf_mode_r, netcdf_mode_w, netcdf_mode_a
12from anuga.file_conversion.file_conversion import timefile2netcdf
13from anuga.utilities.file_utils import del_dir
14
15from anuga.utilities.numerical_tools import NAN
16
17from sys import platform
18
19from anuga.pmesh.mesh import Mesh
20import time
21from mesh_factory import rectangular
22from anuga.coordinate_transforms.geo_reference import Geo_reference
23from anuga.shallow_water.shallow_water_domain import Domain
24from generic_boundary_conditions import \
25                                Transmissive_boundary, Dirichlet_boundary
26from anuga.file.sww import SWW_file
27from csv import reader,writer
28import time
29import string
30
31import numpy as num
32
33
34def test_function(x, y):
35    return x+y
36
37class Test_Util(unittest.TestCase):
38    def setUp(self):
39        pass
40
41    def tearDown(self):
42        pass
43
44
45
46
47    #Geometric
48    #def test_distance(self):
49    #    from anuga.abstract_2d_finite_volumes.util import distance#
50    #
51    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
52    #                     'Distance is wrong!')
53    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
54    #                    'distance is wrong!')
55    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
56    #                    'distance is wrong!')
57    #
58    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
59    #                    'distance is wrong!')
60
61
62    def test_file_function_time1(self):
63        """Test that File function interpolates correctly
64        between given times. No x,y dependency here.
65        """
66
67        #Write file
68        import os, time
69        from anuga.config import time_format
70        from math import sin, pi
71
72        #Typical ASCII file
73        finaltime = 1200
74        filename = 'test_file_function'
75        fid = open(filename + '.txt', 'w')
76        start = time.mktime(time.strptime('2000', '%Y'))
77        dt = 60  #One minute intervals
78        t = 0.0
79        while t <= finaltime:
80            t_string = time.strftime(time_format, time.gmtime(t+start))
81            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
82            t += dt
83
84        fid.close()
85
86        #Convert ASCII file to NetCDF (Which is what we really like!)
87        timefile2netcdf(filename+'.txt')
88
89
90        #Create file function from time series
91        F = file_function(filename + '.tms',
92                          quantities = ['Attribute0',
93                                        'Attribute1',
94                                        'Attribute2'])
95       
96        #Now try interpolation
97        for i in range(20):
98            t = i*10
99            q = F(t)
100
101            #Exact linear intpolation
102            assert num.allclose(q[0], 2*t)
103            if i%6 == 0:
104                assert num.allclose(q[1], t**2)
105                assert num.allclose(q[2], sin(t*pi/600))
106
107        #Check non-exact
108
109        t = 90 #Halfway between 60 and 120
110        q = F(t)
111        assert num.allclose( (120**2 + 60**2)/2, q[1] )
112        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
113
114
115        t = 100 #Two thirds of the way between between 60 and 120
116        q = F(t)
117        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
118        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
119
120        os.remove(filename + '.txt')
121        os.remove(filename + '.tms')       
122
123
124       
125    def test_spatio_temporal_file_function_basic(self):
126        """Test that spatio temporal file function performs the correct
127        interpolations in both time and space
128        NetCDF version (x,y,t dependency)       
129        """
130        import time
131
132        #Create sww file of simple propagation from left to right
133        #through rectangular domain
134
135        #Create basic mesh and shallow water domain
136        points, vertices, boundary = rectangular(3, 3)
137        domain1 = Domain(points, vertices, boundary)
138
139        from anuga.utilities.numerical_tools import mean
140        domain1.reduction = mean
141        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
142                              # only one value.
143
144        domain1.default_order = 2
145        domain1.store = True
146        domain1.set_datadir('.')
147        sww_file = 'spatio_temporal_boundary_source_%d' %(id(self))
148        domain1.set_name(sww_file)
149
150        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
151        domain1.set_quantity('elevation', 0)
152        domain1.set_quantity('friction', 0)
153        domain1.set_quantity('stage', 0)
154
155        # Boundary conditions
156        B0 = Dirichlet_boundary([0,0,0])
157        B6 = Dirichlet_boundary([0.6,0,0])
158        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
159        domain1.check_integrity()
160
161        finaltime = 8
162        #Evolution
163        t0 = -1
164        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
165            #print 'Timesteps: %.16f, %.16f' %(t0, t)
166            #if t == t0:
167            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
168            #   raise msg
169            t0 = t
170             
171            #domain1.write_time()
172
173
174        #Now read data from sww and check
175        from Scientific.IO.NetCDF import NetCDFFile
176        filename = domain1.get_name() + '.sww'
177        fid = NetCDFFile(filename)
178
179        x = fid.variables['x'][:]
180        y = fid.variables['y'][:]
181        stage = fid.variables['stage'][:]
182        xmomentum = fid.variables['xmomentum'][:]
183        ymomentum = fid.variables['ymomentum'][:]
184        time = fid.variables['time'][:]
185
186        #Take stage vertex values at last timestep on diagonal
187        #Diagonal is identified by vertices: 0, 5, 10, 15
188
189        last_time_index = len(time)-1 #Last last_time_index
190        d_stage = num.reshape(num.take(stage[last_time_index, :],
191                                       [0,5,10,15],
192                                       axis=0),
193                              (4,1))
194        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
195                                    [0,5,10,15],
196                                   axis=0),
197                           (4,1))
198        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
199                                    [0,5,10,15],
200                                   axis=0),
201                           (4,1))
202        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
203
204        #Reference interpolated values at midpoints on diagonal at
205        #this timestep are
206        r0 = (D[0] + D[1])/2
207        r1 = (D[1] + D[2])/2
208        r2 = (D[2] + D[3])/2
209
210        #And the midpoints are found now
211        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
212        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
213
214        diag = num.concatenate( (Dx, Dy), axis=1)
215        d_midpoints = (diag[1:] + diag[:-1])/2
216
217        #Let us see if the file function can find the correct
218        #values at the midpoints at the last timestep:
219        f = file_function(filename, domain1,
220                          interpolation_points = d_midpoints)
221
222        T = f.get_time()
223        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
224        assert not T[-1] == T[-2], msg
225        t = time[last_time_index]
226        q = f(t, point_id=0); assert num.allclose(r0, q)
227        q = f(t, point_id=1); assert num.allclose(r1, q)
228        q = f(t, point_id=2); assert num.allclose(r2, q)
229
230
231        ##################
232        #Now do the same for the first timestep
233
234        timestep = 0 #First timestep
235        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
236        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
237        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
238        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
239
240        #Reference interpolated values at midpoints on diagonal at
241        #this timestep are
242        r0 = (D[0] + D[1])/2
243        r1 = (D[1] + D[2])/2
244        r2 = (D[2] + D[3])/2
245
246        #Let us see if the file function can find the correct
247        #values
248        q = f(0, point_id=0); assert num.allclose(r0, q)
249        q = f(0, point_id=1); assert num.allclose(r1, q)
250        q = f(0, point_id=2); assert num.allclose(r2, q)
251
252
253        ##################
254        #Now do it again for a timestep in the middle
255
256        timestep = 33
257        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
258        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
259        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
260        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
261
262        #Reference interpolated values at midpoints on diagonal at
263        #this timestep are
264        r0 = (D[0] + D[1])/2
265        r1 = (D[1] + D[2])/2
266        r2 = (D[2] + D[3])/2
267
268        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
269        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
270        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
271
272
273        ##################
274        #Now check temporal interpolation
275        #Halfway between timestep 15 and 16
276
277        timestep = 15
278        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
279        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
280        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
281        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
282
283        #Reference interpolated values at midpoints on diagonal at
284        #this timestep are
285        r0_0 = (D[0] + D[1])/2
286        r1_0 = (D[1] + D[2])/2
287        r2_0 = (D[2] + D[3])/2
288
289        #
290        timestep = 16
291        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
292        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
293        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
294        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
295
296        #Reference interpolated values at midpoints on diagonal at
297        #this timestep are
298        r0_1 = (D[0] + D[1])/2
299        r1_1 = (D[1] + D[2])/2
300        r2_1 = (D[2] + D[3])/2
301
302        # The reference values are
303        r0 = (r0_0 + r0_1)/2
304        r1 = (r1_0 + r1_1)/2
305        r2 = (r2_0 + r2_1)/2
306
307        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
308        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
309        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
310
311        ##################
312        #Finally check interpolation 2 thirds of the way
313        #between timestep 15 and 16
314
315        # The reference values are
316        r0 = (r0_0 + 2*r0_1)/3
317        r1 = (r1_0 + 2*r1_1)/3
318        r2 = (r2_0 + 2*r2_1)/3
319
320        #And the file function gives
321        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
322        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
323        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
324
325        fid.close()
326        import os
327        os.remove(filename)
328
329
330
331    def test_spatio_temporal_file_function_different_origin(self):
332        """Test that spatio temporal file function performs the correct
333        interpolations in both time and space where space is offset by
334        xllcorner and yllcorner
335        NetCDF version (x,y,t dependency)       
336        """
337        xllcorner = 2048
338        yllcorner = 11000
339        zone = 2
340
341        #Create basic mesh and shallow water domain
342        points, vertices, boundary = rectangular(3, 3)
343        domain1 = Domain(points, vertices, boundary,
344                         geo_reference = Geo_reference(xllcorner = xllcorner,
345                                                       yllcorner = yllcorner))
346       
347
348        from anuga.utilities.numerical_tools import mean       
349        domain1.reduction = mean
350        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
351                              # only one value.
352
353        domain1.default_order = 2
354        domain1.store = True
355        domain1.set_datadir('.')
356        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
357
358        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
359        domain1.set_quantity('elevation', 0)
360        domain1.set_quantity('friction', 0)
361        domain1.set_quantity('stage', 0)
362
363        # Boundary conditions
364        B0 = Dirichlet_boundary([0,0,0])
365        B6 = Dirichlet_boundary([0.6,0,0])
366        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
367        domain1.check_integrity()
368
369        finaltime = 8
370        #Evolution
371        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
372            pass
373            #domain1.write_time()
374
375
376        #Now read data from sww and check
377        from Scientific.IO.NetCDF import NetCDFFile
378        filename = domain1.get_name() + '.sww'
379        fid = NetCDFFile(filename)
380
381        x = fid.variables['x'][:]
382        y = fid.variables['y'][:]
383        # we 'cast' to 64 bit floats to pass this test
384        # SWW file quantities are stored as 32 bits
385        x = num.array(x, num.float)
386        y = num.array(y, num.float)
387
388        stage = fid.variables['stage'][:]
389        xmomentum = fid.variables['xmomentum'][:]
390        ymomentum = fid.variables['ymomentum'][:]
391        time = fid.variables['time'][:]
392
393        #Take stage vertex values at last timestep on diagonal
394        #Diagonal is identified by vertices: 0, 5, 10, 15
395
396        last_time_index = len(time)-1 #Last last_time_index     
397        d_stage = num.reshape(num.take(stage[last_time_index, :],
398                                       [0,5,10,15],
399                                       axis=0),
400                              (4,1))
401        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
402                                    [0,5,10,15],
403                                   axis=0),
404                           (4,1))
405        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
406                                    [0,5,10,15],
407                                    axis=0),
408                           (4,1))
409        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
410
411        #Reference interpolated values at midpoints on diagonal at
412        #this timestep are
413        r0 = (D[0] + D[1])/2
414        r1 = (D[1] + D[2])/2
415        r2 = (D[2] + D[3])/2
416
417        #And the midpoints are found now
418        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
419        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
420
421        diag = num.concatenate((Dx, Dy), axis=1)
422        d_midpoints = (diag[1:] + diag[:-1])/2
423
424
425        #Adjust for georef - make interpolation points absolute
426        d_midpoints[:,0] += xllcorner
427        d_midpoints[:,1] += yllcorner               
428
429        #Let us see if the file function can find the correct
430        #values at the midpoints at the last timestep:
431        f = file_function(filename, domain1,
432                          interpolation_points = d_midpoints)
433
434        t = time[last_time_index]                         
435
436        q = f(t, point_id=0)
437        msg = '\nr0=%s\nq=%s' % (str(r0), str(q))
438        assert num.allclose(r0, q), msg
439
440        q = f(t, point_id=1)
441        msg = '\nr1=%s\nq=%s' % (str(r1), str(q))
442        assert num.allclose(r1, q), msg
443
444        q = f(t, point_id=2)
445        msg = '\nr2=%s\nq=%s' % (str(r2), str(q))
446        assert num.allclose(r2, q), msg
447
448
449        ##################
450        #Now do the same for the first timestep
451
452        timestep = 0 #First timestep
453        d_stage = num.reshape(num.take(stage[timestep, :],
454                                       [0,5,10,15],
455                                       axis=0),
456                              (4,1))
457        d_uh = num.reshape(num.take(xmomentum[timestep, :],
458                                    [0,5,10,15],
459                                    axis=0),
460                           (4,1))
461        d_vh = num.reshape(num.take(ymomentum[timestep, :],
462                                    [0,5,10,15],
463                                    axis=0),
464                           (4,1))
465        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
466
467        #Reference interpolated values at midpoints on diagonal at
468        #this timestep are
469        r0 = (D[0] + D[1])/2
470        r1 = (D[1] + D[2])/2
471        r2 = (D[2] + D[3])/2
472
473        #Let us see if the file function can find the correct
474        #values
475        q = f(0, point_id=0); assert num.allclose(r0, q)
476        q = f(0, point_id=1); assert num.allclose(r1, q)
477        q = f(0, point_id=2); assert num.allclose(r2, q)
478
479
480        ##################
481        #Now do it again for a timestep in the middle
482
483        timestep = 33
484        d_stage = num.reshape(num.take(stage[timestep, :],
485                                       [0,5,10,15],
486                                       axis=0),
487                              (4,1))
488        d_uh = num.reshape(num.take(xmomentum[timestep, :],
489                                    [0,5,10,15],
490                                    axis=0),
491                           (4,1))
492        d_vh = num.reshape(num.take(ymomentum[timestep, :],
493                                    [0,5,10,15],
494                                    axis=0),
495                           (4,1))
496        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
497
498        #Reference interpolated values at midpoints on diagonal at
499        #this timestep are
500        r0 = (D[0] + D[1])/2
501        r1 = (D[1] + D[2])/2
502        r2 = (D[2] + D[3])/2
503
504        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
505        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
506        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
507
508
509        ##################
510        #Now check temporal interpolation
511        #Halfway between timestep 15 and 16
512
513        timestep = 15
514        d_stage = num.reshape(num.take(stage[timestep, :],
515                                       [0,5,10,15],
516                                       axis=0),
517                              (4,1))
518        d_uh = num.reshape(num.take(xmomentum[timestep, :],
519                                    [0,5,10,15],
520                                    axis=0),
521                           (4,1))
522        d_vh = num.reshape(num.take(ymomentum[timestep, :],
523                                    [0,5,10,15],
524                                    axis=0),
525                           (4,1))
526        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
527
528        #Reference interpolated values at midpoints on diagonal at
529        #this timestep are
530        r0_0 = (D[0] + D[1])/2
531        r1_0 = (D[1] + D[2])/2
532        r2_0 = (D[2] + D[3])/2
533
534        #
535        timestep = 16
536        d_stage = num.reshape(num.take(stage[timestep, :],
537                                       [0,5,10,15],
538                                       axis=0),
539                              (4,1))
540        d_uh = num.reshape(num.take(xmomentum[timestep, :],
541                                    [0,5,10,15],
542                                    axis=0),
543                           (4,1))
544        d_vh = num.reshape(num.take(ymomentum[timestep, :],
545                                    [0,5,10,15],
546                                    axis=0),
547                           (4,1))
548        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
549
550        #Reference interpolated values at midpoints on diagonal at
551        #this timestep are
552        r0_1 = (D[0] + D[1])/2
553        r1_1 = (D[1] + D[2])/2
554        r2_1 = (D[2] + D[3])/2
555
556        # The reference values are
557        r0 = (r0_0 + r0_1)/2
558        r1 = (r1_0 + r1_1)/2
559        r2 = (r2_0 + r2_1)/2
560
561        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
562        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
563        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
564
565        ##################
566        #Finally check interpolation 2 thirds of the way
567        #between timestep 15 and 16
568
569        # The reference values are
570        r0 = (r0_0 + 2*r0_1)/3
571        r1 = (r1_0 + 2*r1_1)/3
572        r2 = (r2_0 + 2*r2_1)/3
573
574        #And the file function gives
575        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
576        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
577        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
578
579        fid.close()
580        import os
581        os.remove(filename)
582
583       
584
585
586    def test_spatio_temporal_file_function_time(self):
587        """Test that File function interpolates correctly
588        between given times.
589        NetCDF version (x,y,t dependency)
590        """
591
592        #Create NetCDF (sww) file to be read
593        # x: 0, 5, 10, 15
594        # y: -20, -10, 0, 10
595        # t: 0, 60, 120, ...., 1200
596        #
597        # test quantities (arbitrary but non-trivial expressions):
598        #
599        #   stage     = 3*x - y**2 + 2*t
600        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
601        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
602
603        #NOTE: Nice test that may render some of the others redundant.
604
605        import os, time
606        from anuga.config import time_format
607        from mesh_factory import rectangular
608        from anuga.shallow_water.shallow_water_domain import Domain
609
610        finaltime = 1200
611        filename = 'test_file_function'
612
613        #Create a domain to hold test grid
614        #(0:15, -20:10)
615        points, vertices, boundary =\
616                rectangular(4, 4, 15, 30, origin = (0, -20))
617        #print "points", points
618
619        #print 'Number of elements', len(vertices)
620        domain = Domain(points, vertices, boundary)
621        domain.smooth = False
622        domain.default_order = 2
623        domain.set_datadir('.')
624        domain.set_name(filename)
625        domain.store = True
626
627        #print points
628        start = time.mktime(time.strptime('2000', '%Y'))
629        domain.starttime = start
630
631
632        #Store structure
633        domain.initialise_storage()
634
635        #Compute artificial time steps and store
636        dt = 60  #One minute intervals
637        t = 0.0
638        while t <= finaltime:
639            #Compute quantities
640            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
641            domain.set_quantity('stage', f1)
642
643            f2 = lambda x,y: x+y+t**2
644            domain.set_quantity('xmomentum', f2)
645
646            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
647            domain.set_quantity('ymomentum', f3)
648
649            #Store and advance time
650            domain.time = t
651            domain.store_timestep()
652            t += dt
653
654
655        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
656     
657        #Deliberately set domain.starttime to too early
658        domain.starttime = start - 1
659
660        #Create file function
661        F = file_function(filename + '.sww', domain,
662                          quantities = domain.conserved_quantities,
663                          interpolation_points = interpolation_points)
664
665        #Check that FF updates fixes domain starttime
666        assert num.allclose(domain.starttime, start)
667
668        #Check that domain.starttime isn't updated if later
669        domain.starttime = start + 1
670        F = file_function(filename + '.sww', domain,
671                          quantities = domain.conserved_quantities,
672                          interpolation_points = interpolation_points)
673        assert num.allclose(domain.starttime, start+1)
674        domain.starttime = start
675
676
677        #Check linear interpolation in time
678        F = file_function(filename + '.sww', domain,
679                          quantities = domain.conserved_quantities,
680                          interpolation_points = interpolation_points)               
681        for id in range(len(interpolation_points)):
682            x = interpolation_points[id][0]
683            y = interpolation_points[id][1]
684
685            for i in range(20):
686                t = i*10
687                k = i%6
688
689                if k == 0:
690                    q0 = F(t, point_id=id)
691                    q1 = F(t+60, point_id=id)
692
693                if num.alltrue(q0 == NAN):
694                    actual = q0
695                else:
696                    actual = (k*q1 + (6-k)*q0)/6
697                q = F(t, point_id=id)
698                #print i, k, t, q
699                #print ' ', q0
700                #print ' ', q1
701                #print "q",q
702                #print "actual", actual
703                #print
704                if num.alltrue(q0 == NAN):
705                     self.failUnless(num.alltrue(q == actual), 'Fail!')
706                else:
707                    assert num.allclose(q, actual)
708
709
710        #Another check of linear interpolation in time
711        for id in range(len(interpolation_points)):
712            q60 = F(60, point_id=id)
713            q120 = F(120, point_id=id)
714
715            t = 90 #Halfway between 60 and 120
716            q = F(t, point_id=id)
717            assert num.allclose( (q120+q60)/2, q )
718
719            t = 100 #Two thirds of the way between between 60 and 120
720            q = F(t, point_id=id)
721            assert num.allclose(q60/3 + 2*q120/3, q)
722
723
724
725        #Check that domain.starttime isn't updated if later than file starttime but earlier
726        #than file end time
727        delta = 23
728        domain.starttime = start + delta
729        F = file_function(filename + '.sww', domain,
730                          quantities = domain.conserved_quantities,
731                          interpolation_points = interpolation_points)
732        assert num.allclose(domain.starttime, start+delta)
733
734
735
736
737        #Now try interpolation with delta offset
738        for id in range(len(interpolation_points)):           
739            x = interpolation_points[id][0]
740            y = interpolation_points[id][1]
741
742            for i in range(20):
743                t = i*10
744                k = i%6
745
746                if k == 0:
747                    q0 = F(t-delta, point_id=id)
748                    q1 = F(t+60-delta, point_id=id)
749
750                q = F(t-delta, point_id=id)
751                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
752
753
754        os.remove(filename + '.sww')
755
756
757
758    def Xtest_spatio_temporal_file_function_time(self):
759        # FIXME: This passes but needs some TLC
760        # Test that File function interpolates correctly
761        # When some points are outside the mesh
762
763        import os, time
764        from anuga.config import time_format
765        from mesh_factory import rectangular
766        from shallow_water import Domain
767        import anuga.shallow_water.data_manager 
768        from anuga.pmesh.mesh_interface import create_mesh_from_regions
769        finaltime = 1200
770       
771        filename = tempfile.mktemp()
772        #print "filename",filename
773        filename = 'test_file_function'
774
775        meshfilename = tempfile.mktemp(".tsh")
776
777        boundary_tags = {'walls':[0,1],'bom':[2]}
778       
779        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
780       
781        create_mesh_from_regions(polygon_absolute,
782                                 boundary_tags,
783                                 10000000,
784                                 filename=meshfilename)
785        domain = Domain(mesh_filename=meshfilename)
786        domain.smooth = False
787        domain.default_order = 2
788        domain.set_datadir('.')
789        domain.set_name(filename)
790        domain.store = True
791
792        #print points
793        start = time.mktime(time.strptime('2000', '%Y'))
794        domain.starttime = start
795       
796
797        #Store structure
798        domain.initialise_storage()
799
800        #Compute artificial time steps and store
801        dt = 60  #One minute intervals
802        t = 0.0
803        while t <= finaltime:
804            #Compute quantities
805            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
806            domain.set_quantity('stage', f1)
807
808            f2 = lambda x,y: x+y+t**2
809            domain.set_quantity('xmomentum', f2)
810
811            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
812            domain.set_quantity('ymomentum', f3)
813
814            #Store and advance time
815            domain.time = t
816            domain.store_timestep()
817            t += dt
818
819        interpolation_points = [[1,0]]
820        interpolation_points = [[100,1000]]
821       
822        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
823                                [78787,78787],[7878,3432]]
824           
825        #Deliberately set domain.starttime to too early
826        domain.starttime = start - 1
827
828        #Create file function
829        F = file_function(filename + '.sww', domain,
830                          quantities = domain.conserved_quantities,
831                          interpolation_points = interpolation_points)
832
833        #Check that FF updates fixes domain starttime
834        assert num.allclose(domain.starttime, start)
835
836        #Check that domain.starttime isn't updated if later
837        domain.starttime = start + 1
838        F = file_function(filename + '.sww', domain,
839                          quantities = domain.conserved_quantities,
840                          interpolation_points = interpolation_points)
841        assert num.allclose(domain.starttime, start+1)
842        domain.starttime = start
843
844
845        #Check linear interpolation in time
846        # checking points inside and outside the mesh
847        F = file_function(filename + '.sww', domain,
848                          quantities = domain.conserved_quantities,
849                          interpolation_points = interpolation_points)
850       
851        for id in range(len(interpolation_points)):
852            x = interpolation_points[id][0]
853            y = interpolation_points[id][1]
854
855            for i in range(20):
856                t = i*10
857                k = i%6
858
859                if k == 0:
860                    q0 = F(t, point_id=id)
861                    q1 = F(t+60, point_id=id)
862
863                if q0 == NAN:
864                    actual = q0
865                else:
866                    actual = (k*q1 + (6-k)*q0)/6
867                q = F(t, point_id=id)
868                #print i, k, t, q
869                #print ' ', q0
870                #print ' ', q1
871                #print "q",q
872                #print "actual", actual
873                #print
874                if q0 == NAN:
875                     self.failUnless( q == actual, 'Fail!')
876                else:
877                    assert num.allclose(q, actual)
878
879        # now lets check points inside the mesh
880        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
881        interpolation_points = [[10,-12.5]]
882           
883        print "len(interpolation_points)",len(interpolation_points) 
884        F = file_function(filename + '.sww', domain,
885                          quantities = domain.conserved_quantities,
886                          interpolation_points = interpolation_points)
887
888        domain.starttime = start
889
890
891        #Check linear interpolation in time
892        F = file_function(filename + '.sww', domain,
893                          quantities = domain.conserved_quantities,
894                          interpolation_points = interpolation_points)               
895        for id in range(len(interpolation_points)):
896            x = interpolation_points[id][0]
897            y = interpolation_points[id][1]
898
899            for i in range(20):
900                t = i*10
901                k = i%6
902
903                if k == 0:
904                    q0 = F(t, point_id=id)
905                    q1 = F(t+60, point_id=id)
906
907                if q0 == NAN:
908                    actual = q0
909                else:
910                    actual = (k*q1 + (6-k)*q0)/6
911                q = F(t, point_id=id)
912                print "############"
913                print "id, x, y ", id, x, y #k, t, q
914                print "t", t
915                #print ' ', q0
916                #print ' ', q1
917                print "q",q
918                print "actual", actual
919                #print
920                if q0 == NAN:
921                     self.failUnless( q == actual, 'Fail!')
922                else:
923                    assert num.allclose(q, actual)
924
925
926        #Another check of linear interpolation in time
927        for id in range(len(interpolation_points)):
928            q60 = F(60, point_id=id)
929            q120 = F(120, point_id=id)
930
931            t = 90 #Halfway between 60 and 120
932            q = F(t, point_id=id)
933            assert num.allclose( (q120+q60)/2, q )
934
935            t = 100 #Two thirds of the way between between 60 and 120
936            q = F(t, point_id=id)
937            assert num.allclose(q60/3 + 2*q120/3, q)
938
939
940
941        #Check that domain.starttime isn't updated if later than file starttime but earlier
942        #than file end time
943        delta = 23
944        domain.starttime = start + delta
945        F = file_function(filename + '.sww', domain,
946                          quantities = domain.conserved_quantities,
947                          interpolation_points = interpolation_points)
948        assert num.allclose(domain.starttime, start+delta)
949
950
951
952
953        #Now try interpolation with delta offset
954        for id in range(len(interpolation_points)):           
955            x = interpolation_points[id][0]
956            y = interpolation_points[id][1]
957
958            for i in range(20):
959                t = i*10
960                k = i%6
961
962                if k == 0:
963                    q0 = F(t-delta, point_id=id)
964                    q1 = F(t+60-delta, point_id=id)
965
966                q = F(t-delta, point_id=id)
967                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
968
969
970        os.remove(filename + '.sww')
971
972    def test_file_function_time_with_domain(self):
973        """Test that File function interpolates correctly
974        between given times. No x,y dependency here.
975        Use domain with starttime
976        """
977
978        #Write file
979        import os, time, calendar
980        from anuga.config import time_format
981        from math import sin, pi
982
983        finaltime = 1200
984        filename = 'test_file_function'
985        fid = open(filename + '.txt', 'w')
986        start = time.mktime(time.strptime('2000', '%Y'))
987        dt = 60  #One minute intervals
988        t = 0.0
989        while t <= finaltime:
990            t_string = time.strftime(time_format, time.gmtime(t+start))
991            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
992            t += dt
993
994        fid.close()
995
996
997        #Convert ASCII file to NetCDF (Which is what we really like!)
998        timefile2netcdf(filename+'.txt')
999
1000
1001
1002        a = [0.0, 0.0]
1003        b = [4.0, 0.0]
1004        c = [0.0, 3.0]
1005
1006        points = [a, b, c]
1007        vertices = [[0,1,2]]
1008        domain = Domain(points, vertices)
1009
1010        # Check that domain.starttime is updated if non-existing
1011        F = file_function(filename + '.tms',
1012                          domain,
1013                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
1014        assert num.allclose(domain.starttime, start)
1015
1016        # Check that domain.starttime is updated if too early
1017        domain.starttime = start - 1
1018        F = file_function(filename + '.tms',
1019                          domain,
1020                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1021        assert num.allclose(domain.starttime, start)
1022
1023        # Check that domain.starttime isn't updated if later
1024        domain.starttime = start + 1
1025        F = file_function(filename + '.tms',
1026                          domain,
1027                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1028        assert num.allclose(domain.starttime, start+1)
1029
1030        domain.starttime = start
1031        F = file_function(filename + '.tms',
1032                          domain,
1033                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
1034                          use_cache=True)
1035       
1036
1037        #print F.precomputed_values
1038        #print 'F(60)', F(60)
1039       
1040        #Now try interpolation
1041        for i in range(20):
1042            t = i*10
1043            q = F(t)
1044
1045            #Exact linear intpolation
1046            assert num.allclose(q[0], 2*t)
1047            if i%6 == 0:
1048                assert num.allclose(q[1], t**2)
1049                assert num.allclose(q[2], sin(t*pi/600))
1050
1051        #Check non-exact
1052
1053        t = 90 #Halfway between 60 and 120
1054        q = F(t)
1055        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1056        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1057
1058
1059        t = 100 #Two thirds of the way between between 60 and 120
1060        q = F(t)
1061        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1062        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1063
1064        os.remove(filename + '.tms')
1065        os.remove(filename + '.txt')       
1066
1067    def test_file_function_time_with_domain_different_start(self):
1068        """Test that File function interpolates correctly
1069        between given times. No x,y dependency here.
1070        Use domain with a starttime later than that of file
1071
1072        ASCII version
1073        """
1074
1075        #Write file
1076        import os, time, calendar
1077        from anuga.config import time_format
1078        from math import sin, pi
1079
1080        finaltime = 1200
1081        filename = 'test_file_function'
1082        fid = open(filename + '.txt', 'w')
1083        start = time.mktime(time.strptime('2000', '%Y'))
1084        dt = 60  #One minute intervals
1085        t = 0.0
1086        while t <= finaltime:
1087            t_string = time.strftime(time_format, time.gmtime(t+start))
1088            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1089            t += dt
1090
1091        fid.close()
1092
1093        #Convert ASCII file to NetCDF (Which is what we really like!)
1094        timefile2netcdf(filename+'.txt')       
1095
1096        a = [0.0, 0.0]
1097        b = [4.0, 0.0]
1098        c = [0.0, 3.0]
1099
1100        points = [a, b, c]
1101        vertices = [[0,1,2]]
1102        domain = Domain(points, vertices)
1103
1104        #Check that domain.starttime isn't updated if later than file starttime but earlier
1105        #than file end time
1106        delta = 23
1107        domain.starttime = start + delta
1108        F = file_function(filename + '.tms', domain,
1109                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1110        assert num.allclose(domain.starttime, start+delta)
1111
1112        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1113                                            277., 337., 397., 457., 517.,
1114                                            577., 637., 697., 757., 817.,
1115                                            877., 937., 997., 1057., 1117.,
1116                                            1177.])
1117
1118
1119        #Now try interpolation with delta offset
1120        for i in range(20):
1121            t = i*10
1122            q = F(t-delta)
1123
1124            #Exact linear intpolation
1125            assert num.allclose(q[0], 2*t)
1126            if i%6 == 0:
1127                assert num.allclose(q[1], t**2)
1128                assert num.allclose(q[2], sin(t*pi/600))
1129
1130        #Check non-exact
1131
1132        t = 90 #Halfway between 60 and 120
1133        q = F(t-delta)
1134        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1135        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1136
1137
1138        t = 100 #Two thirds of the way between between 60 and 120
1139        q = F(t-delta)
1140        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1141        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1142
1143
1144        os.remove(filename + '.tms')
1145        os.remove(filename + '.txt')               
1146
1147       
1148
1149    def test_file_function_time_with_domain_different_start_and_time_limit(self):
1150        """Test that File function interpolates correctly
1151        between given times. No x,y dependency here.
1152        Use domain with a starttime later than that of file
1153
1154        ASCII version
1155       
1156        This test also tests that time can be truncated.
1157        """
1158
1159        # Write file
1160        import os, time, calendar
1161        from anuga.config import time_format
1162        from math import sin, pi
1163
1164        finaltime = 1200
1165        filename = 'test_file_function'
1166        fid = open(filename + '.txt', 'w')
1167        start = time.mktime(time.strptime('2000', '%Y'))
1168        dt = 60  #One minute intervals
1169        t = 0.0
1170        while t <= finaltime:
1171            t_string = time.strftime(time_format, time.gmtime(t+start))
1172            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1173            t += dt
1174
1175        fid.close()
1176
1177        # Convert ASCII file to NetCDF (Which is what we really like!)
1178        timefile2netcdf(filename + '.txt')       
1179
1180        a = [0.0, 0.0]
1181        b = [4.0, 0.0]
1182        c = [0.0, 3.0]
1183
1184        points = [a, b, c]
1185        vertices = [[0,1,2]]
1186        domain = Domain(points, vertices)
1187
1188        # Check that domain.starttime isn't updated if later than file starttime but earlier
1189        # than file end time
1190        delta = 23
1191        domain.starttime = start + delta
1192        time_limit = domain.starttime + 600
1193        F = file_function(filename + '.tms', domain,
1194                          time_limit=time_limit,
1195                          quantities=['Attribute0', 'Attribute1', 'Attribute2'])       
1196        assert num.allclose(domain.starttime, start+delta)
1197
1198        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1199                                            277., 337., 397., 457., 517.,
1200                                            577.])       
1201
1202
1203
1204        # Now try interpolation with delta offset
1205        for i in range(20):
1206            t = i*10
1207            q = F(t-delta)
1208
1209            #Exact linear intpolation
1210            assert num.allclose(q[0], 2*t)
1211            if i%6 == 0:
1212                assert num.allclose(q[1], t**2)
1213                assert num.allclose(q[2], sin(t*pi/600))
1214
1215        # Check non-exact
1216        t = 90 #Halfway between 60 and 120
1217        q = F(t-delta)
1218        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1219        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1220
1221
1222        t = 100 # Two thirds of the way between between 60 and 120
1223        q = F(t-delta)
1224        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1225        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1226
1227
1228        os.remove(filename + '.tms')
1229        os.remove(filename + '.txt')               
1230
1231       
1232       
1233       
1234
1235
1236    def test_apply_expression_to_dictionary(self):
1237
1238        #FIXME: Division is not expected to work for integers.
1239        #This must be caught.
1240        foo = num.array([[1,2,3], [4,5,6]], num.float)
1241
1242        bar = num.array([[-1,0,5], [6,1,1]], num.float)                 
1243
1244        D = {'X': foo, 'Y': bar}
1245
1246        Z = apply_expression_to_dictionary('X+Y', D)       
1247        assert num.allclose(Z, foo+bar)
1248
1249        Z = apply_expression_to_dictionary('X*Y', D)       
1250        assert num.allclose(Z, foo*bar)       
1251
1252        Z = apply_expression_to_dictionary('4*X+Y', D)       
1253        assert num.allclose(Z, 4*foo+bar)       
1254
1255        # test zero division is OK
1256        Z = apply_expression_to_dictionary('X/Y', D)
1257        assert num.allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1258
1259        # make an error for zero on zero
1260        # this is really an error in numeric, SciPy core can handle it
1261        # Z = apply_expression_to_dictionary('0/Y', D)
1262
1263        #Check exceptions
1264        try:
1265            #Wrong name
1266            Z = apply_expression_to_dictionary('4*X+A', D)       
1267        except NameError:
1268            pass
1269        else:
1270            msg = 'Should have raised a NameError Exception'
1271            raise msg
1272
1273
1274        try:
1275            #Wrong order
1276            Z = apply_expression_to_dictionary(D, '4*X+A')       
1277        except AssertionError:
1278            pass
1279        else:
1280            msg = 'Should have raised a AssertionError Exception'
1281            raise msg       
1282       
1283
1284    def test_multiple_replace(self):
1285        """Hard test that checks a true word-by-word simultaneous replace
1286        """
1287       
1288        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1289        exp = '3*x+y + xi'
1290       
1291        new = multiple_replace(exp, D)
1292       
1293        assert new == '3*xi+eta + lam'
1294                         
1295   
1296    def test_get_revision_number(self):
1297        """test_get_revision_number(self):
1298
1299        Test that revision number can be retrieved.
1300        """
1301        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1302            # I have a known snv incompatability issue,
1303            # so I'm skipping this test.
1304            # FIXME when SVN is upgraded on our clusters
1305            pass
1306        else:   
1307            n = get_revision_number()
1308            assert n>=0
1309
1310
1311       
1312    def test_add_directories(self):
1313       
1314        import tempfile
1315        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1316        directories = ['ja','ne','ke']
1317        kens_dir = add_directories(root_dir, directories)
1318        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1319               sep + 'ke'
1320        assert access(root_dir,F_OK)
1321
1322        add_directories(root_dir, directories)
1323        assert access(root_dir,F_OK)
1324       
1325        #clean up!
1326        os.rmdir(kens_dir)
1327        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1328        os.rmdir(root_dir + sep + 'ja')
1329        os.rmdir(root_dir)
1330
1331    def test_add_directories_bad(self):
1332       
1333        import tempfile
1334        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1335        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1336       
1337        try:
1338            kens_dir = add_directories(root_dir, directories)
1339        except OSError:
1340            pass
1341        else:
1342            msg = 'bad dir name should give OSError'
1343            raise Exception(msg)   
1344           
1345        #clean up!
1346        os.rmdir(root_dir)
1347
1348    def test_check_list(self):
1349
1350        check_list(['stage','xmomentum'])
1351
1352       
1353    def test_add_directories(self):
1354       
1355        import tempfile
1356        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1357        directories = ['ja','ne','ke']
1358        kens_dir = add_directories(root_dir, directories)
1359        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1360               sep + 'ke'
1361        assert access(root_dir,F_OK)
1362
1363        add_directories(root_dir, directories)
1364        assert access(root_dir,F_OK)
1365       
1366        #clean up!
1367        os.rmdir(kens_dir)
1368        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1369        os.rmdir(root_dir + sep + 'ja')
1370        os.rmdir(root_dir)
1371
1372    def test_add_directories_bad(self):
1373       
1374        import tempfile
1375        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1376        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1377       
1378        try:
1379            kens_dir = add_directories(root_dir, directories)
1380        except OSError:
1381            pass
1382        else:
1383            msg = 'bad dir name should give OSError'
1384            raise Exception(msg)   
1385           
1386        #clean up!
1387        os.rmdir(root_dir)
1388
1389    def test_check_list(self):
1390
1391        check_list(['stage','xmomentum'])
1392
1393######
1394# Test the remove_lone_verts() function
1395######
1396       
1397    def test_remove_lone_verts_a(self):
1398        verts = [[0,0],[1,0],[0,1]]
1399        tris = [[0,1,2]]
1400        new_verts, new_tris = remove_lone_verts(verts, tris)
1401        self.failUnless(new_verts.tolist() == verts)
1402        self.failUnless(new_tris.tolist() == tris)
1403
1404    def test_remove_lone_verts_b(self):
1405        verts = [[0,0],[1,0],[0,1],[99,99]]
1406        tris = [[0,1,2]]
1407        new_verts, new_tris = remove_lone_verts(verts, tris)
1408        self.failUnless(new_verts.tolist() == verts[0:3])
1409        self.failUnless(new_tris.tolist() == tris)
1410       
1411    def test_remove_lone_verts_c(self):
1412        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1413        tris = [[1,2,4]]
1414        new_verts, new_tris = remove_lone_verts(verts, tris)
1415        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1416        self.failUnless(new_tris.tolist() == [[0,1,2]])
1417     
1418    def test_remove_lone_verts_d(self):
1419        verts = [[0,0],[1,0],[99,99],[0,1]]
1420        tris = [[0,1,3]]
1421        new_verts, new_tris = remove_lone_verts(verts, tris)
1422        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1423        self.failUnless(new_tris.tolist() == [[0,1,2]])
1424       
1425    def test_remove_lone_verts_e(self):
1426        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1427        tris = [[0,1,2]]
1428        new_verts, new_tris = remove_lone_verts(verts, tris)
1429        self.failUnless(new_verts.tolist() == verts[0:3])
1430        self.failUnless(new_tris.tolist() == tris)
1431     
1432    def test_remove_lone_verts_f(self):
1433        verts = [[0,0],[1,0],[99,99],[0,1],[99,99],[1,1],[99,99]]
1434        tris = [[0,1,3],[0,1,5]]
1435        new_verts, new_tris = remove_lone_verts(verts, tris)
1436        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1],[1,1]])
1437        self.failUnless(new_tris.tolist() == [[0,1,2],[0,1,3]])
1438       
1439######
1440#
1441######
1442       
1443    def test_get_min_max_values(self):
1444       
1445        list=[8,9,6,1,4]
1446        min1, max1 = get_min_max_values(list)
1447       
1448        assert min1==1 
1449        assert max1==9
1450       
1451    def test_get_min_max_values1(self):
1452       
1453        list=[-8,-9,-6,-1,-4]
1454        min1, max1 = get_min_max_values(list)
1455       
1456#        print 'min1,max1',min1,max1
1457        assert min1==-9 
1458        assert max1==-1
1459
1460#    def test_get_min_max_values2(self):
1461#        '''
1462#        The min and max supplied are greater than the ones in the
1463#        list and therefore are the ones returned
1464#        '''
1465#        list=[-8,-9,-6,-1,-4]
1466#        min1, max1 = get_min_max_values(list,-10,10)
1467#       
1468##        print 'min1,max1',min1,max1
1469#        assert min1==-10
1470#        assert max1==10
1471       
1472    def test_make_plots_from_csv_files(self):
1473       
1474        #if sys.platform == 'win32':  #Windows
1475            try: 
1476                import pylab
1477            except ImportError:
1478                #ANUGA don't need pylab to work so the system doesn't
1479                #rely on pylab being installed
1480                return
1481           
1482       
1483            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1484            temp_dir = tempfile.mkdtemp('','figures')
1485    #        print 'temp_dir',temp_dir
1486            fileName = temp_dir+sep+'time_series_3.csv'
1487            file = open(fileName,"w")
1488            file.write("time,stage,speed,momentum,elevation\n\
14891.0, 0, 0, 0, 10 \n\
14902.0, 5, 2, 4, 10 \n\
14913.0, 3, 3, 5, 10 \n")
1492            file.close()
1493   
1494            fileName1 = temp_dir+sep+'time_series_4.csv'
1495            file1 = open(fileName1,"w")
1496            file1.write("time,stage,speed,momentum,elevation\n\
14971.0, 0, 0, 0, 5 \n\
14982.0, -5, -2, -4, 5 \n\
14993.0, -4, -3, -5, 5 \n")
1500            file1.close()
1501   
1502            fileName2 = temp_dir+sep+'time_series_5.csv'
1503            file2 = open(fileName2,"w")
1504            file2.write("time,stage,speed,momentum,elevation\n\
15051.0, 0, 0, 0, 7 \n\
15062.0, 4, -0.45, 57, 7 \n\
15073.0, 6, -0.5, 56, 7 \n")
1508            file2.close()
1509           
1510            dir, name=os.path.split(fileName)
1511            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1512                                  output_dir=temp_dir,
1513                                  base_name='time_series_',
1514                                  plot_numbers=['3-5'],
1515                                  quantities=['speed','stage','momentum'],
1516                                  assess_all_csv_files=True,
1517                                  extra_plot_name='test')
1518           
1519            #print dir+sep+name[:-4]+'_stage_test.png'
1520            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1521            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1522            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1523   
1524            dir1, name1=os.path.split(fileName1)
1525            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1526            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1527            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1528   
1529   
1530            dir2, name2=os.path.split(fileName2)
1531            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1532            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1533            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1534   
1535            del_dir(temp_dir)
1536       
1537
1538
1539    def test_greens_law(self):
1540
1541        from math import sqrt
1542       
1543        d1 = 80.0
1544        d2 = 20.0
1545        h1 = 1.0
1546        h2 = greens_law(d1,d2,h1)
1547
1548        assert h2==sqrt(2.0)
1549       
1550    def test_calc_bearings(self):
1551 
1552        from math import atan, degrees
1553        #Test East
1554        uh = 1
1555        vh = 1.e-15
1556        angle = calc_bearing(uh, vh)
1557        if 89 < angle < 91: v=1
1558        assert v==1
1559        #Test West
1560        uh = -1
1561        vh = 1.e-15
1562        angle = calc_bearing(uh, vh)
1563        if 269 < angle < 271: v=1
1564        assert v==1
1565        #Test North
1566        uh = 1.e-15
1567        vh = 1
1568        angle = calc_bearing(uh, vh)
1569        if -1 < angle < 1: v=1
1570        assert v==1
1571        #Test South
1572        uh = 1.e-15
1573        vh = -1
1574        angle = calc_bearing(uh, vh)
1575        if 179 < angle < 181: v=1
1576        assert v==1
1577        #Test South-East
1578        uh = 1
1579        vh = -1
1580        angle = calc_bearing(uh, vh)
1581        if 134 < angle < 136: v=1
1582        assert v==1
1583        #Test North-East
1584        uh = 1
1585        vh = 1
1586        angle = calc_bearing(uh, vh)
1587        if 44 < angle < 46: v=1
1588        assert v==1
1589        #Test South-West
1590        uh = -1
1591        vh = -1
1592        angle = calc_bearing(uh, vh)
1593        if 224 < angle < 226: v=1
1594        assert v==1
1595        #Test North-West
1596        uh = -1
1597        vh = 1
1598        angle = calc_bearing(uh, vh)
1599        if 314 < angle < 316: v=1
1600        assert v==1
1601       
1602    def test_calc_bearings_zero_vector(self): 
1603        from math import atan, degrees
1604
1605        uh = 0
1606        vh = 0
1607        angle = calc_bearing(uh, vh)
1608
1609        assert angle == NAN
1610       
1611#-------------------------------------------------------------
1612
1613if __name__ == "__main__":
1614    suite = unittest.makeSuite(Test_Util, 'test')
1615#    runner = unittest.TextTestRunner(verbosity=2)
1616    runner = unittest.TextTestRunner(verbosity=1)
1617    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.