source: anuga_core/source/pypar-numeric/pypar.py @ 7695

Last change on this file since 7695 was 5838, checked in by steve, 16 years ago

Updated setup.py

File size: 24.5 KB
Line 
1# =============================================================================
2# pypar.py - Parallel Python using MPI
3# Copyright (C) 2001, 2002, 2003 Ole M. Nielsen
4#              (Center for Mathematics and its Applications ANU and APAC)
5#
6#    This program is free software; you can redistribute it and/or modify
7#    it under the terms of the GNU General Public License as published by
8#    the Free Software Foundation; either version 2 of the License, or
9#    (at your option) any later version.
10#
11#    This program is distributed in the hope that it will be useful,
12#    but WITHOUT ANY WARRANTY; without even the implied warranty of
13#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14#    GNU General Public License (http://www.gnu.org/copyleft/gpl.html)
15#    for more details.
16#
17#    You should have received a copy of the GNU General Public License
18#    along with this program; if not, write to the Free Software
19#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307
20#
21#
22# Contact address: Ole.Nielsen@anu.edu.au
23#
24# Version: See pypar.__version__
25# =============================================================================
26
27"""Module pypar.py - Parallel Python using MPI
28
29Public functions:
30
31size() -- Number of processors
32rank() -- Id of current processor
33get_processor_name() -- Return host name of current node
34
35send() -- Blocking send (all types)
36receive() -- Blocking receive (all types)
37broadcast() -- Broadcast
38time() -- MPI wall time
39barrier() -- Synchronisation point. Makes processors wait until all processors
40             have reached this point.
41abort() -- Terminate all processes.
42finalize() -- Cleanup MPI. No parallelism can take place after this point.
43
44
45See doc strings of individual functions for detailed documentation.
46"""
47
48# Meta data
49from __metadata__ import __version__, __date__, __author__
50
51
52# Constants
53#
54max_tag = 32767      # Max tag value (MPI_TAG_UB didn't work and returned 0)
55control_tag = 13849  # Reserved tag used to identify control information
56default_tag = 1      # Tag used as default if not specified
57
58control_sep = ':'          # Separator for fields in control info (NOT ',')
59control_data_max_size = 64 # Maximal size of string holding control data
60
61
62#---------------------------------------------------------------------------
63# Communication functions
64#--------------------------------------------------------------------------
65
66def send(x, destination, use_buffer=False, vanilla=False,
67         tag=default_tag, bypass=False):
68    """Wrapper for easy MPI send.
69       Send x to destination.
70       
71       Automatically determine appropriate protocol
72       and call corresponding send function.
73       Also passes type and size information on as preceding message to
74       simplify the receive call.
75       
76       The variable x can be any (picklable) type, but
77       Numeric variables and text strings will most efficient.
78       Setting vanilla = 1 forces vanilla mode for any type.
79
80       If bypass is True, all admin and error checks
81       get bypassed to reduce the latency. Should only
82       be used for sending Numeric arrays and should be matched
83       with a bypass in the corresponding receive command.
84
85    """
86    import types, string
87
88    if bypass:
89        send_array(x, destination, tag)
90        return
91       
92    #Input check
93    errmsg = 'Destination id (%s) must be an integer.' %destination
94    assert type(destination) == types.IntType, errmsg
95   
96    errmsg = 'Tag %d is reserved by pypar - please use another.' %control_tag
97    assert tag != control_tag, errmsg
98
99    #Create metadata about object to be sent
100    control_info, x = create_control_info(x, vanilla, return_object=True)
101    protocol = control_info[0]
102
103
104    #Possibly transmit control data
105    if use_buffer is False:
106        send_control_info(control_info, destination)   
107
108     
109    #Transmit payload data   
110    if protocol == 'array':
111        send_array(x, destination, tag)   
112    elif protocol in ['string', 'vanilla']:
113        send_string(x, destination, tag)         
114    else:
115        raise 'Unknown protocol: %s' %protocol   
116
117     
118def receive(source, buffer=None, vanilla=False, tag=default_tag,
119            return_status=False, bypass=False):           
120    """receive - blocking MPI receive
121   
122       Receive data from source.
123
124       Optional parameters:
125         buffer: Use specified buffer for received data (faster). Default None.
126         vanilla: Specify to enforce vanilla protocol for any type. Default False
127         tag: Only received messages tagged as specified. Default default_tag
128         return_status: Return Status object along with result. Default False.
129
130       If no buffer is specified, receive will try to receive a
131       preceding message containing protocol, type, size and shape and
132       then create a suitable buffer.
133
134       If buffer is specified the corresponding send must specify
135       use_buffer = True.
136       The variable buffer can be any (picklable) type, but
137       Numeric variables and text strings will most efficient.
138
139       Appropriate protocol will be automatically determined
140       and corresponding receive function called.
141
142
143       If bypass is True, all admin and error checks
144       get bypassed to reduce the latency. Should only
145       be used for receiving Numerical arrays and should
146       be matched with a bypass in the corresponding send command.
147       Also buffer must be specified.
148    """
149
150    if bypass:
151        #errmsg = 'bypass mode must be used with specified buffer'
152        #assert buffer is not None, msg
153        stat = receive_array(buffer, source, tag)       
154    else:   
155   
156        import types 
157   
158        #Input check
159        errmsg = 'Source id (%s) must be an integer.' %source
160        assert type(source) == types.IntType, errmsg
161   
162        errmsg = 'Tag %d is reserved by pypar - please use another.' %control_tag
163        assert tag != control_tag, errmsg
164   
165   
166        #Either receive or create metadata about objetc to receive
167        if buffer is None:
168            protocol, typecode, size, shape = receive_control_info(source)
169        else: 
170            protocol, typecode, size, shape = create_control_info(buffer, vanilla)
171   
172   
173        #Receive payload data     
174        if protocol == 'array':
175            if buffer is None:
176                import Numeric
177                buffer = Numeric.zeros(size,typecode)
178                buffer = Numeric.reshape(buffer, shape)
179           
180            stat = receive_array(buffer, source, tag)
181           
182        elif protocol == 'string':
183            if buffer is None:
184                buffer = ' '*size
185           
186            stat = receive_string(buffer, source, tag)
187   
188        elif protocol == 'vanilla':
189            from cPickle import dumps, loads     
190            if buffer is None:
191                s = ' '*size     
192            else:
193                s = dumps(buffer, 1)
194                s = s + ' '*int(0.1*len(s)) #safety
195           
196            stat = receive_string(s, source, tag)
197            buffer = loads(s)  #Replace buffer with received result
198        else:
199            raise 'Unknown protocol: %s' %protocol
200
201    # Return received data and possibly the status object 
202    if return_status:
203        return buffer, Status(stat)
204    else:
205        return buffer
206
207
208def broadcast(buffer, root, vanilla=False, bypass=False):
209    """Wrapper for MPI bcast.
210
211       Broadcast buffer from the process with rank root to all other processes.
212
213   
214       Automatically determine appropriate protocol
215       and call corresponding send function.
216       
217       The variable buffer can be any (picklable) type, but
218       Numeric variables and text strings will most efficient.
219       Setting vanilla = 1 forces vanilla mode for any type.
220
221       If bypass is True, all admin and error checks
222       get bypassed to reduce the latency.
223
224    """
225
226    if bypass:
227        broadcast_array(buffer, root)
228        return
229   
230
231    import types
232   
233    #Input check
234    errmsg = 'Root id (%s) must be an integer.' %root
235    assert type(root) == types.IntType, errmsg
236
237
238    #Create metadata about object to be sent
239    protocol = create_control_info(buffer, vanilla)[0]
240
241
242    #Broadcast
243    if protocol == 'array':
244        broadcast_array(buffer, root)   
245    elif protocol == 'string':
246        broadcast_string(buffer, root)         
247    elif protocol == 'vanilla':
248        from cPickle import loads, dumps
249        s = dumps(buffer, 1)
250        s = s + ' '*int(0.1*len(s)) #safety
251       
252        broadcast_string(s, root)
253        buffer = loads(s)
254    else:
255        raise 'Unknown protocol: %s' %protocol 
256     
257    return buffer         
258
259
260def scatter(x, root, buffer=None, vanilla=False):
261    """Sends data x from process with rank root to all other processes.
262   
263       Create appropriate buffer and receive data.
264       Return scattered result (same type as x)
265
266       Scatter makes only sense for arrays or strings
267    """
268
269    import types
270    from mpiext import size
271    numproc = size()         #Needed to determine buffer size   
272   
273    #Input check
274    errmsg = 'Root id (%s) must be an integer.' %root
275    assert type(root) == types.IntType, errmsg
276
277   
278    #Create metadata about object to be sent
279    protocol, typecode, size, shape = create_control_info(x)
280
281    #Scatter
282    if protocol == 'array':
283        if buffer is None:
284            import Numeric
285           
286            # Modify shape along axis=0 to match size
287            shape = list(shape)
288            shape[0] /= numproc
289            count = Numeric.product(shape)           
290           
291            buffer = Numeric.zeros(count, typecode)
292            buffer = Numeric.reshape(buffer, shape)
293     
294        scatter_array(x, buffer, root)
295    elif protocol == 'string':
296        if buffer is None:
297            buffer = ' '*(size/numproc)
298           
299        scatter_string(x, buffer, root)
300    elif protocol == 'vanilla':
301        errmsg = 'Scatter is only supported for Numeric arrays and strings.\n'
302        errmsg += 'If you wish to distribute a general sequence, '
303        errmsg += 'please use send and receive commands or broadcast.'
304        raise errmsg
305    else:
306        raise 'Unknown protocol: %s' %protocol
307       
308    return buffer 
309
310
311def gather(x, root, buffer=None, vanilla=0):
312    """Gather values from all processes to root
313       
314       Create appropriate buffer and receive data.
315
316       Gather only makes sens for arrays or strings
317    """
318
319    import types     
320    from mpiext import size
321    numproc = size()         #Needed to determine buffer size
322
323    #Input check
324    errmsg = 'Root id (%s) must be an integer.' %root
325    assert type(root) == types.IntType, errmsg
326
327    #Create metadata about object to be gathered
328    protocol, typecode, size, shape = create_control_info(x)
329
330    #Gather
331    if protocol == 'array':
332        if buffer is None:
333            import Numeric
334            buffer = Numeric.zeros(size*numproc, typecode)
335
336            # Modify shape along axis=0 to match size
337            shape = list(shape)
338            shape[0] *= numproc
339            buffer = Numeric.reshape(buffer, shape)
340     
341        gather_array(x, buffer, root)   
342    elif protocol == 'string':
343        if buffer is None:
344            buffer = ' '*size*numproc
345       
346        gather_string(x, buffer, root)         
347    elif protocol == 'vanilla':
348        errmsg = 'Gather is only supported for Numeric arrays and strings.\n'
349        errmsg += 'If you wish to distribute a general sequence, '
350        errmsg += 'please use send and receive commands or broadcast.'
351        raise errmsg
352    else:
353        raise 'Unknown protocol: %s' %protocol
354       
355    return buffer 
356
357
358def reduce(x, op, root, buffer=None, vanilla=0, bypass=False):
359    """Reduce elements in x to buffer (of the same size as x)
360       at root applying operation op elementwise.
361
362       If bypass is True, all admin and error checks
363       get bypassed to reduce the latency.
364       The buffer must be specified explicitly in this case.
365    """
366
367
368
369    if bypass:
370        reduce_array(x, buffer, op, root)
371        return
372   
373
374    import types   
375    from mpiext import size
376    numproc = size()         #Needed to determine buffer size
377
378
379    #Input check
380    errmsg = 'Root id (%s) must be an integer.' %root
381    assert type(root) == types.IntType, errmsg
382
383    #Create metadata about object
384    protocol, typecode, size, shape = create_control_info(x)
385
386
387    #Reduce
388    if protocol == 'array':
389        if buffer is None:
390            import Numeric
391            buffer = Numeric.zeros(size*numproc, typecode)
392     
393            # Modify shape along axis=0 to match size
394            shape = list(shape)
395            shape[0] *= numproc
396            buffer = Numeric.reshape(buffer, shape)
397     
398        reduce_array(x, buffer, op, root)   
399    elif (protocol == 'vanilla' or protocol == 'string'):
400        raise 'Protocol: %s unsupported for reduce' %protocol
401    else:
402        raise 'Unknown protocol: %s' %protocol
403     
404    return buffer 
405
406
407#---------------------------------------------------------
408# AUXILIARY FUNCTIONS
409#---------------------------------------------------------
410def balance(N, P, p):
411    """Compute p'th interval when N is distributed over P bins.
412
413    This function computes boundaries of sub intervals of [0:N] such
414    that they are almost equally sized with their sizes differening
415    by no more than 1.
416   
417    As such, this function  is suitable for partitioning an interval equally
418    across P processors.
419   
420    Inputs:
421       N: Upper bound of full interval.
422       P: Total number of processors
423       p: Local processor id       
424
425       
426    Outputs:
427       Nlo: Lower bound of p'th sub-interval   
428       Nhi: Upper bound of p'th sub-interval           
429     
430     
431    Example:
432       To partition the interval [0:29] among 4 processors:
433       
434       Nlo, Nhi = pypar.balance(29, 4, p)
435       
436       with p in [0,1,2,3]
437       
438       and the subintervals are
439       
440       p          Nlo      Nhi
441       -----------------------
442       0           0        8   
443       1           8       15
444       2          15       22
445       3          22       29
446       
447                 
448       
449    Note that the interval bounds following the Python convention of
450    list slicing such that the last element of Nlo:Nhi is, in fact, Nhi-1
451    """
452
453    from math import floor
454
455    L = int(floor(float(N)/P))
456    K = N - P*L
457    if p < K:
458        Nlo = p*L + p
459        Nhi = Nlo + L + 1
460    else:
461        Nlo = p*L + K
462        Nhi = Nlo + L
463
464    return Nlo, Nhi
465
466
467# Obsolete functions
468# (for backwards compatibility - remove in version 2.0)
469
470def raw_send(x, destination, tag=default_tag, vanilla=0):
471    send(x, destination, use_buffer=True, tag=tag, vanilla=vanilla) 
472
473
474def raw_receive(x, source, tag=default_tag, vanilla=0, return_status=0):
475    x = receive(source, tag=tag, vanilla=vanilla,
476              return_status=return_status, buffer=x)
477    return x
478
479def raw_scatter(x, buffer, source, vanilla=0):
480    scatter(x, source, buffer=buffer, vanilla=vanilla)
481
482def raw_gather(x, buffer, source, vanilla=0):
483    gather(x, source, buffer=buffer, vanilla=0) 
484
485def raw_reduce(x, buffer, op, source, vanilla=0):
486    reduce(x, op, source, buffer=buffer, vanilla=0)
487
488def bcast(buffer, root, vanilla=False):
489    return broadcast(buffer, root, vanilla)
490
491def Wtime():
492    return time()
493
494def Get_processor_name():
495    return get_processor_name()
496
497def Initialized():
498    return initialized()
499
500def Finalize():
501    finalize()
502
503def Abort():
504    abort()
505
506def Barrier():
507    barrier()
508   
509     
510
511#---------------------------------------------------------
512# INTERNAL FUNCTIONS
513#---------------------------------------------------------
514
515class Status:
516    """ MPI Status block returned by receive if
517        specified with parameter return_status=True
518    """   
519 
520    def __init__(self, status_tuple):
521        self.source = status_tuple[0]  #Id of sender
522        self.tag = status_tuple[1]     #Tag of received message
523        self.error = status_tuple[2]   #MPI Error code
524        self.length = status_tuple[3]  #Number of elements transmitted
525        self.size = status_tuple[4]    #Size of one element
526
527    def __repr__(self):
528        return 'Pypar Status Object:\n  source=%d\n  tag=%d\n '+\
529               'error=%d\n  length=%d\n  size=%d\n'\
530               %(self.source, self.tag, self.error, self.length, self.size)
531 
532    def bytes(self):
533        """Number of bytes transmitted (excl control info)
534        """
535        return self.length * self.size
536 
537
538
539def create_control_info(x, vanilla=0, return_object=False):
540    """Determine which protocol to use for communication:
541       (Numeric) arrays, strings, or vanilla based x's type.
542
543       There are three protocols:
544       'array':   Numeric arrays of type 'i', 'l', 'f', 'd', 'F' or 'D' can be
545                  communicated  with mpiext.send_array and mpiext.receive_array.
546       'string':  Text strings can be communicated with mpiext.send_string and
547                  mpiext.receive_string.
548       'vanilla': All other types can be communicated as string representations
549                  provided that the objects
550                  can be serialised using pickle (or cPickle).
551                  The latter mode is less efficient than the
552                  first two but it can handle general structures.
553
554       Rules:
555       If keyword argument vanilla == 1, vanilla is chosen regardless of
556       x's type.
557       Otherwise if x is a string, the string protocol is chosen
558       If x is an array, the 'array' protocol is chosen provided that x has one
559       of the admissible typecodes.
560
561       The optional argument return_object asks to return object as well.
562       This is useful in case it gets modified as in the case of general structures
563       using the vanilla protocol.
564    """
565
566    import types
567
568    #Default values
569    protocol = 'vanilla'
570    typecode = ' '
571    size = 0
572    shape = ()
573
574    #Determine protocol in case
575    if not vanilla:
576        if type(x) == types.StringType:
577            protocol = 'string'
578            typecode = 'c'
579            size = len(x)
580        elif type(x).__name__ == 'array': #Numeric isn't imported yet
581            try:
582                import Numeric
583            except:
584                print "WARNING (pypar.py): Numeric module could not be imported,",
585                print "reverting to vanilla mode"
586                protocol = 'vanilla'
587            else: 
588                typecode = x.typecode() 
589                if typecode in ['i', 'l', 'f', 'd', 'F', 'D']:
590                    protocol = 'array'
591                    shape = x.shape
592                    size = Numeric.product(shape)
593                else:   
594                    print "WARNING (pypar.py): Numeric object type %s is not supported."\
595                          %(x.typecode())
596                    print "Only types 'i', 'l', 'f', 'd', 'F', 'D' are supported,",
597                    print "Reverting to vanilla mode."
598                    protocol = 'vanilla'
599
600    #Pickle general structures using the vanilla protocol               
601    if protocol == 'vanilla':                   
602        from cPickle import dumps     
603        x = dumps(x, 1)
604        size = len(x) # Let count be length of pickled object
605
606    #Return   
607    if return_object:
608        return [protocol, typecode, size, shape], x
609    else: 
610        return [protocol, typecode, size, shape]
611
612
613
614#----------------------------------------------
615
616def send_control_info(control_info, destination):
617    """Send control info to destination
618    """
619    import string
620
621    #Convert to strings
622    control_info = [str(c) for c in control_info]
623 
624    control_msg = string.join(control_info,control_sep)
625    if len(control_msg) > control_data_max_size:
626        errmsg = 'Length of control_info exceeds specified maximium (%d)'\
627                 %control_data_max_size
628        errmsg += ' - Please increase it (in pypar.py)' 
629        raise errmsg
630 
631    send_string(control_msg, destination, control_tag)
632
633 
634def receive_control_info(source):
635    """Receive control info from source
636    """
637    import string
638 
639    msg = ' '*control_data_max_size
640
641    stat = receive_string(msg, source, control_tag)
642    #No need to create status object here - it is reserved
643    #for payload communications only
644
645    msg = msg[:stat[3]] #Trim buffer to actual received length (needed?)
646
647    control_info = msg.split(control_sep)
648
649    assert len(control_info) == 4, 'len(control_info) = %d' %len(control_info)
650    control_info[2] = eval(control_info[2]) #Convert back to int
651    control_info[3] = eval(control_info[3]) #Convert back to tuple
652
653
654    return control_info
655
656
657#----------------------------------------------------------------------------
658# Initialise module
659#----------------------------------------------------------------------------
660
661
662# Take care of situation where module is part of package
663import sys, os, string, os.path
664dirname = os.path.dirname(string.replace(__name__,'.',os.sep)).strip()
665
666if not dirname:
667    dirname = '.'
668   
669if dirname[-1] != os.sep:
670    dirname += os.sep   
671
672
673
674# Import MPI extension
675#
676# Verify existence of mpiext.so.
677
678try:
679    import mpiext
680except:
681    errmsg = 'ERROR: C extension mpiext could not be imported.\n'
682    errmsg += 'Please compile mpiext.c e.g. by running\n'
683    errmsg += '  python install.py\n'
684    errmsg += 'in the pypar directory, or by using\n'
685    errmsg += '  python setup.py install\n'
686    #raise Exception, errmsg
687    error = 1
688    print errmsg
689else:
690 
691    # Determine if MPI program is allowed to run sequentially on current platform
692    # Attempting to check this automatically may case some systems to hang.
693
694    if sys.platform in ['linux2', 'sunos5', 'win32', 'darwin']:
695        #Linux (LAM,MPICH) or Sun (MPICH)
696        error = 0  #Sequential execution of MPI is allowed   
697    else:
698        #Platform: Alpha 'osf1V5' 
699        cmdstring = '"import mpiext, sys; mpiext.init(sys.argv); mpiext.finalize()"'
700        #s = 'cd %s; python -c %s' %(dirname, cmdstring)
701        s = 'python -c %s >/dev/null 2>/dev/null' %cmdstring 
702        error = os.system(s)
703     
704        # The check is performed in a separate shell.
705        # Reason: The Alpha server, LAM/Linux or the Sun cannot recover from a
706        # try:
707        #   mpiext.init(sys.argv)
708
709        # However, on LAM/Linux, this test causes system to hang.
710        # Verified (OMN 12/12/2)
711        # If lamboot is started, the system, will hang when init is called
712        # again further down in this file.
713        # If lamboot is not loaded error will be nozero as it should.
714        # I don't know how to deal with this
715        #
716        #Comparisons of two strategies using LAM
717        #
718        # Strategy 1: Assume seq execution is OK (i.e. set error = 0)
719        # Strategy 2: Try to test if mpi can be initialised (in a separate shell)
720        #
721        #
722        # Strategy 1 (currently used)
723        #                    | Lam booted  | Lam not booted
724        #-----------------------------------------------------
725        #
726        # Sequential exec    |  OK         | Not OK
727        # Parallel exec      |  OK         | Not OK 
728        #
729        #
730        # Strategy 2
731        #                    | Lam booted  | Lam not booted
732        #-----------------------------------------------------
733        #
734        # Sequential exec    |  Hangs      | Not OK
735        # Parallel exec      |  Hangs      | OK 
736        #
737
738
739
740# Initialise MPI
741#
742# Attempt to initialise mpiext.so
743# If this fails, define a rudimentary interface suitable for
744# sequential execution.
745
746if error:
747    print "WARNING: MPI library could not be initialised - running sequentially"
748
749    # Define rudimentary functions to keep sequential programs happy
750
751    def size(): return 1
752    def rank(): return 0
753
754    def get_processor_name():
755        import os
756        try:
757            hostname = os.environ['HOST']
758        except:
759            try: 
760                hostname = os.environ['HOSTNAME'] 
761            except:
762                hostname = 'Unknown' 
763
764        return hostname
765       
766    def abort():
767        import sys
768        sys.exit()
769
770    def finalize(): pass
771 
772    def barrier(): pass 
773
774    def time():
775        import time
776        return time.time()
777
778else:
779
780    from mpiext import size, rank, barrier, time,\
781         get_processor_name,\
782         init, initialized, finalize, abort,\
783         send_string, receive_string,\
784         send_array, receive_array, broadcast_string, broadcast_array,\
785         scatter_string, scatter_array,\
786         gather_string, gather_array,\
787         reduce_array,\
788         MPI_ANY_TAG as any_tag, MPI_TAG_UB as max_tag,\
789         MPI_ANY_SOURCE as any_source,\
790         MAX, MIN, SUM, PROD, LAND, BAND,\
791         LOR, BOR, LXOR, BXOR
792
793    init(sys.argv) #Initialise MPI with cmd line (needed by MPICH/Linux)
794
795    if rank() == 0:     
796        print "Pypar (version %s) initialised MPI OK with %d processors" %(__version__, size())
797
798
799
800
801
802
Note: See TracBrowser for help on using the repository browser.