source: anuga_core/source/anuga/abstract_2d_finite_volumes/test_domain.py @ 7737

Last change on this file since 7737 was 7737, checked in by hudson, 14 years ago

Various refactorings, all unit tests pass.
Domain renamed to generic domain.

File size: 33.2 KB
Line 
1#!/usr/bin/env python
2
3import unittest
4from math import sqrt
5
6from anuga.abstract_2d_finite_volumes.generic_domain import *
7from anuga.pmesh.mesh_interface import create_mesh_from_regions
8from anuga.config import epsilon
9from anuga.shallow_water import Reflective_boundary
10from anuga.shallow_water import Dirichlet_boundary
11import numpy as num
12from anuga.pmesh.mesh import Segment, Vertex, Mesh
13
14
15def add_to_verts(tag, elements, domain):
16    if tag == "mound":
17        domain.test = "Mound"
18
19
20
21class Test_Domain(unittest.TestCase):
22    def setUp(self):
23        pass
24
25
26    def tearDown(self):
27        pass
28
29
30    def test_simple(self):
31        a = [0.0, 0.0]
32        b = [0.0, 2.0]
33        c = [2.0,0.0]
34        d = [0.0, 4.0]
35        e = [2.0, 2.0]
36        f = [4.0,0.0]
37
38        points = [a, b, c, d, e, f]
39        #bac, bce, ecf, dbe, daf, dae
40        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4]]
41
42        conserved_quantities = ['stage', 'xmomentum', 'ymomentum']
43        evolved_quantities = ['stage', 'xmomentum', 'ymomentum', 'xvelocity']
44       
45        other_quantities = ['elevation', 'friction']
46
47        domain = Generic_Domain(points, vertices, None,
48                        conserved_quantities, evolved_quantities, other_quantities)
49        domain.check_integrity()
50
51        for name in conserved_quantities + other_quantities:
52            assert domain.quantities.has_key(name)
53
54
55        assert num.alltrue(domain.get_conserved_quantities(0, edge=1) == 0.)
56
57
58
59    def test_CFL(self):
60        a = [0.0, 0.0]
61        b = [0.0, 2.0]
62        c = [2.0,0.0]
63        d = [0.0, 4.0]
64        e = [2.0, 2.0]
65        f = [4.0,0.0]
66
67        points = [a, b, c, d, e, f]
68        #bac, bce, ecf, dbe, daf, dae
69        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4]]
70
71        conserved_quantities = ['stage', 'xmomentum', 'ymomentum']
72        evolved_quantities = ['stage', 'xmomentum', 'ymomentum', 'xvelocity']
73       
74        other_quantities = ['elevation', 'friction']
75
76        domain = Generic_Domain(points, vertices, None,
77                        conserved_quantities, evolved_quantities, other_quantities)
78
79        try:
80            domain.set_CFL(-0.1)
81        except:
82            pass
83        else:
84            msg = 'Should have caught a negative cfl'
85            raise Exception, msg
86
87
88       
89        try:
90            domain.set_CFL(2.0)
91        except:
92            pass
93        else:
94            msg = 'Should have warned of cfl>1.0'
95            raise Exception, msg
96
97        assert domain.CFL == 2.0
98       
99
100    def test_conserved_quantities(self):
101
102        a = [0.0, 0.0]
103        b = [0.0, 2.0]
104        c = [2.0,0.0]
105        d = [0.0, 4.0]
106        e = [2.0, 2.0]
107        f = [4.0,0.0]
108
109        points = [a, b, c, d, e, f]
110        #bac, bce, ecf, dbe, daf, dae
111        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4]]
112
113        domain = Generic_Domain(points, vertices, boundary=None,
114                        conserved_quantities =\
115                        ['stage', 'xmomentum', 'ymomentum'])
116
117
118        domain.set_quantity('stage', [[1,2,3], [5,5,5],
119                                      [0,0,9], [-6, 3, 3]])
120
121        domain.set_quantity('xmomentum', [[1,2,3], [5,5,5],
122                                          [0,0,9], [-6, 3, 3]])
123
124        domain.check_integrity()
125
126        #Centroids
127        q = domain.get_conserved_quantities(0)
128        assert num.allclose(q, [2., 2., 0.])
129
130        q = domain.get_conserved_quantities(1)
131        assert num.allclose(q, [5., 5., 0.])
132
133        q = domain.get_conserved_quantities(2)
134        assert num.allclose(q, [3., 3., 0.])
135
136        q = domain.get_conserved_quantities(3)
137        assert num.allclose(q, [0., 0., 0.])
138
139
140        #Edges
141        q = domain.get_conserved_quantities(0, edge=0)
142        assert num.allclose(q, [2.5, 2.5, 0.])
143        q = domain.get_conserved_quantities(0, edge=1)
144        assert num.allclose(q, [2., 2., 0.])
145        q = domain.get_conserved_quantities(0, edge=2)
146        assert num.allclose(q, [1.5, 1.5, 0.])
147
148        for i in range(3):
149            q = domain.get_conserved_quantities(1, edge=i)
150            assert num.allclose(q, [5, 5, 0.])
151
152
153        q = domain.get_conserved_quantities(2, edge=0)
154        assert num.allclose(q, [4.5, 4.5, 0.])
155        q = domain.get_conserved_quantities(2, edge=1)
156        assert num.allclose(q, [4.5, 4.5, 0.])
157        q = domain.get_conserved_quantities(2, edge=2)
158        assert num.allclose(q, [0., 0., 0.])
159
160
161        q = domain.get_conserved_quantities(3, edge=0)
162        assert num.allclose(q, [3., 3., 0.])
163        q = domain.get_conserved_quantities(3, edge=1)
164        assert num.allclose(q, [-1.5, -1.5, 0.])
165        q = domain.get_conserved_quantities(3, edge=2)
166        assert num.allclose(q, [-1.5, -1.5, 0.])
167
168
169
170    def test_create_quantity_from_expression(self):
171        """Quantity created from other quantities using arbitrary expression
172
173        """
174
175
176        a = [0.0, 0.0]
177        b = [0.0, 2.0]
178        c = [2.0,0.0]
179        d = [0.0, 4.0]
180        e = [2.0, 2.0]
181        f = [4.0,0.0]
182
183        points = [a, b, c, d, e, f]
184        #bac, bce, ecf, dbe, daf, dae
185        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4]]
186
187        domain = Generic_Domain(points, vertices, boundary=None,
188                        conserved_quantities =\
189                        ['stage', 'xmomentum', 'ymomentum'],
190                        other_quantities = ['elevation', 'friction'])
191
192
193        domain.set_quantity('elevation', -1)
194
195
196        domain.set_quantity('stage', [[1,2,3], [5,5,5],
197                                      [0,0,9], [-6, 3, 3]])
198
199        domain.set_quantity('xmomentum', [[1,2,3], [5,5,5],
200                                          [0,0,9], [-6, 3, 3]])
201
202        domain.set_quantity('ymomentum', [[3,3,3], [4,2,1],
203                                          [2,4,-1], [1, 0, 1]])
204
205        domain.check_integrity()
206
207
208
209        expression = 'stage - elevation'
210        Q = domain.create_quantity_from_expression(expression)
211
212        assert num.allclose(Q.vertex_values, [[2,3,4], [6,6,6],
213                                              [1,1,10], [-5, 4, 4]])
214
215        expression = '(xmomentum*xmomentum + ymomentum*ymomentum)**0.5'
216        Q = domain.create_quantity_from_expression(expression)
217
218        X = domain.quantities['xmomentum'].vertex_values
219        Y = domain.quantities['ymomentum'].vertex_values
220
221        assert num.allclose(Q.vertex_values, (X**2 + Y**2)**0.5)
222
223
224
225    def test_set_quanitities_to_be_monitored(self):
226        """test_set_quanitities_to_be_monitored
227        """
228
229        a = [0.0, 0.0]
230        b = [0.0, 2.0]
231        c = [2.0,0.0]
232        d = [0.0, 4.0]
233        e = [2.0, 2.0]
234        f = [4.0,0.0]
235
236        points = [a, b, c, d, e, f]
237        #bac, bce, ecf, dbe, daf, dae
238        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4]]
239
240
241        domain = Generic_Domain(points, vertices, boundary=None,
242                        conserved_quantities =\
243                        ['stage', 'xmomentum', 'ymomentum'],
244                        other_quantities = ['elevation', 'friction', 'depth'])
245
246
247        assert domain.quantities_to_be_monitored is None
248        domain.set_quantities_to_be_monitored(['stage', 'stage-elevation'])
249        assert len(domain.quantities_to_be_monitored) == 2
250        assert domain.quantities_to_be_monitored.has_key('stage')
251        assert domain.quantities_to_be_monitored.has_key('stage-elevation')
252        for key in domain.quantities_to_be_monitored['stage'].keys():
253            assert domain.quantities_to_be_monitored['stage'][key] is None
254
255
256        # Check that invalid requests are dealt with
257        try:
258            domain.set_quantities_to_be_monitored(['yyyyy'])       
259        except:
260            pass
261        else:
262            msg = 'Should have caught illegal quantity'
263            raise Exception, msg
264
265        try:
266            domain.set_quantities_to_be_monitored(['stage-xx'])       
267        except NameError:
268            pass
269        else:
270            msg = 'Should have caught illegal quantity'
271            raise Exception, msg
272
273        try:
274            domain.set_quantities_to_be_monitored('stage', 'stage-elevation')
275        except:
276            pass
277        else:
278            msg = 'Should have caught too many arguments'
279            raise Exception, msg
280
281        try:
282            domain.set_quantities_to_be_monitored('stage', 'blablabla')
283        except:
284            pass
285        else:
286            msg = 'Should have caught polygon as a string'
287            raise Exception, msg       
288
289
290
291        # Now try with a polygon restriction
292        domain.set_quantities_to_be_monitored('xmomentum',
293                                              polygon=[[1,1], [1,3], [3,3], [3,1]],
294                                              time_interval = [0,3])
295        assert domain.monitor_indices[0] == 1
296        assert domain.monitor_time_interval[0] == 0
297        assert domain.monitor_time_interval[1] == 3       
298       
299
300    def test_set_quantity_from_expression(self):
301        """Quantity set using arbitrary expression
302
303        """
304
305
306        a = [0.0, 0.0]
307        b = [0.0, 2.0]
308        c = [2.0,0.0]
309        d = [0.0, 4.0]
310        e = [2.0, 2.0]
311        f = [4.0,0.0]
312
313        points = [a, b, c, d, e, f]
314        #bac, bce, ecf, dbe, daf, dae
315        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4]]
316
317        domain = Generic_Domain(points, vertices, boundary=None,
318                        conserved_quantities =\
319                        ['stage', 'xmomentum', 'ymomentum'],
320                        other_quantities = ['elevation', 'friction', 'depth'])
321
322
323        domain.set_quantity('elevation', -1)
324
325
326        domain.set_quantity('stage', [[1,2,3], [5,5,5],
327                                      [0,0,9], [-6, 3, 3]])
328
329        domain.set_quantity('xmomentum', [[1,2,3], [5,5,5],
330                                          [0,0,9], [-6, 3, 3]])
331
332        domain.set_quantity('ymomentum', [[3,3,3], [4,2,1],
333                                          [2,4,-1], [1, 0, 1]])
334
335
336
337
338        domain.set_quantity('depth', expression = 'stage - elevation')
339
340        domain.check_integrity()
341
342
343
344
345        Q = domain.quantities['depth']
346
347        assert num.allclose(Q.vertex_values, [[2,3,4], [6,6,6],
348                                              [1,1,10], [-5, 4, 4]])
349
350
351
352                                     
353    def test_add_quantity(self):
354        """Test that quantities already set can be added to using
355        add_quantity
356
357        """
358
359
360        a = [0.0, 0.0]
361        b = [0.0, 2.0]
362        c = [2.0,0.0]
363        d = [0.0, 4.0]
364        e = [2.0, 2.0]
365        f = [4.0,0.0]
366
367        points = [a, b, c, d, e, f]
368        #bac, bce, ecf, dbe, daf, dae
369        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4]]
370
371        domain = Generic_Domain(points, vertices, boundary=None,
372                        conserved_quantities =\
373                        ['stage', 'xmomentum', 'ymomentum'],
374                        other_quantities = ['elevation', 'friction', 'depth'])
375
376
377        A = num.array([[1,2,3], [5,5,-5], [0,0,9], [-6,3,3]], num.float)
378        B = num.array([[2,4,4], [3,2,1], [6,-3,4], [4,5,-1]], num.float)
379       
380        # Shorthands
381        stage = domain.quantities['stage']
382        elevation = domain.quantities['elevation']
383        depth = domain.quantities['depth']
384       
385        # Go testing
386        domain.set_quantity('elevation', A)
387        domain.add_quantity('elevation', B)
388        assert num.allclose(elevation.vertex_values, A+B)
389       
390        domain.add_quantity('elevation', 4)
391        assert num.allclose(elevation.vertex_values, A+B+4)       
392       
393       
394        # Test using expression
395        domain.set_quantity('stage', [[1,2,3], [5,5,5],
396                                      [0,0,9], [-6, 3, 3]])       
397        domain.set_quantity('depth', 1.0)                                     
398        domain.add_quantity('depth', expression = 'stage - elevation')       
399        assert num.allclose(depth.vertex_values, stage.vertex_values-elevation.vertex_values+1)
400               
401       
402        # Check self referential expression
403        reference = 2*stage.vertex_values - depth.vertex_values
404        domain.add_quantity('stage', expression = 'stage - depth')               
405        assert num.allclose(stage.vertex_values, reference)       
406                                     
407
408        # Test using a function
409        def f(x, y):
410            return x+y
411           
412        domain.set_quantity('elevation', f)           
413        domain.set_quantity('stage', 5.0)
414        domain.set_quantity('depth', expression = 'stage - elevation')
415       
416        domain.add_quantity('depth', f)
417        assert num.allclose(stage.vertex_values, depth.vertex_values)               
418         
419           
420       
421       
422                                     
423                                     
424    def test_setting_timestepping_method(self):
425        """test_setting_timestepping_method
426        """
427
428        a = [0.0, 0.0]
429        b = [0.0, 2.0]
430        c = [2.0,0.0]
431        d = [0.0, 4.0]
432        e = [2.0, 2.0]
433        f = [4.0,0.0]
434
435        points = [a, b, c, d, e, f]
436        #bac, bce, ecf, dbe, daf, dae
437        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4]]
438
439
440        domain = Generic_Domain(points, vertices, boundary=None,
441                        conserved_quantities =\
442                        ['stage', 'xmomentum', 'ymomentum'],
443                        other_quantities = ['elevation', 'friction', 'depth'])
444
445
446        domain.timestepping_method = None
447
448
449        # Check that invalid requests are dealt with
450        try:
451            domain.set_timestepping_method('eee')       
452        except:
453            pass
454        else:
455            msg = 'Should have caught illegal method'
456            raise Exception, msg
457
458
459        #Should have no trouble with euler, rk2 or rk3
460        domain.set_timestepping_method('euler')
461        domain.set_timestepping_method('rk2')
462        domain.set_timestepping_method('rk3')
463
464        domain.set_timestepping_method(1)
465        domain.set_timestepping_method(2)
466        domain.set_timestepping_method(3)
467
468        #test get timestepping method
469        assert domain.get_timestepping_method() == 'rk3'
470
471
472
473    def test_boundary_indices(self):
474
475        from anuga.config import default_boundary_tag
476
477
478        a = [0.0, 0.5]
479        b = [0.0, 0.0]
480        c = [0.5, 0.5]
481
482        points = [a, b, c]
483        vertices = [ [0,1,2] ]
484        domain = Generic_Domain(points, vertices)
485
486        domain.set_boundary( {default_boundary_tag: Dirichlet_boundary([5,2,1])} )
487
488
489        domain.check_integrity()
490
491        assert num.allclose(domain.neighbours, [[-1,-2,-3]])
492
493
494
495    def test_boundary_conditions(self):
496
497        a = [0.0, 0.0]
498        b = [0.0, 2.0]
499        c = [2.0,0.0]
500        d = [0.0, 4.0]
501        e = [2.0, 2.0]
502        f = [4.0,0.0]
503
504        points = [a, b, c, d, e, f]
505        #bac, bce, ecf, dbe
506        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4] ]
507        boundary = { (0, 0): 'First',
508                     (0, 2): 'First',
509                     (2, 0): 'Second',
510                     (2, 1): 'Second',
511                     (3, 1): 'Second',
512                     (3, 2): 'Second'}
513
514
515        domain = Generic_Domain(points, vertices, boundary,
516                        conserved_quantities =\
517                        ['stage', 'xmomentum', 'ymomentum'])
518        domain.check_integrity()
519
520
521
522        domain.set_quantity('stage', [[1,2,3], [5,5,5],
523                                      [0,0,9], [-6, 3, 3]])
524
525
526        domain.set_boundary( {'First': Dirichlet_boundary([5,2,1]),
527                              'Second': Transmissive_boundary(domain)} )
528
529        domain.update_boundary()
530
531        assert domain.quantities['stage'].boundary_values[0] == 5. #Dirichlet
532        assert domain.quantities['stage'].boundary_values[1] == 5. #Dirichlet
533        assert domain.quantities['stage'].boundary_values[2] ==\
534               domain.get_conserved_quantities(2, edge=0)[0] #Transmissive (4.5)
535        assert domain.quantities['stage'].boundary_values[3] ==\
536               domain.get_conserved_quantities(2, edge=1)[0] #Transmissive (4.5)
537        assert domain.quantities['stage'].boundary_values[4] ==\
538               domain.get_conserved_quantities(3, edge=1)[0] #Transmissive (-1.5)
539        assert domain.quantities['stage'].boundary_values[5] ==\
540               domain.get_conserved_quantities(3, edge=2)[0] #Transmissive (-1.5)
541
542        #Check enumeration
543        for k, ((vol_id, edge_id), _) in enumerate(domain.boundary_objects):
544            assert domain.neighbours[vol_id, edge_id] == -k-1
545
546
547
548
549    def test_conserved_evolved_boundary_conditions(self):
550
551        a = [0.0, 0.0]
552        b = [0.0, 2.0]
553        c = [2.0,0.0]
554        d = [0.0, 4.0]
555        e = [2.0, 2.0]
556        f = [4.0,0.0]
557
558        points = [a, b, c, d, e, f]
559        #bac, bce, ecf, dbe
560        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4] ]
561        boundary = { (0, 0): 'First',
562                     (0, 2): 'First',
563                     (2, 0): 'Second',
564                     (2, 1): 'Second',
565                     (3, 1): 'Second',
566                     (3, 2): 'Second'}
567
568
569 
570        try:
571            domain = Generic_Domain(points, vertices, boundary,
572                            conserved_quantities = ['stage', 'xmomentum', 'ymomentum'],
573                            evolved_quantities =\
574                                   ['stage', 'xmomentum', 'xvelocity', 'ymomentum', 'yvelocity'])
575        except:
576            pass
577        else:
578            msg = 'Should have caught the evolved quantities not being in order'
579            raise Exception, msg           
580
581
582        domain = Generic_Domain(points, vertices, boundary,
583                        conserved_quantities = ['stage', 'xmomentum', 'ymomentum'],
584                        evolved_quantities =\
585                        ['stage', 'xmomentum', 'ymomentum', 'xvelocity', 'yvelocity'])
586
587
588        domain.set_quantity('stage', [[1,2,3], [5,5,5],
589                                      [0,0,9], [6, -3, 3]])
590
591
592        domain.set_boundary( {'First': Dirichlet_boundary([5,2,1,4,6]),
593                              'Second': Transmissive_boundary(domain)} )
594
595        try:
596            domain.update_boundary()
597        except:
598            pass
599        else:
600            msg = 'Should have caught the lack of conserved_values_to_evolved_values member function'
601            raise Exception, msg
602
603
604        def  conserved_values_to_evolved_values(q_cons, q_evol):
605
606            q_evol[0:3] = q_cons
607            q_evol[3] = q_cons[1]/q_cons[0]
608            q_evol[4] = q_cons[2]/q_cons[0]
609
610            return q_evol
611
612        domain.conserved_values_to_evolved_values = conserved_values_to_evolved_values
613
614        domain.update_boundary()
615
616
617        assert domain.quantities['stage'].boundary_values[0] == 5. #Dirichlet
618        assert domain.quantities['stage'].boundary_values[1] == 5. #Dirichlet
619        assert domain.quantities['xvelocity'].boundary_values[0] == 4. #Dirichlet
620        assert domain.quantities['yvelocity'].boundary_values[1] == 6. #Dirichlet
621
622        q_cons = domain.get_conserved_quantities(2, edge=0) #Transmissive
623        assert domain.quantities['stage'    ].boundary_values[2] == q_cons[0]
624        assert domain.quantities['xmomentum'].boundary_values[2] == q_cons[1]
625        assert domain.quantities['ymomentum'].boundary_values[2] == q_cons[2]
626        assert domain.quantities['xvelocity'].boundary_values[2] == q_cons[1]/q_cons[0]
627        assert domain.quantities['yvelocity'].boundary_values[2] == q_cons[2]/q_cons[0]
628
629        q_cons = domain.get_conserved_quantities(2, edge=1) #Transmissive
630        assert domain.quantities['stage'    ].boundary_values[3] == q_cons[0]
631        assert domain.quantities['xmomentum'].boundary_values[3] == q_cons[1]
632        assert domain.quantities['ymomentum'].boundary_values[3] == q_cons[2]
633        assert domain.quantities['xvelocity'].boundary_values[3] == q_cons[1]/q_cons[0]
634        assert domain.quantities['yvelocity'].boundary_values[3] == q_cons[2]/q_cons[0]       
635
636
637        q_cons = domain.get_conserved_quantities(3, edge=1) #Transmissive
638        assert domain.quantities['stage'    ].boundary_values[4] == q_cons[0]
639        assert domain.quantities['xmomentum'].boundary_values[4] == q_cons[1]
640        assert domain.quantities['ymomentum'].boundary_values[4] == q_cons[2]
641        assert domain.quantities['xvelocity'].boundary_values[4] == q_cons[1]/q_cons[0]
642        assert domain.quantities['yvelocity'].boundary_values[4] == q_cons[2]/q_cons[0]               
643
644
645        q_cons = domain.get_conserved_quantities(3, edge=2) #Transmissive
646        assert domain.quantities['stage'    ].boundary_values[5] == q_cons[0]
647        assert domain.quantities['xmomentum'].boundary_values[5] == q_cons[1]
648        assert domain.quantities['ymomentum'].boundary_values[5] == q_cons[2]
649        assert domain.quantities['xvelocity'].boundary_values[5] == q_cons[1]/q_cons[0]
650        assert domain.quantities['yvelocity'].boundary_values[5] == q_cons[2]/q_cons[0]
651 
652
653    def test_distribute_first_order(self):
654        """Domain implements a default first order gradient limiter
655        """
656
657        a = [0.0, 0.0]
658        b = [0.0, 2.0]
659        c = [2.0,0.0]
660        d = [0.0, 4.0]
661        e = [2.0, 2.0]
662        f = [4.0,0.0]
663
664        points = [a, b, c, d, e, f]
665        #bac, bce, ecf, dbe
666        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4] ]
667        boundary = { (0, 0): 'Third',
668                     (0, 2): 'First',
669                     (2, 0): 'Second',
670                     (2, 1): 'Second',
671                     (3, 1): 'Second',
672                     (3, 2): 'Third'}
673
674
675        domain = Generic_Domain(points, vertices, boundary,
676                        conserved_quantities =\
677                        ['stage', 'xmomentum', 'ymomentum'])
678        domain.set_default_order(1)
679        domain.check_integrity()
680
681
682        domain.set_quantity('stage', [[1,2,3], [5,5,5],
683                                      [0,0,9], [-6, 3, 3]])
684
685        assert num.allclose( domain.quantities['stage'].centroid_values,
686                             [2,5,3,0] )
687
688        domain.set_quantity('xmomentum', [[1,1,1], [2,2,2],
689                                          [3,3,3], [4, 4, 4]])
690
691        domain.set_quantity('ymomentum', [[10,10,10], [20,20,20],
692                                          [30,30,30], [40, 40, 40]])
693
694
695        domain.distribute_to_vertices_and_edges()
696
697        #First order extrapolation
698        assert num.allclose( domain.quantities['stage'].vertex_values,
699                             [[ 2.,  2.,  2.],
700                              [ 5.,  5.,  5.],
701                              [ 3.,  3.,  3.],
702                              [ 0.,  0.,  0.]])
703
704
705
706
707    def test_update_conserved_quantities(self):
708        a = [0.0, 0.0]
709        b = [0.0, 2.0]
710        c = [2.0,0.0]
711        d = [0.0, 4.0]
712        e = [2.0, 2.0]
713        f = [4.0,0.0]
714
715        points = [a, b, c, d, e, f]
716        #bac, bce, ecf, dbe
717        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4] ]
718        boundary = { (0, 0): 'Third',
719                     (0, 2): 'First',
720                     (2, 0): 'Second',
721                     (2, 1): 'Second',
722                     (3, 1): 'Second',
723                     (3, 2): 'Third'}
724
725
726        domain = Generic_Domain(points, vertices, boundary,
727                        conserved_quantities =\
728                        ['stage', 'xmomentum', 'ymomentum'])
729        domain.check_integrity()
730
731
732        domain.set_quantity('stage', [1,2,3,4], location='centroids')
733        domain.set_quantity('xmomentum', [1,2,3,4], location='centroids')
734        domain.set_quantity('ymomentum', [1,2,3,4], location='centroids')
735
736
737        #Assign some values to update vectors
738        #Set explicit_update
739
740
741        for name in domain.conserved_quantities:
742            domain.quantities[name].explicit_update = num.array([4.,3.,2.,1.])
743            domain.quantities[name].semi_implicit_update = num.array([1.,1.,1.,1.])
744
745
746        #Update with given timestep (assuming no other forcing terms)
747        domain.timestep = 0.1
748        domain.update_conserved_quantities()
749
750        sem = num.array([1.,1.,1.,1.])/num.array([1, 2, 3, 4])
751        denom = num.ones(4, num.float) - domain.timestep*sem
752
753#        x = array([1, 2, 3, 4]) + array( [.4,.3,.2,.1] )
754#        x /= denom
755
756        x = num.array([1., 2., 3., 4.])
757        x += domain.timestep*num.array( [4,3,2,1] )
758        x /= denom
759
760
761        for name in domain.conserved_quantities:
762            assert num.allclose(domain.quantities[name].centroid_values, x)
763
764
765    def test_set_region(self):
766        """Set quantities for sub region
767        """
768
769        a = [0.0, 0.0]
770        b = [0.0, 2.0]
771        c = [2.0,0.0]
772        d = [0.0, 4.0]
773        e = [2.0, 2.0]
774        f = [4.0,0.0]
775
776        points = [a, b, c, d, e, f]
777        #bac, bce, ecf, dbe
778        vertices = [ [1,0,2], [1,2,4], [4,2,5], [3,1,4] ]
779        boundary = { (0, 0): 'Third',
780                     (0, 2): 'First',
781                     (2, 0): 'Second',
782                     (2, 1): 'Second',
783                     (3, 1): 'Second',
784                     (3, 2): 'Third'}
785
786        domain = Generic_Domain(points, vertices, boundary,
787                        conserved_quantities =\
788                        ['stage', 'xmomentum', 'ymomentum'])
789        domain.set_default_order(1)                       
790        domain.check_integrity()
791
792        domain.set_quantity('stage', [[1,2,3], [5,5,5],
793                                      [0,0,9], [-6, 3, 3]])
794
795        assert num.allclose( domain.quantities['stage'].centroid_values,
796                             [2,5,3,0] )
797
798        domain.set_quantity('xmomentum', [[1,1,1], [2,2,2],
799                                          [3,3,3], [4, 4, 4]])
800
801        domain.set_quantity('ymomentum', [[10,10,10], [20,20,20],
802                                          [30,30,30], [40, 40, 40]])
803
804
805        domain.distribute_to_vertices_and_edges()
806
807        #First order extrapolation
808        assert num.allclose( domain.quantities['stage'].vertex_values,
809                             [[ 2.,  2.,  2.],
810                              [ 5.,  5.,  5.],
811                              [ 3.,  3.,  3.],
812                              [ 0.,  0.,  0.]])
813
814        domain.build_tagged_elements_dictionary({'mound':[0,1]})
815        domain.set_region([add_to_verts])
816
817        self.failUnless(domain.test == "Mound",
818                        'set region failed')
819
820                             
821    def test_rectangular_periodic_and_ghosts(self):
822
823        from mesh_factory import rectangular_periodic
824       
825
826        M=5
827        N=2
828        points, vertices, boundary, full_send_dict, ghost_recv_dict = rectangular_periodic(M, N)
829
830        assert num.allclose(ghost_recv_dict[0][0], [24, 25, 26, 27,  0,  1,  2,  3])
831        assert num.allclose(full_send_dict[0][0] , [ 4,  5,  6,  7, 20, 21, 22, 23])
832
833        conserved_quantities = ['quant1', 'quant2']
834        domain = Generic_Domain(points, vertices, boundary, conserved_quantities,
835                        full_send_dict=full_send_dict,
836                        ghost_recv_dict=ghost_recv_dict)
837
838
839
840
841        assert num.allclose(ghost_recv_dict[0][0], [24, 25, 26, 27,  0,  1,  2,  3])
842        assert num.allclose(full_send_dict[0][0] , [ 4,  5,  6,  7, 20, 21, 22, 23])
843
844        def xylocation(x,y):
845            return 15*x + 9*y
846
847       
848        domain.set_quantity('quant1',xylocation,location='centroids')
849        domain.set_quantity('quant2',xylocation,location='centroids')
850
851
852        assert num.allclose(domain.quantities['quant1'].centroid_values,
853                            [  0.5,   1.,   5.,    5.5,   3.5,   4.,    8.,    8.5,   6.5,  7.,   11.,   11.5,   9.5,
854                               10.,   14.,   14.5,  12.5,  13.,   17.,   17.5,  15.5,  16.,   20.,   20.5,
855                               18.5,  19.,   23.,   23.5])
856
857
858
859        assert num.allclose(domain.quantities['quant2'].centroid_values,
860                            [  0.5,   1.,   5.,    5.5,   3.5,   4.,    8.,    8.5,   6.5,  7.,   11.,   11.5,   9.5,
861                               10.,   14.,   14.5,  12.5,  13.,   17.,   17.5,  15.5,  16.,   20.,   20.5,
862                               18.5,  19.,   23.,   23.5])
863
864        domain.update_ghosts()
865
866
867        assert num.allclose(domain.quantities['quant1'].centroid_values,
868                            [  15.5,  16.,   20.,   20.5,   3.5,   4.,    8.,    8.5,   6.5,  7.,   11.,   11.5,   9.5,
869                               10.,   14.,   14.5,  12.5,  13.,   17.,   17.5,  15.5,  16.,   20.,   20.5,
870                                3.5,   4.,    8.,    8.5])
871
872
873
874        assert num.allclose(domain.quantities['quant2'].centroid_values,
875                            [  15.5,  16.,   20.,   20.5,   3.5,   4.,    8.,    8.5,   6.5,  7.,   11.,   11.5,   9.5,
876                               10.,   14.,   14.5,  12.5,  13.,   17.,   17.5,  15.5,  16.,   20.,   20.5,
877                                3.5,   4.,    8.,    8.5])
878
879       
880        assert num.allclose(domain.tri_full_flag, [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
881                                                   1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0])
882
883        #Test that points are arranged in a counter clock wise order
884        domain.check_integrity()
885
886
887    def test_that_mesh_methods_exist(self):
888        """test_that_mesh_methods_exist
889       
890        Test that relavent mesh methods are made available in
891        domain through composition
892        """
893        from mesh_factory import rectangular
894        from shallow_water import Domain
895
896        # Create basic mesh
897        points, vertices, boundary = rectangular(1, 3)
898
899        # Create shallow water domain
900        domain = Domain(points, vertices, boundary)                             
901       
902       
903        domain.get_centroid_coordinates()
904        domain.get_radii()
905        domain.get_areas()
906        domain.get_area()
907        domain.get_vertex_coordinates()
908        domain.get_triangles()
909        domain.get_nodes()
910        domain.get_number_of_nodes()
911        domain.get_normal(0,0)
912        domain.get_triangle_containing_point([0.4,0.5])
913        domain.get_intersecting_segments([[0.0, 0.0], [0.0, 1.0]])
914        domain.get_disconnected_triangles()
915        domain.get_boundary_tags()
916        domain.get_boundary_polygon()
917        #domain.get_number_of_triangles_per_node()
918        domain.get_triangles_and_vertices_per_node()
919        domain.get_interpolation_object()
920        domain.get_tagged_elements()
921        domain.get_lone_vertices()
922        domain.get_unique_vertices()
923        g = domain.get_georeference()
924        domain.set_georeference(g)
925        domain.build_tagged_elements_dictionary()
926        domain.statistics()
927        domain.get_extent()
928
929    def NOtest_vertex_within_hole(self):
930        """ NOTE: This test fails - it is designed to test fitting on
931            a mesh with a hole, but more info is needed on the specific
932            problem."""
933       
934        # For test_fitting_using_shallow_water_domain example
935        def linear_function(point):
936            point = num.array(point)
937            return point[:,0]+point[:,1]       
938       
939        meshname = 'test_mesh.msh'
940        verbose = False
941        W = 0
942        S = 0
943        E = 10
944        N = 10
945
946        bounding_polygon = [[W, S], [E, S], [E, N], [W, N]]
947        hole = [[[.1,.1], [9.9,1.1], [9.9,9.9], [1.1,9.9]]]
948
949        create_mesh_from_regions(bounding_polygon,
950                                 boundary_tags={'south': [0], 
951                                                'east': [1], 
952                                                'north': [2], 
953                                                'west': [3]},
954                                 maximum_triangle_area=1,
955                                 filename=meshname,
956                                 interior_holes = hole,
957                                 use_cache=False,
958                                 verbose=verbose)
959
960        domain = Domain(meshname, use_cache=False, verbose=verbose)
961        quantity = Quantity(domain)
962       
963         # Get (enough) datapoints (relative to georef)
964        data_points     = [[ 0.66666667, 0.66666667],
965                           [ 1.33333333, 1.33333333],
966                           [ 2.66666667, 0.66666667],
967                           [ 0.66666667, 2.66666667],
968                           [ 0.0,        1.0],
969                           [ 0.0,        3.0],
970                           [ 1.0,        0.0],
971                           [ 1.0,        1.0],
972                           [ 1.0,        2.0],
973                           [ 1.0,        3.0],
974                           [ 2.0,        1.0],
975                           [ 3.0,        0.0],
976                           [ 3.0,        1.0]]
977
978
979        attributes = linear_function(data_points)
980        att = 'spam_and_eggs'
981
982        # Create .txt file
983        ptsfile = "points.txt"
984        file = open(ptsfile, "w")
985        file.write(" x,y," + att + " \n")
986        for data_point, attribute in map(None, data_points, attributes):
987            row = (str(data_point[0]) + ',' +
988                   str(data_point[1]) + ',' +
989                   str(attribute))
990            file.write(row + "\n")
991        file.close()
992
993        # Check that values can be set from file
994        quantity.set_values(filename=ptsfile, attribute_name=att, alpha=0)
995        answer = linear_function(quantity.domain.get_vertex_coordinates())
996
997        assert num.allclose(quantity.vertex_values.flat, answer)
998
999        # Check that values can be set from file using default attribute
1000        quantity.set_values(filename = ptsfile, alpha = 0)
1001        assert num.allclose(quantity.vertex_values.flat, answer)       
1002       
1003       
1004
1005#-------------------------------------------------------------
1006
1007if __name__ == "__main__":
1008    suite = unittest.makeSuite(Test_Domain,'test')
1009    runner = unittest.TextTestRunner()
1010    runner.run(suite)
Note: See TracBrowser for help on using the repository browser.