source: anuga_core/source/anuga_parallel/parallel_shallow_water.py @ 3906

Last change on this file since 3906 was 3893, checked in by ole, 19 years ago

Allowed set_name to accept extension '.sww'

File size: 9.3 KB
Line 
1"""Class Parallel_Shallow_Water_Domain -
22D triangular domains for finite-volume computations of
3the shallow water equation, with extra structures to allow
4communication between other Parallel_Domains and itself
5
6This module contains a specialisation of class Domain
7from module shallow_water.py
8
9Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
10Geoscience Australia, 2004-2005
11
12"""
13
14import logging, logging.config
15logger = logging.getLogger('parallel')
16logger.setLevel(logging.WARNING)
17
18try:
19    logging.config.fileConfig('log.ini')
20except:
21    pass
22
23
24from anuga.shallow_water.shallow_water_domain import *
25from Numeric import zeros, Float, Int, ones, allclose, array
26
27import pypar
28
29
30class Parallel_Domain(Domain):
31
32    def __init__(self, coordinates, vertices, boundary = None,
33                 full_send_dict = None, ghost_recv_dict = None):
34
35        Domain.__init__(self,
36                        coordinates,
37                        vertices,
38                        boundary,
39                        full_send_dict=full_send_dict,
40                        ghost_recv_dict=ghost_recv_dict,
41                        processor=pypar.rank(),
42                        numproc=pypar.size())
43
44        N = self.number_of_elements
45
46#        self.processor = pypar.rank()
47#        self.numproc   = pypar.size()
48#
49#        # Setup Communication Buffers
50#        self.nsys = 3
51#        for key in full_send_dict:
52#            buffer_shape = full_send_dict[key][0].shape[0]
53#            full_send_dict[key].append(zeros( (buffer_shape,self.nsys) ,Float))
54#
55#
56#        for key in ghost_recv_dict:
57#            buffer_shape = ghost_recv_dict[key][0].shape[0]
58#            ghost_recv_dict[key].append(zeros( (buffer_shape,self.nsys) ,Float))
59#
60#        self.full_send_dict  = full_send_dict
61        self.ghost_recv_dict = ghost_recv_dict
62
63        # Buffers for synchronisation of timesteps
64        self.local_timestep = zeros(1, Float)
65        self.global_timestep = zeros(1, Float)
66
67        self.local_timesteps = zeros(self.numproc, Float)
68
69
70        self.communication_time = 0.0
71        self.communication_reduce_time = 0.0
72        self.communication_broadcast_time = 0.0
73
74       
75
76
77    def set_name(self, name):
78        """Assign name based on processor number
79        """
80
81        if name.endswith('.sww'):
82            name = name[:-4]
83
84        # Call parents method with processor number attached.
85        Domain.set_name(self, name + '_P%d_%d' %(self.processor, self.numproc))
86
87
88    def check_integrity(self):
89        Domain.check_integrity(self)
90
91        msg = 'Will need to check global and local numbering'
92        assert self.conserved_quantities[0] == 'stage', msg
93        assert self.conserved_quantities[1] == 'xmomentum', msg
94        assert self.conserved_quantities[2] == 'ymomentum', msg
95
96
97    def update_timestep_1(self, yieldstep, finaltime):
98        """Calculate local timestep using broadcasts
99        """
100
101        #LINDA:
102        # Moved below so timestep is found before doing update
103       
104        #Domain.update_timestep(self, yieldstep, finaltime)
105
106        import time
107
108
109        t0 = time.time()
110
111        #Broadcast local timestep from every processor to every other
112        for pid in range(self.numproc):
113            #print 'P%d calling broadcast from %d' %(self.processor, pid)
114            self.local_timestep[0] = self.timestep
115            pypar.broadcast(self.local_timestep, pid, bypass=True)
116            self.local_timesteps[pid] = self.local_timestep[0]
117
118        self.timestep = min(self.local_timesteps)
119
120        pypar.barrier()
121        self.communication_broadcast_time += time.time()-t0
122
123        # LINDA:
124        # Moved timestep to here
125       
126        Domain.update_timestep(self, yieldstep, finaltime)
127
128
129    def update_timestep(self, yieldstep, finaltime):
130        """Calculate local timestep
131        """
132
133        # LINDA: Moved below so timestep is updated before
134        # calculating statistic
135       
136        #Compute minimal timestep on local process
137        #Domain.update_timestep(self, yieldstep, finaltime)
138
139        pypar.barrier()
140
141        import time
142        #Compute minimal timestep across all processes
143        self.local_timestep[0] = self.timestep
144        use_reduce_broadcast = True
145        if use_reduce_broadcast:
146            t0 = time.time()
147            pypar.reduce(self.local_timestep, pypar.MIN, 0,
148                         buffer=self.global_timestep,
149                         bypass=True)
150
151        else:
152            #Alternative: Try using straight send and receives
153            t0 = time.time()
154            self.global_timestep[0] = self.timestep
155
156            if self.processor == 0:
157                for i in range(1, self.numproc):
158                    pypar.receive(i,
159                                  buffer=self.local_timestep,
160                                  bypass=True)
161
162                    if self.local_timestep[0] < self.global_timestep[0]:
163                        self.global_timestep[0] = self.local_timestep[0]
164            else:
165                pypar.send(self.local_timestep, 0,
166                           use_buffer=True, bypass=True)
167
168
169        self.communication_reduce_time += time.time()-t0
170
171
172        #Broadcast minimal timestep to all
173        t0 = time.time()
174        pypar.broadcast(self.global_timestep, 0,
175                        bypass=True)
176
177        self.communication_broadcast_time += time.time()-t0
178
179
180        self.timestep = self.global_timestep[0]
181       
182        # LINDA:
183        # update local stats now
184       
185        #Compute minimal timestep on local process
186        Domain.update_timestep(self, yieldstep, finaltime)
187
188        # FIXME (Ole) We should update the variable min_timestep for use
189        # with write_time (or redo write_time)
190
191    #update_timestep = update_timestep_1
192
193    def update_ghosts(self):
194
195        # We must send the information from the full cells and
196        # receive the information for the ghost cells
197        # We have a dictionary of lists with ghosts expecting updates from
198        # the separate processors
199
200
201        from Numeric import take,put
202        import time
203        t0 = time.time()
204
205        # update of non-local ghost cells
206        for iproc in range(self.numproc):
207            if iproc == self.processor:
208                #Send data from iproc processor to other processors
209                for send_proc in self.full_send_dict:
210                    if send_proc != iproc:
211
212                        Idf  = self.full_send_dict[send_proc][0]
213                        Xout = self.full_send_dict[send_proc][2]
214
215                        for i, q in enumerate(self.conserved_quantities):
216                            #print 'Send',i,q
217                            Q_cv =  self.quantities[q].centroid_values
218                            Xout[:,i] = take(Q_cv, Idf)
219
220                        pypar.send(Xout, send_proc,
221                                   use_buffer=True, bypass = True)
222
223
224            else:
225                #Receive data from the iproc processor
226                if  self.ghost_recv_dict.has_key(iproc):
227
228                    Idg = self.ghost_recv_dict[iproc][0]
229                    X = self.ghost_recv_dict[iproc][2]
230
231                    X = pypar.receive(iproc, buffer=X, bypass = True)
232
233                    for i, q in enumerate(self.conserved_quantities):
234                        #print 'Receive',i,q
235                        Q_cv =  self.quantities[q].centroid_values
236                        put(Q_cv, Idg, X[:,i])
237
238        #local update of ghost cells
239        iproc = self.processor
240        if self.full_send_dict.has_key(iproc):
241
242            # LINDA:
243            # now store full as local id, global id, value
244            Idf  = self.full_send_dict[iproc][0]
245
246            # LINDA:
247            # now store ghost as local id, global id, value
248            Idg = self.ghost_recv_dict[iproc][0]
249
250            for i, q in enumerate(self.conserved_quantities):
251                #print 'LOCAL SEND RECEIVE',i,q
252                Q_cv =  self.quantities[q].centroid_values
253                put(Q_cv,     Idg, take(Q_cv,     Idf))
254
255        self.communication_time += time.time()-t0
256
257
258    def write_time(self):
259        if self.min_timestep == self.max_timestep:
260            print 'Processor %d/%d, Time = %.4f, delta t = %.8f, steps=%d (%d)'\
261                  %(self.processor, self.numproc,
262                    self.time, self.min_timestep, self.number_of_steps,
263                    self.number_of_first_order_steps)
264        elif self.min_timestep > self.max_timestep:
265            print 'Processor %d/%d, Time = %.4f, steps=%d (%d)'\
266                  %(self.processor, self.numproc,
267                    self.time, self.number_of_steps,
268                    self.number_of_first_order_steps)
269        else:
270            print 'Processor %d/%d, Time = %.4f, delta t in [%.8f, %.8f], steps=%d (%d)'\
271                  %(self.processor, self.numproc,
272                    self.time, self.min_timestep,
273                    self.max_timestep, self.number_of_steps,
274                    self.number_of_first_order_steps)
275
276
277    def evolve(self, yieldstep = None, finaltime = None):
278        """Specialisation of basic evolve method from parent class
279        """
280
281        #Initialise real time viz if requested
282        if self.time == 0.0:
283            pass
284
285        #Call basic machinery from parent class
286        for t in Domain.evolve(self, yieldstep, finaltime):
287
288            #Pass control on to outer loop for more specific actions
289            yield(t)
Note: See TracBrowser for help on using the repository browser.