source: trunk/anuga_core/source/anuga/abstract_2d_finite_volumes/test_util.py @ 7780

Last change on this file since 7780 was 7780, checked in by hudson, 13 years ago

Almost all failing tests fixed.

File size: 53.9 KB
Line 
1#!/usr/bin/env python
2
3
4import unittest
5from math import sqrt, pi
6import tempfile, os
7from os import access, F_OK,sep, removedirs,remove,mkdir,getcwd
8
9from anuga.abstract_2d_finite_volumes.util import *
10from anuga.config import epsilon
11from anuga.config import netcdf_mode_r, netcdf_mode_w, netcdf_mode_a
12from anuga.file_conversion.file_conversion import timefile2netcdf
13from anuga.utilities.file_utils import del_dir
14
15from anuga.utilities.numerical_tools import NAN
16
17from sys import platform
18
19from anuga.pmesh.mesh import Mesh
20import time
21from mesh_factory import rectangular
22from anuga.coordinate_transforms.geo_reference import Geo_reference
23from anuga.shallow_water.shallow_water_domain import Domain
24from generic_boundary_conditions import \
25                                Transmissive_boundary, Dirichlet_boundary
26from anuga.file.sww import SWW_file
27from csv import reader,writer
28import time
29import string
30
31import numpy as num
32
33
34def test_function(x, y):
35    return x+y
36
37class Test_Util(unittest.TestCase):
38    def setUp(self):
39        pass
40
41    def tearDown(self):
42        pass
43
44
45
46
47    #Geometric
48    #def test_distance(self):
49    #    from anuga.abstract_2d_finite_volumes.util import distance#
50    #
51    #    self.failUnless( distance([4,2],[7,6]) == 5.0,
52    #                     'Distance is wrong!')
53    #    self.failUnless( allclose(distance([7,6],[9,8]), 2.82842712475),
54    #                    'distance is wrong!')
55    #    self.failUnless( allclose(distance([9,8],[4,2]), 7.81024967591),
56    #                    'distance is wrong!')
57    #
58    #    self.failUnless( distance([9,8],[4,2]) == distance([4,2],[9,8]),
59    #                    'distance is wrong!')
60
61
62    def test_file_function_time1(self):
63        """Test that File function interpolates correctly
64        between given times. No x,y dependency here.
65        """
66
67        #Write file
68        import os, time
69        from anuga.config import time_format
70        from math import sin, pi
71
72        #Typical ASCII file
73        finaltime = 1200
74        filename = 'test_file_function'
75        fid = open(filename + '.txt', 'w')
76        start = time.mktime(time.strptime('2000', '%Y'))
77        dt = 60  #One minute intervals
78        t = 0.0
79        while t <= finaltime:
80            t_string = time.strftime(time_format, time.gmtime(t+start))
81            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
82            t += dt
83
84        fid.close()
85
86        #Convert ASCII file to NetCDF (Which is what we really like!)
87        timefile2netcdf(filename)
88
89
90        #Create file function from time series
91        F = file_function(filename + '.tms',
92                          quantities = ['Attribute0',
93                                        'Attribute1',
94                                        'Attribute2'])
95       
96        #Now try interpolation
97        for i in range(20):
98            t = i*10
99            q = F(t)
100
101            #Exact linear intpolation
102            assert num.allclose(q[0], 2*t)
103            if i%6 == 0:
104                assert num.allclose(q[1], t**2)
105                assert num.allclose(q[2], sin(t*pi/600))
106
107        #Check non-exact
108
109        t = 90 #Halfway between 60 and 120
110        q = F(t)
111        assert num.allclose( (120**2 + 60**2)/2, q[1] )
112        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
113
114
115        t = 100 #Two thirds of the way between between 60 and 120
116        q = F(t)
117        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
118        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
119
120        os.remove(filename + '.txt')
121        os.remove(filename + '.tms')       
122
123
124       
125    def test_spatio_temporal_file_function_basic(self):
126        """Test that spatio temporal file function performs the correct
127        interpolations in both time and space
128        NetCDF version (x,y,t dependency)       
129        """
130        import time
131
132        #Create sww file of simple propagation from left to right
133        #through rectangular domain
134
135        #Create basic mesh and shallow water domain
136        points, vertices, boundary = rectangular(3, 3)
137        domain1 = Domain(points, vertices, boundary)
138
139        from anuga.utilities.numerical_tools import mean
140        domain1.reduction = mean
141        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
142                              # only one value.
143
144        domain1.default_order = 2
145        domain1.store = True
146        domain1.set_datadir('.')
147        sww_file = 'spatio_temporal_boundary_source_%d' %(id(self))
148        domain1.set_name(sww_file)
149
150        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
151        domain1.set_quantity('elevation', 0)
152        domain1.set_quantity('friction', 0)
153        domain1.set_quantity('stage', 0)
154
155        # Boundary conditions
156        B0 = Dirichlet_boundary([0,0,0])
157        B6 = Dirichlet_boundary([0.6,0,0])
158        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
159        domain1.check_integrity()
160
161        finaltime = 8
162        #Evolution
163        t0 = -1
164        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
165            #print 'Timesteps: %.16f, %.16f' %(t0, t)
166            #if t == t0:
167            #    msg = 'Duplicate timestep found: %f, %f' %(t0, t)
168            #   raise msg
169            t0 = t
170             
171            #domain1.write_time()
172
173
174        #Now read data from sww and check
175        from Scientific.IO.NetCDF import NetCDFFile
176        filename = domain1.get_name() + '.sww'
177        fid = NetCDFFile(filename)
178
179        x = fid.variables['x'][:]
180        y = fid.variables['y'][:]
181        stage = fid.variables['stage'][:]
182        xmomentum = fid.variables['xmomentum'][:]
183        ymomentum = fid.variables['ymomentum'][:]
184        time = fid.variables['time'][:]
185
186        #Take stage vertex values at last timestep on diagonal
187        #Diagonal is identified by vertices: 0, 5, 10, 15
188
189        last_time_index = len(time)-1 #Last last_time_index
190        d_stage = num.reshape(num.take(stage[last_time_index, :],
191                                       [0,5,10,15],
192                                       axis=0),
193                              (4,1))
194        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
195                                    [0,5,10,15],
196                                   axis=0),
197                           (4,1))
198        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
199                                    [0,5,10,15],
200                                   axis=0),
201                           (4,1))
202        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
203
204        #Reference interpolated values at midpoints on diagonal at
205        #this timestep are
206        r0 = (D[0] + D[1])/2
207        r1 = (D[1] + D[2])/2
208        r2 = (D[2] + D[3])/2
209
210        #And the midpoints are found now
211        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
212        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
213
214        diag = num.concatenate( (Dx, Dy), axis=1)
215        d_midpoints = (diag[1:] + diag[:-1])/2
216
217        #Let us see if the file function can find the correct
218        #values at the midpoints at the last timestep:
219        f = file_function(filename, domain1,
220                          interpolation_points = d_midpoints)
221
222        T = f.get_time()
223        msg = 'duplicate timesteps: %.16f and %.16f' %(T[-1], T[-2])
224        assert not T[-1] == T[-2], msg
225        t = time[last_time_index]
226        q = f(t, point_id=0); assert num.allclose(r0, q)
227        q = f(t, point_id=1); assert num.allclose(r1, q)
228        q = f(t, point_id=2); assert num.allclose(r2, q)
229
230
231        ##################
232        #Now do the same for the first timestep
233
234        timestep = 0 #First timestep
235        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
236        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
237        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
238        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
239
240        #Reference interpolated values at midpoints on diagonal at
241        #this timestep are
242        r0 = (D[0] + D[1])/2
243        r1 = (D[1] + D[2])/2
244        r2 = (D[2] + D[3])/2
245
246        #Let us see if the file function can find the correct
247        #values
248        q = f(0, point_id=0); assert num.allclose(r0, q)
249        q = f(0, point_id=1); assert num.allclose(r1, q)
250        q = f(0, point_id=2); assert num.allclose(r2, q)
251
252
253        ##################
254        #Now do it again for a timestep in the middle
255
256        timestep = 33
257        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
258        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
259        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
260        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
261
262        #Reference interpolated values at midpoints on diagonal at
263        #this timestep are
264        r0 = (D[0] + D[1])/2
265        r1 = (D[1] + D[2])/2
266        r2 = (D[2] + D[3])/2
267
268        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
269        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
270        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
271
272
273        ##################
274        #Now check temporal interpolation
275        #Halfway between timestep 15 and 16
276
277        timestep = 15
278        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
279        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
280        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
281        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
282
283        #Reference interpolated values at midpoints on diagonal at
284        #this timestep are
285        r0_0 = (D[0] + D[1])/2
286        r1_0 = (D[1] + D[2])/2
287        r2_0 = (D[2] + D[3])/2
288
289        #
290        timestep = 16
291        d_stage = num.reshape(num.take(stage[timestep, :], [0,5,10,15], axis=0), (4,1))
292        d_uh = num.reshape(num.take(xmomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
293        d_vh = num.reshape(num.take(ymomentum[timestep, :], [0,5,10,15], axis=0), (4,1))
294        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
295
296        #Reference interpolated values at midpoints on diagonal at
297        #this timestep are
298        r0_1 = (D[0] + D[1])/2
299        r1_1 = (D[1] + D[2])/2
300        r2_1 = (D[2] + D[3])/2
301
302        # The reference values are
303        r0 = (r0_0 + r0_1)/2
304        r1 = (r1_0 + r1_1)/2
305        r2 = (r2_0 + r2_1)/2
306
307        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
308        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
309        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
310
311        ##################
312        #Finally check interpolation 2 thirds of the way
313        #between timestep 15 and 16
314
315        # The reference values are
316        r0 = (r0_0 + 2*r0_1)/3
317        r1 = (r1_0 + 2*r1_1)/3
318        r2 = (r2_0 + 2*r2_1)/3
319
320        #And the file function gives
321        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
322        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
323        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
324
325        fid.close()
326        import os
327        os.remove(filename)
328
329
330
331    def test_spatio_temporal_file_function_different_origin(self):
332        """Test that spatio temporal file function performs the correct
333        interpolations in both time and space where space is offset by
334        xllcorner and yllcorner
335        NetCDF version (x,y,t dependency)       
336        """
337        xllcorner = 2048
338        yllcorner = 11000
339        zone = 2
340
341        #Create basic mesh and shallow water domain
342        points, vertices, boundary = rectangular(3, 3)
343        domain1 = Domain(points, vertices, boundary,
344                         geo_reference = Geo_reference(xllcorner = xllcorner,
345                                                       yllcorner = yllcorner))
346       
347
348        from anuga.utilities.numerical_tools import mean       
349        domain1.reduction = mean
350        domain1.smooth = True #NOTE: Mimic sww output where each vertex has
351                              # only one value.
352
353        domain1.default_order = 2
354        domain1.store = True
355        domain1.set_datadir('.')
356        domain1.set_name('spatio_temporal_boundary_source_%d' %(id(self)))
357
358        #Bed-slope, friction and IC at vertices (and interpolated elsewhere)
359        domain1.set_quantity('elevation', 0)
360        domain1.set_quantity('friction', 0)
361        domain1.set_quantity('stage', 0)
362
363        # Boundary conditions
364        B0 = Dirichlet_boundary([0,0,0])
365        B6 = Dirichlet_boundary([0.6,0,0])
366        domain1.set_boundary({'left': B6, 'top': B6, 'right': B0, 'bottom': B0})
367        domain1.check_integrity()
368
369        finaltime = 8
370        #Evolution
371        for t in domain1.evolve(yieldstep = 0.1, finaltime = finaltime):
372            pass
373            #domain1.write_time()
374
375
376        #Now read data from sww and check
377        from Scientific.IO.NetCDF import NetCDFFile
378        filename = domain1.get_name() + '.sww'
379        fid = NetCDFFile(filename)
380
381        x = fid.variables['x'][:]
382        y = fid.variables['y'][:]
383        # we 'cast' to 64 bit floats to pass this test
384        # SWW file quantities are stored as 32 bits
385        x = num.array(x, num.float)
386        y = num.array(y, num.float)
387
388        stage = fid.variables['stage'][:]
389        xmomentum = fid.variables['xmomentum'][:]
390        ymomentum = fid.variables['ymomentum'][:]
391        time = fid.variables['time'][:]
392
393        #Take stage vertex values at last timestep on diagonal
394        #Diagonal is identified by vertices: 0, 5, 10, 15
395
396        last_time_index = len(time)-1 #Last last_time_index     
397        d_stage = num.reshape(num.take(stage[last_time_index, :],
398                                       [0,5,10,15],
399                                       axis=0),
400                              (4,1))
401        d_uh = num.reshape(num.take(xmomentum[last_time_index, :],
402                                    [0,5,10,15],
403                                   axis=0),
404                           (4,1))
405        d_vh = num.reshape(num.take(ymomentum[last_time_index, :],
406                                    [0,5,10,15],
407                                    axis=0),
408                           (4,1))
409        D = num.concatenate((d_stage, d_uh, d_vh), axis=1)
410
411        #Reference interpolated values at midpoints on diagonal at
412        #this timestep are
413        r0 = (D[0] + D[1])/2
414        r1 = (D[1] + D[2])/2
415        r2 = (D[2] + D[3])/2
416
417        #And the midpoints are found now
418        Dx = num.take(num.reshape(x, (16,1)), [0,5,10,15], axis=0)
419        Dy = num.take(num.reshape(y, (16,1)), [0,5,10,15], axis=0)
420
421        diag = num.concatenate((Dx, Dy), axis=1)
422        d_midpoints = (diag[1:] + diag[:-1])/2
423
424
425        #Adjust for georef - make interpolation points absolute
426        d_midpoints[:,0] += xllcorner
427        d_midpoints[:,1] += yllcorner               
428
429        #Let us see if the file function can find the correct
430        #values at the midpoints at the last timestep:
431        f = file_function(filename, domain1,
432                          interpolation_points = d_midpoints)
433
434        t = time[last_time_index]                         
435
436        q = f(t, point_id=0)
437        msg = '\nr0=%s\nq=%s' % (str(r0), str(q))
438        assert num.allclose(r0, q), msg
439
440        q = f(t, point_id=1)
441        msg = '\nr1=%s\nq=%s' % (str(r1), str(q))
442        assert num.allclose(r1, q), msg
443
444        q = f(t, point_id=2)
445        msg = '\nr2=%s\nq=%s' % (str(r2), str(q))
446        assert num.allclose(r2, q), msg
447
448
449        ##################
450        #Now do the same for the first timestep
451
452        timestep = 0 #First timestep
453        d_stage = num.reshape(num.take(stage[timestep, :],
454                                       [0,5,10,15],
455                                       axis=0),
456                              (4,1))
457        d_uh = num.reshape(num.take(xmomentum[timestep, :],
458                                    [0,5,10,15],
459                                    axis=0),
460                           (4,1))
461        d_vh = num.reshape(num.take(ymomentum[timestep, :],
462                                    [0,5,10,15],
463                                    axis=0),
464                           (4,1))
465        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
466
467        #Reference interpolated values at midpoints on diagonal at
468        #this timestep are
469        r0 = (D[0] + D[1])/2
470        r1 = (D[1] + D[2])/2
471        r2 = (D[2] + D[3])/2
472
473        #Let us see if the file function can find the correct
474        #values
475        q = f(0, point_id=0); assert num.allclose(r0, q)
476        q = f(0, point_id=1); assert num.allclose(r1, q)
477        q = f(0, point_id=2); assert num.allclose(r2, q)
478
479
480        ##################
481        #Now do it again for a timestep in the middle
482
483        timestep = 33
484        d_stage = num.reshape(num.take(stage[timestep, :],
485                                       [0,5,10,15],
486                                       axis=0),
487                              (4,1))
488        d_uh = num.reshape(num.take(xmomentum[timestep, :],
489                                    [0,5,10,15],
490                                    axis=0),
491                           (4,1))
492        d_vh = num.reshape(num.take(ymomentum[timestep, :],
493                                    [0,5,10,15],
494                                    axis=0),
495                           (4,1))
496        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
497
498        #Reference interpolated values at midpoints on diagonal at
499        #this timestep are
500        r0 = (D[0] + D[1])/2
501        r1 = (D[1] + D[2])/2
502        r2 = (D[2] + D[3])/2
503
504        q = f(timestep/10., point_id=0); assert num.allclose(r0, q)
505        q = f(timestep/10., point_id=1); assert num.allclose(r1, q)
506        q = f(timestep/10., point_id=2); assert num.allclose(r2, q)
507
508
509        ##################
510        #Now check temporal interpolation
511        #Halfway between timestep 15 and 16
512
513        timestep = 15
514        d_stage = num.reshape(num.take(stage[timestep, :],
515                                       [0,5,10,15],
516                                       axis=0),
517                              (4,1))
518        d_uh = num.reshape(num.take(xmomentum[timestep, :],
519                                    [0,5,10,15],
520                                    axis=0),
521                           (4,1))
522        d_vh = num.reshape(num.take(ymomentum[timestep, :],
523                                    [0,5,10,15],
524                                    axis=0),
525                           (4,1))
526        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
527
528        #Reference interpolated values at midpoints on diagonal at
529        #this timestep are
530        r0_0 = (D[0] + D[1])/2
531        r1_0 = (D[1] + D[2])/2
532        r2_0 = (D[2] + D[3])/2
533
534        #
535        timestep = 16
536        d_stage = num.reshape(num.take(stage[timestep, :],
537                                       [0,5,10,15],
538                                       axis=0),
539                              (4,1))
540        d_uh = num.reshape(num.take(xmomentum[timestep, :],
541                                    [0,5,10,15],
542                                    axis=0),
543                           (4,1))
544        d_vh = num.reshape(num.take(ymomentum[timestep, :],
545                                    [0,5,10,15],
546                                    axis=0),
547                           (4,1))
548        D = num.concatenate( (d_stage, d_uh, d_vh), axis=1)
549
550        #Reference interpolated values at midpoints on diagonal at
551        #this timestep are
552        r0_1 = (D[0] + D[1])/2
553        r1_1 = (D[1] + D[2])/2
554        r2_1 = (D[2] + D[3])/2
555
556        # The reference values are
557        r0 = (r0_0 + r0_1)/2
558        r1 = (r1_0 + r1_1)/2
559        r2 = (r2_0 + r2_1)/2
560
561        q = f((timestep - 0.5)/10., point_id=0); assert num.allclose(r0, q)
562        q = f((timestep - 0.5)/10., point_id=1); assert num.allclose(r1, q)
563        q = f((timestep - 0.5)/10., point_id=2); assert num.allclose(r2, q)
564
565        ##################
566        #Finally check interpolation 2 thirds of the way
567        #between timestep 15 and 16
568
569        # The reference values are
570        r0 = (r0_0 + 2*r0_1)/3
571        r1 = (r1_0 + 2*r1_1)/3
572        r2 = (r2_0 + 2*r2_1)/3
573
574        #And the file function gives
575        q = f((timestep - 1.0/3)/10., point_id=0); assert num.allclose(r0, q)
576        q = f((timestep - 1.0/3)/10., point_id=1); assert num.allclose(r1, q)
577        q = f((timestep - 1.0/3)/10., point_id=2); assert num.allclose(r2, q)
578
579        fid.close()
580        import os
581        os.remove(filename)
582
583       
584
585
586    def test_spatio_temporal_file_function_time(self):
587        """Test that File function interpolates correctly
588        between given times.
589        NetCDF version (x,y,t dependency)
590        """
591
592        #Create NetCDF (sww) file to be read
593        # x: 0, 5, 10, 15
594        # y: -20, -10, 0, 10
595        # t: 0, 60, 120, ...., 1200
596        #
597        # test quantities (arbitrary but non-trivial expressions):
598        #
599        #   stage     = 3*x - y**2 + 2*t
600        #   xmomentum = exp( -((x-7)**2 + (y+5)**2)/20 ) * t**2
601        #   ymomentum = x**2 + y**2 * sin(t*pi/600)
602
603        #NOTE: Nice test that may render some of the others redundant.
604
605        import os, time
606        from anuga.config import time_format
607        from mesh_factory import rectangular
608        from anuga.shallow_water.shallow_water_domain import Domain
609
610        finaltime = 1200
611        filename = 'test_file_function'
612
613        #Create a domain to hold test grid
614        #(0:15, -20:10)
615        points, vertices, boundary =\
616                rectangular(4, 4, 15, 30, origin = (0, -20))
617        #print "points", points
618
619        #print 'Number of elements', len(vertices)
620        domain = Domain(points, vertices, boundary)
621        domain.smooth = False
622        domain.default_order = 2
623        domain.set_datadir('.')
624        domain.set_name(filename)
625        domain.store = True
626
627        #print points
628        start = time.mktime(time.strptime('2000', '%Y'))
629        domain.starttime = start
630
631
632        #Store structure
633        domain.initialise_storage()
634
635        #Compute artificial time steps and store
636        dt = 60  #One minute intervals
637        t = 0.0
638        while t <= finaltime:
639            #Compute quantities
640            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
641            domain.set_quantity('stage', f1)
642
643            f2 = lambda x,y: x+y+t**2
644            domain.set_quantity('xmomentum', f2)
645
646            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
647            domain.set_quantity('ymomentum', f3)
648
649            #Store and advance time
650            domain.time = t
651            domain.store_timestep()
652            t += dt
653
654
655        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5]]
656     
657        #Deliberately set domain.starttime to too early
658        domain.starttime = start - 1
659
660        #Create file function
661        F = file_function(filename + '.sww', domain,
662                          quantities = domain.conserved_quantities,
663                          interpolation_points = interpolation_points)
664
665        #Check that FF updates fixes domain starttime
666        assert num.allclose(domain.starttime, start)
667
668        #Check that domain.starttime isn't updated if later
669        domain.starttime = start + 1
670        F = file_function(filename + '.sww', domain,
671                          quantities = domain.conserved_quantities,
672                          interpolation_points = interpolation_points)
673        assert num.allclose(domain.starttime, start+1)
674        domain.starttime = start
675
676
677        #Check linear interpolation in time
678        F = file_function(filename + '.sww', domain,
679                          quantities = domain.conserved_quantities,
680                          interpolation_points = interpolation_points)               
681        for id in range(len(interpolation_points)):
682            x = interpolation_points[id][0]
683            y = interpolation_points[id][1]
684
685            for i in range(20):
686                t = i*10
687                k = i%6
688
689                if k == 0:
690                    q0 = F(t, point_id=id)
691                    q1 = F(t+60, point_id=id)
692
693                if num.alltrue(q0 == NAN):
694                    actual = q0
695                else:
696                    actual = (k*q1 + (6-k)*q0)/6
697                q = F(t, point_id=id)
698                #print i, k, t, q
699                #print ' ', q0
700                #print ' ', q1
701                #print "q",q
702                #print "actual", actual
703                #print
704                if num.alltrue(q0 == NAN):
705                     self.failUnless(num.alltrue(q == actual), 'Fail!')
706                else:
707                    assert num.allclose(q, actual)
708
709
710        #Another check of linear interpolation in time
711        for id in range(len(interpolation_points)):
712            q60 = F(60, point_id=id)
713            q120 = F(120, point_id=id)
714
715            t = 90 #Halfway between 60 and 120
716            q = F(t, point_id=id)
717            assert num.allclose( (q120+q60)/2, q )
718
719            t = 100 #Two thirds of the way between between 60 and 120
720            q = F(t, point_id=id)
721            assert num.allclose(q60/3 + 2*q120/3, q)
722
723
724
725        #Check that domain.starttime isn't updated if later than file starttime but earlier
726        #than file end time
727        delta = 23
728        domain.starttime = start + delta
729        F = file_function(filename + '.sww', domain,
730                          quantities = domain.conserved_quantities,
731                          interpolation_points = interpolation_points)
732        assert num.allclose(domain.starttime, start+delta)
733
734
735
736
737        #Now try interpolation with delta offset
738        for id in range(len(interpolation_points)):           
739            x = interpolation_points[id][0]
740            y = interpolation_points[id][1]
741
742            for i in range(20):
743                t = i*10
744                k = i%6
745
746                if k == 0:
747                    q0 = F(t-delta, point_id=id)
748                    q1 = F(t+60-delta, point_id=id)
749
750                q = F(t-delta, point_id=id)
751                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
752
753
754        os.remove(filename + '.sww')
755
756
757
758    def Xtest_spatio_temporal_file_function_time(self):
759        # FIXME: This passes but needs some TLC
760        # Test that File function interpolates correctly
761        # When some points are outside the mesh
762
763        import os, time
764        from anuga.config import time_format
765        from mesh_factory import rectangular
766        from shallow_water import Domain
767        import anuga.shallow_water.data_manager 
768        from anuga.pmesh.mesh_interface import create_mesh_from_regions
769        finaltime = 1200
770       
771        filename = tempfile.mktemp()
772        #print "filename",filename
773        filename = 'test_file_function'
774
775        meshfilename = tempfile.mktemp(".tsh")
776
777        boundary_tags = {'walls':[0,1],'bom':[2]}
778       
779        polygon_absolute = [[0,-20],[10,-20],[10,15],[-20,15]]
780       
781        create_mesh_from_regions(polygon_absolute,
782                                 boundary_tags,
783                                 10000000,
784                                 filename=meshfilename)
785        domain = Domain(mesh_filename=meshfilename)
786        domain.smooth = False
787        domain.default_order = 2
788        domain.set_datadir('.')
789        domain.set_name(filename)
790        domain.store = True
791
792        #print points
793        start = time.mktime(time.strptime('2000', '%Y'))
794        domain.starttime = start
795       
796
797        #Store structure
798        domain.initialise_storage()
799
800        #Compute artificial time steps and store
801        dt = 60  #One minute intervals
802        t = 0.0
803        while t <= finaltime:
804            #Compute quantities
805            f1 = lambda x,y: 3*x - y**2 + 2*t + 4
806            domain.set_quantity('stage', f1)
807
808            f2 = lambda x,y: x+y+t**2
809            domain.set_quantity('xmomentum', f2)
810
811            f3 = lambda x,y: x**2 + y**2 * num.sin(t*num.pi/600)
812            domain.set_quantity('ymomentum', f3)
813
814            #Store and advance time
815            domain.time = t
816            domain.store_timestep()
817            t += dt
818
819        interpolation_points = [[1,0]]
820        interpolation_points = [[100,1000]]
821       
822        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14], [10,-12.5],
823                                [78787,78787],[7878,3432]]
824           
825        #Deliberately set domain.starttime to too early
826        domain.starttime = start - 1
827
828        #Create file function
829        F = file_function(filename + '.sww', domain,
830                          quantities = domain.conserved_quantities,
831                          interpolation_points = interpolation_points)
832
833        #Check that FF updates fixes domain starttime
834        assert num.allclose(domain.starttime, start)
835
836        #Check that domain.starttime isn't updated if later
837        domain.starttime = start + 1
838        F = file_function(filename + '.sww', domain,
839                          quantities = domain.conserved_quantities,
840                          interpolation_points = interpolation_points)
841        assert num.allclose(domain.starttime, start+1)
842        domain.starttime = start
843
844
845        #Check linear interpolation in time
846        # checking points inside and outside the mesh
847        F = file_function(filename + '.sww', domain,
848                          quantities = domain.conserved_quantities,
849                          interpolation_points = interpolation_points)
850       
851        for id in range(len(interpolation_points)):
852            x = interpolation_points[id][0]
853            y = interpolation_points[id][1]
854
855            for i in range(20):
856                t = i*10
857                k = i%6
858
859                if k == 0:
860                    q0 = F(t, point_id=id)
861                    q1 = F(t+60, point_id=id)
862
863                if q0 == NAN:
864                    actual = q0
865                else:
866                    actual = (k*q1 + (6-k)*q0)/6
867                q = F(t, point_id=id)
868                #print i, k, t, q
869                #print ' ', q0
870                #print ' ', q1
871                #print "q",q
872                #print "actual", actual
873                #print
874                if q0 == NAN:
875                     self.failUnless( q == actual, 'Fail!')
876                else:
877                    assert num.allclose(q, actual)
878
879        # now lets check points inside the mesh
880        interpolation_points = [[0,-20], [1,0], [0,1], [1.1, 3.14]] #, [10,-12.5]] - this point doesn't work WHY?
881        interpolation_points = [[10,-12.5]]
882           
883        print "len(interpolation_points)",len(interpolation_points) 
884        F = file_function(filename + '.sww', domain,
885                          quantities = domain.conserved_quantities,
886                          interpolation_points = interpolation_points)
887
888        domain.starttime = start
889
890
891        #Check linear interpolation in time
892        F = file_function(filename + '.sww', domain,
893                          quantities = domain.conserved_quantities,
894                          interpolation_points = interpolation_points)               
895        for id in range(len(interpolation_points)):
896            x = interpolation_points[id][0]
897            y = interpolation_points[id][1]
898
899            for i in range(20):
900                t = i*10
901                k = i%6
902
903                if k == 0:
904                    q0 = F(t, point_id=id)
905                    q1 = F(t+60, point_id=id)
906
907                if q0 == NAN:
908                    actual = q0
909                else:
910                    actual = (k*q1 + (6-k)*q0)/6
911                q = F(t, point_id=id)
912                print "############"
913                print "id, x, y ", id, x, y #k, t, q
914                print "t", t
915                #print ' ', q0
916                #print ' ', q1
917                print "q",q
918                print "actual", actual
919                #print
920                if q0 == NAN:
921                     self.failUnless( q == actual, 'Fail!')
922                else:
923                    assert num.allclose(q, actual)
924
925
926        #Another check of linear interpolation in time
927        for id in range(len(interpolation_points)):
928            q60 = F(60, point_id=id)
929            q120 = F(120, point_id=id)
930
931            t = 90 #Halfway between 60 and 120
932            q = F(t, point_id=id)
933            assert num.allclose( (q120+q60)/2, q )
934
935            t = 100 #Two thirds of the way between between 60 and 120
936            q = F(t, point_id=id)
937            assert num.allclose(q60/3 + 2*q120/3, q)
938
939
940
941        #Check that domain.starttime isn't updated if later than file starttime but earlier
942        #than file end time
943        delta = 23
944        domain.starttime = start + delta
945        F = file_function(filename + '.sww', domain,
946                          quantities = domain.conserved_quantities,
947                          interpolation_points = interpolation_points)
948        assert num.allclose(domain.starttime, start+delta)
949
950
951
952
953        #Now try interpolation with delta offset
954        for id in range(len(interpolation_points)):           
955            x = interpolation_points[id][0]
956            y = interpolation_points[id][1]
957
958            for i in range(20):
959                t = i*10
960                k = i%6
961
962                if k == 0:
963                    q0 = F(t-delta, point_id=id)
964                    q1 = F(t+60-delta, point_id=id)
965
966                q = F(t-delta, point_id=id)
967                assert num.allclose(q, (k*q1 + (6-k)*q0)/6)
968
969
970        os.remove(filename + '.sww')
971
972    def test_file_function_time_with_domain(self):
973        """Test that File function interpolates correctly
974        between given times. No x,y dependency here.
975        Use domain with starttime
976        """
977
978        #Write file
979        import os, time, calendar
980        from anuga.config import time_format
981        from math import sin, pi
982
983        finaltime = 1200
984        filename = 'test_file_function'
985        fid = open(filename + '.txt', 'w')
986        start = time.mktime(time.strptime('2000', '%Y'))
987        dt = 60  #One minute intervals
988        t = 0.0
989        while t <= finaltime:
990            t_string = time.strftime(time_format, time.gmtime(t+start))
991            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
992            t += dt
993
994        fid.close()
995
996
997        #Convert ASCII file to NetCDF (Which is what we really like!)
998        timefile2netcdf(filename)
999
1000
1001
1002        a = [0.0, 0.0]
1003        b = [4.0, 0.0]
1004        c = [0.0, 3.0]
1005
1006        points = [a, b, c]
1007        vertices = [[0,1,2]]
1008        domain = Domain(points, vertices)
1009
1010        # Check that domain.starttime is updated if non-existing
1011        F = file_function(filename + '.tms',
1012                          domain,
1013                          quantities = ['Attribute0', 'Attribute1', 'Attribute2']) 
1014        assert num.allclose(domain.starttime, start)
1015
1016        # Check that domain.starttime is updated if too early
1017        domain.starttime = start - 1
1018        F = file_function(filename + '.tms',
1019                          domain,
1020                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1021        assert num.allclose(domain.starttime, start)
1022
1023        # Check that domain.starttime isn't updated if later
1024        domain.starttime = start + 1
1025        F = file_function(filename + '.tms',
1026                          domain,
1027                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])
1028        assert num.allclose(domain.starttime, start+1)
1029
1030        domain.starttime = start
1031        F = file_function(filename + '.tms',
1032                          domain,
1033                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'],
1034                          use_cache=True)
1035       
1036
1037        #print F.precomputed_values
1038        #print 'F(60)', F(60)
1039       
1040        #Now try interpolation
1041        for i in range(20):
1042            t = i*10
1043            q = F(t)
1044
1045            #Exact linear intpolation
1046            assert num.allclose(q[0], 2*t)
1047            if i%6 == 0:
1048                assert num.allclose(q[1], t**2)
1049                assert num.allclose(q[2], sin(t*pi/600))
1050
1051        #Check non-exact
1052
1053        t = 90 #Halfway between 60 and 120
1054        q = F(t)
1055        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1056        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1057
1058
1059        t = 100 #Two thirds of the way between between 60 and 120
1060        q = F(t)
1061        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1062        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1063
1064        os.remove(filename + '.tms')
1065        os.remove(filename + '.txt')       
1066
1067    def test_file_function_time_with_domain_different_start(self):
1068        """Test that File function interpolates correctly
1069        between given times. No x,y dependency here.
1070        Use domain with a starttime later than that of file
1071
1072        ASCII version
1073        """
1074
1075        #Write file
1076        import os, time, calendar
1077        from anuga.config import time_format
1078        from math import sin, pi
1079
1080        finaltime = 1200
1081        filename = 'test_file_function'
1082        fid = open(filename + '.txt', 'w')
1083        start = time.mktime(time.strptime('2000', '%Y'))
1084        dt = 60  #One minute intervals
1085        t = 0.0
1086        while t <= finaltime:
1087            t_string = time.strftime(time_format, time.gmtime(t+start))
1088            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1089            t += dt
1090
1091        fid.close()
1092
1093        #Convert ASCII file to NetCDF (Which is what we really like!)
1094        timefile2netcdf(filename)       
1095
1096        a = [0.0, 0.0]
1097        b = [4.0, 0.0]
1098        c = [0.0, 3.0]
1099
1100        points = [a, b, c]
1101        vertices = [[0,1,2]]
1102        domain = Domain(points, vertices)
1103
1104        #Check that domain.starttime isn't updated if later than file starttime but earlier
1105        #than file end time
1106        delta = 23
1107        domain.starttime = start + delta
1108        F = file_function(filename + '.tms', domain,
1109                          quantities = ['Attribute0', 'Attribute1', 'Attribute2'])       
1110        assert num.allclose(domain.starttime, start+delta)
1111
1112        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1113                                            277., 337., 397., 457., 517.,
1114                                            577., 637., 697., 757., 817.,
1115                                            877., 937., 997., 1057., 1117.,
1116                                            1177.])
1117
1118
1119        #Now try interpolation with delta offset
1120        for i in range(20):
1121            t = i*10
1122            q = F(t-delta)
1123
1124            #Exact linear intpolation
1125            assert num.allclose(q[0], 2*t)
1126            if i%6 == 0:
1127                assert num.allclose(q[1], t**2)
1128                assert num.allclose(q[2], sin(t*pi/600))
1129
1130        #Check non-exact
1131
1132        t = 90 #Halfway between 60 and 120
1133        q = F(t-delta)
1134        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1135        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1136
1137
1138        t = 100 #Two thirds of the way between between 60 and 120
1139        q = F(t-delta)
1140        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1141        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1142
1143
1144        os.remove(filename + '.tms')
1145        os.remove(filename + '.txt')               
1146
1147       
1148
1149    def test_file_function_time_with_domain_different_start_and_time_limit(self):
1150        """Test that File function interpolates correctly
1151        between given times. No x,y dependency here.
1152        Use domain with a starttime later than that of file
1153
1154        ASCII version
1155       
1156        This test also tests that time can be truncated.
1157        """
1158
1159        # Write file
1160        import os, time, calendar
1161        from anuga.config import time_format
1162        from math import sin, pi
1163
1164        finaltime = 1200
1165        filename = 'test_file_function'
1166        fid = open(filename + '.txt', 'w')
1167        start = time.mktime(time.strptime('2000', '%Y'))
1168        dt = 60  #One minute intervals
1169        t = 0.0
1170        while t <= finaltime:
1171            t_string = time.strftime(time_format, time.gmtime(t+start))
1172            fid.write('%s, %f %f %f\n' %(t_string, 2*t, t**2, sin(t*pi/600)))
1173            t += dt
1174
1175        fid.close()
1176
1177        # Convert ASCII file to NetCDF (Which is what we really like!)
1178        timefile2netcdf(filename)       
1179
1180        a = [0.0, 0.0]
1181        b = [4.0, 0.0]
1182        c = [0.0, 3.0]
1183
1184        points = [a, b, c]
1185        vertices = [[0,1,2]]
1186        domain = Domain(points, vertices)
1187
1188        # Check that domain.starttime isn't updated if later than file starttime but earlier
1189        # than file end time
1190        delta = 23
1191        domain.starttime = start + delta
1192        time_limit = domain.starttime + 600
1193        F = file_function(filename + '.tms', domain,
1194                          time_limit=time_limit,
1195                          quantities=['Attribute0', 'Attribute1', 'Attribute2'])       
1196        assert num.allclose(domain.starttime, start+delta)
1197
1198        assert num.allclose(F.get_time(), [-23., 37., 97., 157., 217.,
1199                                            277., 337., 397., 457., 517.,
1200                                            577.])       
1201
1202
1203
1204        # Now try interpolation with delta offset
1205        for i in range(20):
1206            t = i*10
1207            q = F(t-delta)
1208
1209            #Exact linear intpolation
1210            assert num.allclose(q[0], 2*t)
1211            if i%6 == 0:
1212                assert num.allclose(q[1], t**2)
1213                assert num.allclose(q[2], sin(t*pi/600))
1214
1215        # Check non-exact
1216        t = 90 #Halfway between 60 and 120
1217        q = F(t-delta)
1218        assert num.allclose( (120**2 + 60**2)/2, q[1] )
1219        assert num.allclose( (sin(120*pi/600) + sin(60*pi/600))/2, q[2] )
1220
1221
1222        t = 100 # Two thirds of the way between between 60 and 120
1223        q = F(t-delta)
1224        assert num.allclose( 2*120**2/3 + 60**2/3, q[1] )
1225        assert num.allclose( 2*sin(120*pi/600)/3 + sin(60*pi/600)/3, q[2] )
1226
1227
1228        os.remove(filename + '.tms')
1229        os.remove(filename + '.txt')               
1230
1231       
1232       
1233       
1234
1235
1236    def test_apply_expression_to_dictionary(self):
1237
1238        #FIXME: Division is not expected to work for integers.
1239        #This must be caught.
1240        foo = num.array([[1,2,3], [4,5,6]], num.float)
1241
1242        bar = num.array([[-1,0,5], [6,1,1]], num.float)                 
1243
1244        D = {'X': foo, 'Y': bar}
1245
1246        Z = apply_expression_to_dictionary('X+Y', D)       
1247        assert num.allclose(Z, foo+bar)
1248
1249        Z = apply_expression_to_dictionary('X*Y', D)       
1250        assert num.allclose(Z, foo*bar)       
1251
1252        Z = apply_expression_to_dictionary('4*X+Y', D)       
1253        assert num.allclose(Z, 4*foo+bar)       
1254
1255        # test zero division is OK
1256        Z = apply_expression_to_dictionary('X/Y', D)
1257        assert num.allclose(1/Z, 1/(foo/bar)) # can't compare inf to inf
1258
1259        # make an error for zero on zero
1260        # this is really an error in numeric, SciPy core can handle it
1261        # Z = apply_expression_to_dictionary('0/Y', D)
1262
1263        #Check exceptions
1264        try:
1265            #Wrong name
1266            Z = apply_expression_to_dictionary('4*X+A', D)       
1267        except NameError:
1268            pass
1269        else:
1270            msg = 'Should have raised a NameError Exception'
1271            raise msg
1272
1273
1274        try:
1275            #Wrong order
1276            Z = apply_expression_to_dictionary(D, '4*X+A')       
1277        except AssertionError:
1278            pass
1279        else:
1280            msg = 'Should have raised a AssertionError Exception'
1281            raise msg       
1282       
1283
1284    def test_multiple_replace(self):
1285        """Hard test that checks a true word-by-word simultaneous replace
1286        """
1287       
1288        D = {'x': 'xi', 'y': 'eta', 'xi':'lam'}
1289        exp = '3*x+y + xi'
1290       
1291        new = multiple_replace(exp, D)
1292       
1293        assert new == '3*xi+eta + lam'
1294                         
1295
1296
1297    def test_point_on_line_obsolete(self):
1298        """Test that obsolete call issues appropriate warning"""
1299
1300        #Turn warning into an exception
1301        import warnings
1302        warnings.filterwarnings('error')
1303
1304        try:
1305            assert point_on_line( 0, 0.5, 0,1, 0,0 )
1306        except DeprecationWarning:
1307            pass
1308        else:
1309            msg = 'point_on_line should have issued a DeprecationWarning'
1310            raise Exception(msg)   
1311
1312        warnings.resetwarnings()
1313   
1314    def test_get_revision_number(self):
1315        """test_get_revision_number(self):
1316
1317        Test that revision number can be retrieved.
1318        """
1319        if os.environ.has_key('USER') and os.environ['USER'] == 'dgray':
1320            # I have a known snv incompatability issue,
1321            # so I'm skipping this test.
1322            # FIXME when SVN is upgraded on our clusters
1323            pass
1324        else:   
1325            n = get_revision_number()
1326            assert n>=0
1327
1328
1329       
1330    def test_add_directories(self):
1331       
1332        import tempfile
1333        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1334        directories = ['ja','ne','ke']
1335        kens_dir = add_directories(root_dir, directories)
1336        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1337               sep + 'ke'
1338        assert access(root_dir,F_OK)
1339
1340        add_directories(root_dir, directories)
1341        assert access(root_dir,F_OK)
1342       
1343        #clean up!
1344        os.rmdir(kens_dir)
1345        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1346        os.rmdir(root_dir + sep + 'ja')
1347        os.rmdir(root_dir)
1348
1349    def test_add_directories_bad(self):
1350       
1351        import tempfile
1352        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1353        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1354       
1355        try:
1356            kens_dir = add_directories(root_dir, directories)
1357        except OSError:
1358            pass
1359        else:
1360            msg = 'bad dir name should give OSError'
1361            raise Exception(msg)   
1362           
1363        #clean up!
1364        os.rmdir(root_dir)
1365
1366    def test_check_list(self):
1367
1368        check_list(['stage','xmomentum'])
1369
1370       
1371    def test_add_directories(self):
1372       
1373        import tempfile
1374        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1375        directories = ['ja','ne','ke']
1376        kens_dir = add_directories(root_dir, directories)
1377        assert kens_dir == root_dir + sep + 'ja' + sep + 'ne' + \
1378               sep + 'ke'
1379        assert access(root_dir,F_OK)
1380
1381        add_directories(root_dir, directories)
1382        assert access(root_dir,F_OK)
1383       
1384        #clean up!
1385        os.rmdir(kens_dir)
1386        os.rmdir(root_dir + sep + 'ja' + sep + 'ne')
1387        os.rmdir(root_dir + sep + 'ja')
1388        os.rmdir(root_dir)
1389
1390    def test_add_directories_bad(self):
1391       
1392        import tempfile
1393        root_dir = tempfile.mkdtemp('_test_util', 'test_util_')
1394        directories = ['/\/!@#@#$%^%&*((*:*:','ne','ke']
1395       
1396        try:
1397            kens_dir = add_directories(root_dir, directories)
1398        except OSError:
1399            pass
1400        else:
1401            msg = 'bad dir name should give OSError'
1402            raise Exception(msg)   
1403           
1404        #clean up!
1405        os.rmdir(root_dir)
1406
1407    def test_check_list(self):
1408
1409        check_list(['stage','xmomentum'])
1410
1411######
1412# Test the remove_lone_verts() function
1413######
1414       
1415    def test_remove_lone_verts_a(self):
1416        verts = [[0,0],[1,0],[0,1]]
1417        tris = [[0,1,2]]
1418        new_verts, new_tris = remove_lone_verts(verts, tris)
1419        self.failUnless(new_verts.tolist() == verts)
1420        self.failUnless(new_tris.tolist() == tris)
1421
1422    def test_remove_lone_verts_b(self):
1423        verts = [[0,0],[1,0],[0,1],[99,99]]
1424        tris = [[0,1,2]]
1425        new_verts, new_tris = remove_lone_verts(verts, tris)
1426        self.failUnless(new_verts.tolist() == verts[0:3])
1427        self.failUnless(new_tris.tolist() == tris)
1428       
1429    def test_remove_lone_verts_c(self):
1430        verts = [[99,99],[0,0],[1,0],[99,99],[0,1],[99,99]]
1431        tris = [[1,2,4]]
1432        new_verts, new_tris = remove_lone_verts(verts, tris)
1433        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1434        self.failUnless(new_tris.tolist() == [[0,1,2]])
1435     
1436    def test_remove_lone_verts_d(self):
1437        verts = [[0,0],[1,0],[99,99],[0,1]]
1438        tris = [[0,1,3]]
1439        new_verts, new_tris = remove_lone_verts(verts, tris)
1440        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1]])
1441        self.failUnless(new_tris.tolist() == [[0,1,2]])
1442       
1443    def test_remove_lone_verts_e(self):
1444        verts = [[0,0],[1,0],[0,1],[99,99],[99,99],[99,99]]
1445        tris = [[0,1,2]]
1446        new_verts, new_tris = remove_lone_verts(verts, tris)
1447        self.failUnless(new_verts.tolist() == verts[0:3])
1448        self.failUnless(new_tris.tolist() == tris)
1449     
1450    def test_remove_lone_verts_f(self):
1451        verts = [[0,0],[1,0],[99,99],[0,1],[99,99],[1,1],[99,99]]
1452        tris = [[0,1,3],[0,1,5]]
1453        new_verts, new_tris = remove_lone_verts(verts, tris)
1454        self.failUnless(new_verts.tolist() == [[0,0],[1,0],[0,1],[1,1]])
1455        self.failUnless(new_tris.tolist() == [[0,1,2],[0,1,3]])
1456       
1457######
1458#
1459######
1460       
1461    def test_get_min_max_values(self):
1462       
1463        list=[8,9,6,1,4]
1464        min1, max1 = get_min_max_values(list)
1465       
1466        assert min1==1 
1467        assert max1==9
1468       
1469    def test_get_min_max_values1(self):
1470       
1471        list=[-8,-9,-6,-1,-4]
1472        min1, max1 = get_min_max_values(list)
1473       
1474#        print 'min1,max1',min1,max1
1475        assert min1==-9 
1476        assert max1==-1
1477
1478#    def test_get_min_max_values2(self):
1479#        '''
1480#        The min and max supplied are greater than the ones in the
1481#        list and therefore are the ones returned
1482#        '''
1483#        list=[-8,-9,-6,-1,-4]
1484#        min1, max1 = get_min_max_values(list,-10,10)
1485#       
1486##        print 'min1,max1',min1,max1
1487#        assert min1==-10
1488#        assert max1==10
1489       
1490    def test_make_plots_from_csv_files(self):
1491       
1492        #if sys.platform == 'win32':  #Windows
1493            try: 
1494                import pylab
1495            except ImportError:
1496                #ANUGA don't need pylab to work so the system doesn't
1497                #rely on pylab being installed
1498                return
1499           
1500       
1501            current_dir=getcwd()+sep+'abstract_2d_finite_volumes'
1502            temp_dir = tempfile.mkdtemp('','figures')
1503    #        print 'temp_dir',temp_dir
1504            fileName = temp_dir+sep+'time_series_3.csv'
1505            file = open(fileName,"w")
1506            file.write("time,stage,speed,momentum,elevation\n\
15071.0, 0, 0, 0, 10 \n\
15082.0, 5, 2, 4, 10 \n\
15093.0, 3, 3, 5, 10 \n")
1510            file.close()
1511   
1512            fileName1 = temp_dir+sep+'time_series_4.csv'
1513            file1 = open(fileName1,"w")
1514            file1.write("time,stage,speed,momentum,elevation\n\
15151.0, 0, 0, 0, 5 \n\
15162.0, -5, -2, -4, 5 \n\
15173.0, -4, -3, -5, 5 \n")
1518            file1.close()
1519   
1520            fileName2 = temp_dir+sep+'time_series_5.csv'
1521            file2 = open(fileName2,"w")
1522            file2.write("time,stage,speed,momentum,elevation\n\
15231.0, 0, 0, 0, 7 \n\
15242.0, 4, -0.45, 57, 7 \n\
15253.0, 6, -0.5, 56, 7 \n")
1526            file2.close()
1527           
1528            dir, name=os.path.split(fileName)
1529            csv2timeseries_graphs(directories_dic={dir:['gauge', 0, 0]},
1530                                  output_dir=temp_dir,
1531                                  base_name='time_series_',
1532                                  plot_numbers=['3-5'],
1533                                  quantities=['speed','stage','momentum'],
1534                                  assess_all_csv_files=True,
1535                                  extra_plot_name='test')
1536           
1537            #print dir+sep+name[:-4]+'_stage_test.png'
1538            assert(access(dir+sep+name[:-4]+'_stage_test.png',F_OK)==True)
1539            assert(access(dir+sep+name[:-4]+'_speed_test.png',F_OK)==True)
1540            assert(access(dir+sep+name[:-4]+'_momentum_test.png',F_OK)==True)
1541   
1542            dir1, name1=os.path.split(fileName1)
1543            assert(access(dir+sep+name1[:-4]+'_stage_test.png',F_OK)==True)
1544            assert(access(dir+sep+name1[:-4]+'_speed_test.png',F_OK)==True)
1545            assert(access(dir+sep+name1[:-4]+'_momentum_test.png',F_OK)==True)
1546   
1547   
1548            dir2, name2=os.path.split(fileName2)
1549            assert(access(dir+sep+name2[:-4]+'_stage_test.png',F_OK)==True)
1550            assert(access(dir+sep+name2[:-4]+'_speed_test.png',F_OK)==True)
1551            assert(access(dir+sep+name2[:-4]+'_momentum_test.png',F_OK)==True)
1552   
1553            del_dir(temp_dir)
1554       
1555
1556
1557    def test_greens_law(self):
1558
1559        from math import sqrt
1560       
1561        d1 = 80.0
1562        d2 = 20.0
1563        h1 = 1.0
1564        h2 = greens_law(d1,d2,h1)
1565
1566        assert h2==sqrt(2.0)
1567       
1568    def test_calc_bearings(self):
1569 
1570        from math import atan, degrees
1571        #Test East
1572        uh = 1
1573        vh = 1.e-15
1574        angle = calc_bearing(uh, vh)
1575        if 89 < angle < 91: v=1
1576        assert v==1
1577        #Test West
1578        uh = -1
1579        vh = 1.e-15
1580        angle = calc_bearing(uh, vh)
1581        if 269 < angle < 271: v=1
1582        assert v==1
1583        #Test North
1584        uh = 1.e-15
1585        vh = 1
1586        angle = calc_bearing(uh, vh)
1587        if -1 < angle < 1: v=1
1588        assert v==1
1589        #Test South
1590        uh = 1.e-15
1591        vh = -1
1592        angle = calc_bearing(uh, vh)
1593        if 179 < angle < 181: v=1
1594        assert v==1
1595        #Test South-East
1596        uh = 1
1597        vh = -1
1598        angle = calc_bearing(uh, vh)
1599        if 134 < angle < 136: v=1
1600        assert v==1
1601        #Test North-East
1602        uh = 1
1603        vh = 1
1604        angle = calc_bearing(uh, vh)
1605        if 44 < angle < 46: v=1
1606        assert v==1
1607        #Test South-West
1608        uh = -1
1609        vh = -1
1610        angle = calc_bearing(uh, vh)
1611        if 224 < angle < 226: v=1
1612        assert v==1
1613        #Test North-West
1614        uh = -1
1615        vh = 1
1616        angle = calc_bearing(uh, vh)
1617        if 314 < angle < 316: v=1
1618        assert v==1
1619       
1620    def test_calc_bearings_zero_vector(self): 
1621        from math import atan, degrees
1622
1623        uh = 0
1624        vh = 0
1625        angle = calc_bearing(uh, vh)
1626
1627        assert angle == NAN
1628       
1629#-------------------------------------------------------------
1630
1631if __name__ == "__main__":
1632    suite = unittest.makeSuite(Test_Util, 'test')
1633#    runner = unittest.TextTestRunner(verbosity=2)
1634    runner = unittest.TextTestRunner(verbosity=1)
1635    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.