source: anuga_core/source/anuga_parallel/parallel_shallow_water.py @ 7449

Last change on this file since 7449 was 7449, checked in by steve, 15 years ago

Testing unit tests

File size: 7.6 KB
Line 
1"""Class Parallel_shallow_water_domain -
22D triangular domains for finite-volume computations of
3the shallow water equation, with extra structures to allow
4communication between other Parallel_domains and itself
5
6This module contains a specialisation of class Domain
7from module shallow_water.py
8
9Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
10Geoscience Australia, 2004-2005
11
12"""
13
14from anuga.interface import Domain
15
16
17import numpy as num
18
19import pypar
20
21
22class Parallel_domain(Domain):
23
24    def __init__(self, coordinates, vertices,
25                 boundary=None,
26                 full_send_dict=None,
27                 ghost_recv_dict=None,
28                 number_of_full_nodes=None,
29                 number_of_full_triangles=None):
30
31        Domain.__init__(self,
32                        coordinates,
33                        vertices,
34                        boundary,
35                        full_send_dict=full_send_dict,
36                        ghost_recv_dict=ghost_recv_dict,
37                        processor=pypar.rank(),
38                        numproc=pypar.size(),
39                        number_of_full_nodes=number_of_full_nodes,
40                        number_of_full_triangles=number_of_full_triangles)
41
42        N = len(self) # number_of_triangles
43
44
45        # Buffers for synchronisation of timesteps
46        self.local_timestep = num.zeros(1, num.float)
47        self.global_timestep = num.zeros(1, num.float)
48
49        self.local_timesteps = num.zeros(self.numproc, num.float)
50
51
52        self.communication_time = 0.0
53        self.communication_reduce_time = 0.0
54        self.communication_broadcast_time = 0.0
55
56       
57
58
59    def set_name(self, name):
60        """Assign name based on processor number
61        """
62
63        if name.endswith('.sww'):
64            name = name[:-4]
65
66        # Call parents method with processor number attached.
67        Domain.set_name(self, name + '_P%d_%d' %(self.processor, self.numproc))
68
69
70    def check_integrity(self):
71        Domain.check_integrity(self)
72
73        msg = 'Will need to check global and local numbering'
74        assert self.conserved_quantities[0] == 'stage', msg
75        assert self.conserved_quantities[1] == 'xmomentum', msg
76        assert self.conserved_quantities[2] == 'ymomentum', msg
77
78
79    def update_timestep_1(self, yieldstep, finaltime):
80        """Calculate local timestep using broadcasts
81        """
82
83        #LINDA:
84        # Moved below so timestep is found before doing update
85       
86        #Domain.update_timestep(self, yieldstep, finaltime)
87
88        import time
89
90
91        t0 = time.time()
92
93        #Broadcast local timestep from every processor to every other
94        for pid in range(self.numproc):
95            #print 'P%d calling broadcast from %d' %(self.processor, pid)
96            self.local_timestep[0] = self.flux_timestep
97            pypar.broadcast(self.local_timestep, pid, bypass=True)
98            self.local_timesteps[pid] = self.local_timestep[0]
99
100        self.flux_timestep = min(self.local_timesteps)
101
102        #print 'Flux Timestep %d P%d_%d' %(self.flux_timestep, self.processor, self.numproc)
103
104        pypar.barrier()
105        self.communication_broadcast_time += time.time()-t0
106
107        # LINDA:
108        # Moved timestep to here
109       
110        Domain.update_timestep(self, yieldstep, finaltime)
111
112
113    def update_timestep(self, yieldstep, finaltime):
114        """Calculate local timestep
115        """
116
117        # LINDA: Moved below so timestep is updated before
118        # calculating statistic
119       
120        #Compute minimal timestep on local process
121        #Domain.update_timestep(self, yieldstep, finaltime)
122
123        pypar.barrier()
124
125        import time
126
127        #Compute minimal timestep across all processes
128        self.local_timestep[0] = self.flux_timestep
129        use_reduce_broadcast = True
130        if use_reduce_broadcast:
131            t0 = time.time()
132            pypar.reduce(self.local_timestep, pypar.MIN, 0,
133                         buffer=self.global_timestep)#,
134                         #bypass=True)
135
136        else:
137            #Alternative: Try using straight send and receives
138            t0 = time.time()
139            self.global_timestep[0] = self.flux_timestep
140
141            if self.processor == 0:
142                for i in range(1, self.numproc):
143                    pypar.receive(i,
144                                  buffer=self.local_timestep)
145
146                    if self.local_timestep[0] < self.global_timestep[0]:
147                        self.global_timestep[0] = self.local_timestep[0]
148            else:
149                pypar.send(self.local_timestep, 0,
150                           use_buffer=True)
151
152
153        self.communication_reduce_time += time.time()-t0
154
155
156        #Broadcast minimal timestep to all
157        t0 = time.time()
158        pypar.broadcast(self.global_timestep, 0)#,
159                        #bypass=True)
160
161        self.communication_broadcast_time += time.time()-t0
162
163        old_timestep = self.flux_timestep
164        self.flux_timestep = self.global_timestep[0]
165        #print 'Flux Timestep %15.5e %15.5e P%d_%d' %(self.flux_timestep, old_timestep, self.processor, self.numproc)
166       
167        # LINDA:
168        # update local stats now
169       
170        #Compute minimal timestep on local process
171        Domain.update_timestep(self, yieldstep, finaltime)
172
173        # FIXME (Ole) We should update the variable min_timestep for use
174        # with write_time (or redo write_time)
175
176    #update_timestep = update_timestep_1
177
178    def update_ghosts(self):
179
180        # We must send the information from the full cells and
181        # receive the information for the ghost cells
182        # We have a dictionary of lists with ghosts expecting updates from
183        # the separate processors
184
185        import numpy as num
186        import time
187        t0 = time.time()
188
189        # update of non-local ghost cells
190        for iproc in range(self.numproc):
191            if iproc == self.processor:
192                #Send data from iproc processor to other processors
193                for send_proc in self.full_send_dict:
194                    if send_proc != iproc:
195
196                        Idf  = self.full_send_dict[send_proc][0]
197                        Xout = self.full_send_dict[send_proc][2]
198
199                        for i, q in enumerate(self.conserved_quantities):
200                            #print 'Send',i,q
201                            Q_cv =  self.quantities[q].centroid_values
202                            Xout[:,i] = num.take(Q_cv, Idf)
203
204                        pypar.send(Xout, int(send_proc), use_buffer=True)
205
206
207            else:
208                #Receive data from the iproc processor
209                if  self.ghost_recv_dict.has_key(iproc):
210
211                    Idg = self.ghost_recv_dict[iproc][0]
212                    X   = self.ghost_recv_dict[iproc][2]
213
214                    X = pypar.receive(int(iproc), buffer=X)
215
216                    for i, q in enumerate(self.conserved_quantities):
217                        #print 'Receive',i,q
218                        Q_cv =  self.quantities[q].centroid_values
219                        num.put(Q_cv, Idg, X[:,i])
220
221        #local update of ghost cells
222        iproc = self.processor
223        if self.full_send_dict.has_key(iproc):
224
225            # LINDA:
226            # now store full as local id, global id, value
227            Idf  = self.full_send_dict[iproc][0]
228
229            # LINDA:
230            # now store ghost as local id, global id, value
231            Idg = self.ghost_recv_dict[iproc][0]
232
233            for i, q in enumerate(self.conserved_quantities):
234                #print 'LOCAL SEND RECEIVE',i,q
235                Q_cv =  self.quantities[q].centroid_values
236                num.put(Q_cv, Idg, num.take(Q_cv, Idf))
237
238        self.communication_time += time.time()-t0
239
Note: See TracBrowser for help on using the repository browser.