source: pypar/pypar.py @ 1627

Last change on this file since 1627 was 126, checked in by ole, 19 years ago

Added bypass to reduce and broadcast

File size: 23.2 KB
Line 
1# =============================================================================
2# pypar.py - Parallel Python using MPI
3# Copyright (C) 2001, 2002, 2003 Ole M. Nielsen
4#              (Center for Mathematics and its Applications ANU and APAC)
5#
6#    This program is free software; you can redistribute it and/or modify
7#    it under the terms of the GNU General Public License as published by
8#    the Free Software Foundation; either version 2 of the License, or
9#    (at your option) any later version.
10#
11#    This program is distributed in the hope that it will be useful,
12#    but WITHOUT ANY WARRANTY; without even the implied warranty of
13#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14#    GNU General Public License (http://www.gnu.org/copyleft/gpl.html)
15#    for more details.
16#
17#    You should have received a copy of the GNU General Public License
18#    along with this program; if not, write to the Free Software
19#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307
20#
21#
22# Contact address: Ole.Nielsen@anu.edu.au
23#
24# Version: See pypar.__version__
25# =============================================================================
26
27"""Module pypar.py - Parallel Python using MPI
28
29Public functions:
30
31size() -- Number of processors
32rank() -- Id of current processor
33get_processor_name() -- Return host name of current node
34
35send() -- Blocking send (all types)
36receive() -- Blocking receive (all types)
37broadcast() -- Broadcast
38time() -- MPI wall time
39barrier() -- Synchronisation point. Makes processors wait until all processors
40             have reached this point.
41abort() -- Terminate all processes.
42finalize() -- Cleanup MPI. No parallelism can take place after this point.
43
44
45See doc strings of individual functions for detailed documentation.
46"""
47
48# Meta data
49__version__ = '1.9.2'
50__date__ = '11 July 2005'
51__author__ = 'Ole M. Nielsen'
52
53
54# Constants
55#
56max_tag = 32767      # Max tag value (MPI_TAG_UB didn't work and returned 0)
57control_tag = 13849  # Reserved tag used to identify control information
58default_tag = 1      # Tag used as default if not specified
59
60control_sep = ':'          # Separator for fields in control info (NOT ',')
61control_data_max_size = 64 # Maximal size of string holding control data
62
63
64#---------------------------------------------------------------------------
65# Communication functions
66#--------------------------------------------------------------------------
67
68def send(x, destination, use_buffer=False, vanilla=False,
69         tag=default_tag, bypass=False):
70    """Wrapper for easy MPI send.
71       Send x to destination.
72       
73       Automatically determine appropriate protocol
74       and call corresponding send function.
75       Also passes type and size information on as preceding message to
76       simplify the receive call.
77       
78       The variable x can be any (picklable) type, but
79       Numeric variables and text strings will most efficient.
80       Setting vanilla = 1 forces vanilla mode for any type.
81
82       If bypass is True, all admin and error checks
83       get bypassed to reduce the latency. Should only
84       be used for sending Numeric arrays and should be matched
85       with a bypass in the corresponding receive command.
86
87    """
88    import types, string
89
90    if bypass:
91        send_array(x, destination, tag)
92        return
93       
94    #Input check
95    errmsg = 'Destination id (%s) must be an integer.' %destination
96    assert type(destination) == types.IntType, errmsg
97   
98    errmsg = 'Tag %d is reserved by pypar - please use another.' %control_tag
99    assert tag != control_tag, errmsg
100
101    #Create metadata about object to be sent
102    control_info, x = create_control_info(x, vanilla, return_object=True)
103    protocol = control_info[0]
104
105
106    #Possibly transmit control data
107    if use_buffer is False:
108        send_control_info(control_info, destination)   
109
110     
111    #Transmit payload data   
112    if protocol == 'array':
113        send_array(x, destination, tag)   
114    elif protocol in ['string', 'vanilla']:
115        send_string(x, destination, tag)         
116    else:
117        raise 'Unknown protocol: %s' %protocol   
118
119     
120def receive(source, buffer=None, vanilla=False, tag=default_tag,
121            return_status=False, bypass=False):           
122    """receive - blocking MPI receive
123   
124       Receive data from source.
125
126       Optional parameters:
127         buffer: Use specified buffer for received data (faster). Default None.
128         vanilla: Specify to enforce vanilla protocol for any type. Default False
129         tag: Only received messages tagged as specified. Default default_tag
130         return_status: Return Status object along with result. Default False.
131
132       If no buffer is specified, receive will try to receive a
133       preceding message containing protocol, type, size and shape and
134       then create a suitable buffer.
135
136       If buffer is specified the corresponding send must specify
137       use_buffer = True.
138       The variable buffer can be any (picklable) type, but
139       Numeric variables and text strings will most efficient.
140
141       Appropriate protocol will be automatically determined
142       and corresponding receive function called.
143
144
145       If bypass is True, all admin and error checks
146       get bypassed to reduce the latency. Should only
147       be used for receiving Numerical arrays and should
148       be matched with a bypass in the corresponding send command.
149       Also buffer must be specified.
150    """
151
152    if bypass:
153        #errmsg = 'bypass mode must be used with specified buffer'
154        #assert buffer is not None, msg
155        stat = receive_array(buffer, source, tag)       
156    else:   
157   
158        import types 
159   
160        #Input check
161        errmsg = 'Source id (%s) must be an integer.' %source
162        assert type(source) == types.IntType, errmsg
163   
164        errmsg = 'Tag %d is reserved by pypar - please use another.' %control_tag
165        assert tag != control_tag, errmsg
166   
167   
168        #Either receive or create metadata about objetc to receive
169        if buffer is None:
170            protocol, typecode, size, shape = receive_control_info(source)
171        else: 
172            protocol, typecode, size, shape = create_control_info(buffer, vanilla)
173   
174   
175        #Receive payload data     
176        if protocol == 'array':
177            if buffer is None:
178                import Numeric
179                buffer = Numeric.zeros(size,typecode)
180                buffer = Numeric.reshape(buffer, shape)
181           
182            stat = receive_array(buffer, source, tag)
183           
184        elif protocol == 'string':
185            if buffer is None:
186                buffer = ' '*size
187           
188            stat = receive_string(buffer, source, tag)
189   
190        elif protocol == 'vanilla':
191            from cPickle import dumps, loads     
192            if buffer is None:
193                s = ' '*size     
194            else:
195                s = dumps(buffer, 1)
196                s = s + ' '*int(0.1*len(s)) #safety
197           
198            stat = receive_string(s, source, tag)
199            buffer = loads(s)  #Replace buffer with received result
200        else:
201            raise 'Unknown protocol: %s' %protocol
202
203    # Return received data and possibly the status object 
204    if return_status:
205        return buffer, Status(stat)
206    else:
207        return buffer
208
209
210def broadcast(buffer, root, vanilla=False, bypass=False):
211    """Wrapper for MPI bcast.
212
213       Broadcast buffer from the process with rank root to all other processes.
214
215   
216       Automatically determine appropriate protocol
217       and call corresponding send function.
218       
219       The variable buffer can be any (picklable) type, but
220       Numeric variables and text strings will most efficient.
221       Setting vanilla = 1 forces vanilla mode for any type.
222
223       If bypass is True, all admin and error checks
224       get bypassed to reduce the latency.
225
226    """
227
228    if bypass:
229        broadcast_array(buffer, root)
230        return
231   
232
233    import types
234   
235    #Input check
236    errmsg = 'Root id (%s) must be an integer.' %root
237    assert type(root) == types.IntType, errmsg
238
239
240    #Create metadata about object to be sent
241    protocol = create_control_info(buffer, vanilla)[0]
242
243
244    #Broadcast
245    if protocol == 'array':
246        broadcast_array(buffer, root)   
247    elif protocol == 'string':
248        broadcast_string(buffer, root)         
249    elif protocol == 'vanilla':
250        from cPickle import loads, dumps
251        s = dumps(buffer, 1)
252        s = s + ' '*int(0.1*len(s)) #safety
253       
254        broadcast_string(s, root)
255        buffer = loads(s)
256    else:
257        raise 'Unknown protocol: %s' %protocol 
258     
259    return buffer         
260
261
262def scatter(x, root, buffer=None, vanilla=False):
263    """Sends data x from process with rank root to all other processes.
264   
265       Create appropriate buffer and receive data.
266       Return scattered result (same type as x)
267
268       Scatter makes only sense for arrays or strings
269    """
270
271    import types
272    from mpiext import size
273    numproc = size()         #Needed to determine buffer size   
274   
275    #Input check
276    errmsg = 'Root id (%s) must be an integer.' %root
277    assert type(root) == types.IntType, errmsg
278
279   
280    #Create metadata about object to be sent
281    protocol, typecode, size, shape = create_control_info(x)
282
283    #Scatter
284    if protocol == 'array':
285        if buffer is None:
286            import Numeric
287           
288            # Modify shape along axis=0 to match size
289            shape = list(shape)
290            shape[0] /= numproc
291            count = Numeric.product(shape)           
292           
293            buffer = Numeric.zeros(count, typecode)
294            buffer = Numeric.reshape(buffer, shape)
295     
296        scatter_array(x, buffer, root)
297    elif protocol == 'string':
298        if buffer is None:
299            buffer = ' '*(size/numproc)
300           
301        scatter_string(x, buffer, root)
302    elif protocol == 'vanilla':
303        errmsg = 'Scatter is only supported for Numeric arrays and strings.\n'
304        errmsg += 'If you wish to distribute a general sequence, '
305        errmsg += 'please use send and receive commands or broadcast.'
306        raise errmsg
307    else:
308        raise 'Unknown protocol: %s' %protocol
309       
310    return buffer 
311
312
313def gather(x, root, buffer=None, vanilla=0):
314    """Gather values from all processes to root
315       
316       Create appropriate buffer and receive data.
317
318       Gather only makes sens for arrays or strings
319    """
320
321    import types     
322    from mpiext import size
323    numproc = size()         #Needed to determine buffer size
324
325    #Input check
326    errmsg = 'Root id (%s) must be an integer.' %root
327    assert type(root) == types.IntType, errmsg
328
329    #Create metadata about object to be gathered
330    protocol, typecode, size, shape = create_control_info(x)
331
332    #Gather
333    if protocol == 'array':
334        if buffer is None:
335            import Numeric
336            buffer = Numeric.zeros(size*numproc, typecode)
337
338            # Modify shape along axis=0 to match size
339            shape = list(shape)
340            shape[0] *= numproc
341            buffer = Numeric.reshape(buffer, shape)
342     
343        gather_array(x, buffer, root)   
344    elif protocol == 'string':
345        if buffer is None:
346            buffer = ' '*size*numproc
347       
348        gather_string(x, buffer, root)         
349    elif protocol == 'vanilla':
350        errmsg = 'Gather is only supported for Numeric arrays and strings.\n'
351        errmsg += 'If you wish to distribute a general sequence, '
352        errmsg += 'please use send and receive commands or broadcast.'
353        raise errmsg
354    else:
355        raise 'Unknown protocol: %s' %protocol
356       
357    return buffer 
358
359
360def reduce(x, op, root, buffer=None, vanilla=0, bypass=False):
361    """Reduce elements in x to buffer (of the same size as x)
362       at root applying operation op elementwise.
363
364       If bypass is True, all admin and error checks
365       get bypassed to reduce the latency.
366       The buffer must be specified explicitly in this case.
367    """
368
369
370
371    if bypass:
372        reduce_array(x, buffer, op, root)
373        return
374   
375
376    import types   
377    from mpiext import size
378    numproc = size()         #Needed to determine buffer size
379
380
381    #Input check
382    errmsg = 'Root id (%s) must be an integer.' %root
383    assert type(root) == types.IntType, errmsg
384
385    #Create metadata about object
386    protocol, typecode, size, shape = create_control_info(x)
387
388
389    #Reduce
390    if protocol == 'array':
391        if buffer is None:
392            import Numeric
393            buffer = Numeric.zeros(size*numproc, typecode)
394     
395            # Modify shape along axis=0 to match size
396            shape = list(shape)
397            shape[0] *= numproc
398            buffer = Numeric.reshape(buffer, shape)
399     
400        reduce_array(x, buffer, op, root)   
401    elif (protocol == 'vanilla' or protocol == 'string'):
402        raise 'Protocol: %s unsupported for reduce' %protocol
403    else:
404        raise 'Unknown protocol: %s' %protocol
405     
406    return buffer 
407
408
409#---------------------------------------------------------
410# AUXILIARY FUNCTIONS
411#---------------------------------------------------------
412def balance(N, P, p):
413    """Compute p'th interval when N is distributed over P bins.
414    """
415
416    from math import floor
417
418    L = int(floor(float(N)/P))
419    K = N - P*L
420    if p < K:
421        Nlo = p*L + p
422        Nhi = Nlo + L + 1
423    else:
424        Nlo = p*L + K
425        Nhi = Nlo + L
426
427    return Nlo, Nhi
428
429
430# Obsolete functions
431# (for backwards compatibility - remove in version 2.0)
432
433def raw_send(x, destination, tag=default_tag, vanilla=0):
434    send(x, destination, use_buffer=True, tag=tag, vanilla=vanilla) 
435
436
437def raw_receive(x, source, tag=default_tag, vanilla=0, return_status=0):
438    x = receive(source, tag=tag, vanilla=vanilla,
439              return_status=return_status, buffer=x)
440    return x
441
442def raw_scatter(x, buffer, source, vanilla=0):
443    scatter(x, source, buffer=buffer, vanilla=vanilla)
444
445def raw_gather(x, buffer, source, vanilla=0):
446    gather(x, source, buffer=buffer, vanilla=0) 
447
448def raw_reduce(x, buffer, op, source, vanilla=0):
449    reduce(x, op, source, buffer=buffer, vanilla=0)
450
451def bcast(buffer, root, vanilla=False):
452    return broadcast(buffer, root, vanilla)
453
454def Wtime():
455    return time()
456
457def Get_processor_name():
458    return get_processor_name()
459
460def Initialized():
461    return initialized()
462
463def Finalize():
464    finalize()
465
466def Abort():
467    abort()
468
469def Barrier():
470    barrier()
471   
472     
473
474#---------------------------------------------------------
475# INTERNAL FUNCTIONS
476#---------------------------------------------------------
477
478class Status:
479    """ MPI Status block returned by receive if
480        specified with parameter return_status=True
481    """   
482 
483    def __init__(self, status_tuple):
484        self.source = status_tuple[0]  #Id of sender
485        self.tag = status_tuple[1]     #Tag of received message
486        self.error = status_tuple[2]   #MPI Error code
487        self.length = status_tuple[3]  #Number of elements transmitted
488        self.size = status_tuple[4]    #Size of one element
489
490    def __repr__(self):
491        return 'Pypar Status Object:\n  source=%d\n  tag=%d\n '+\
492               'error=%d\n  length=%d\n  size=%d\n'\
493               %(self.source, self.tag, self.error, self.length, self.size)
494 
495    def bytes(self):
496        """Number of bytes transmitted (excl control info)
497        """
498        return self.length * self.size
499 
500
501
502def create_control_info(x, vanilla=0, return_object=False):
503    """Determine which protocol to use for communication:
504       (Numeric) arrays, strings, or vanilla based x's type.
505
506       There are three protocols:
507       'array':   Numeric arrays of type 'i', 'l', 'f', 'd', 'F' or 'D' can be
508                  communicated  with mpiext.send_array and mpiext.receive_array.
509       'string':  Text strings can be communicated with mpiext.send_string and
510                  mpiext.receive_string.
511       'vanilla': All other types can be communicated as string representations
512                  provided that the objects
513                  can be serialised using pickle (or cPickle).
514                  The latter mode is less efficient than the
515                  first two but it can handle general structures.
516
517       Rules:
518       If keyword argument vanilla == 1, vanilla is chosen regardless of
519       x's type.
520       Otherwise if x is a string, the string protocol is chosen
521       If x is an array, the 'array' protocol is chosen provided that x has one
522       of the admissible typecodes.
523
524       The optional argument return_object asks to return object as well.
525       This is useful in case it gets modified as in the case of general structures
526       using the vanilla protocol.
527    """
528
529    import types
530
531    #Default values
532    protocol = 'vanilla'
533    typecode = ' '
534    size = 0
535    shape = ()
536
537    #Determine protocol in case
538    if not vanilla:
539        if type(x) == types.StringType:
540            protocol = 'string'
541            typecode = 'c'
542            size = len(x)
543        elif type(x).__name__ == 'array': #Numeric isn't imported yet
544            try:
545                import Numeric
546            except:
547                print "WARNING (pypar.py): Numeric module could not be imported,",
548                print "reverting to vanilla mode"
549                protocol = 'vanilla'
550            else: 
551                typecode = x.typecode() 
552                if typecode in ['i', 'l', 'f', 'd', 'F', 'D']:
553                    protocol = 'array'
554                    shape = x.shape
555                    size = Numeric.product(shape)
556                else:   
557                    print "WARNING (pypar.py): Numeric object type %s is not supported."\
558                          %(x.typecode())
559                    print "Only types 'i', 'l', 'f', 'd', 'F', 'D' are supported,",
560                    print "Reverting to vanilla mode."
561                    protocol = 'vanilla'
562
563    #Pickle general structures using the vanilla protocol               
564    if protocol == 'vanilla':                   
565        from cPickle import dumps     
566        x = dumps(x, 1)
567        size = len(x) # Let count be length of pickled object
568
569    #Return   
570    if return_object:
571        return [protocol, typecode, size, shape], x
572    else: 
573        return [protocol, typecode, size, shape]
574
575
576
577#----------------------------------------------
578
579def send_control_info(control_info, destination):
580    """Send control info to destination
581    """
582    import string
583
584    #Convert to strings
585    control_info = [str(c) for c in control_info]
586 
587    control_msg = string.join(control_info,control_sep)
588    if len(control_msg) > control_data_max_size:
589        errmsg = 'Length of control_info exceeds specified maximium (%d)'\
590                 %control_data_max_size
591        errmsg += ' - Please increase it (in pypar.py)' 
592        raise errmsg
593 
594    send_string(control_msg, destination, control_tag)
595
596 
597def receive_control_info(source):
598    """Receive control info from source
599    """
600    import string
601 
602    msg = ' '*control_data_max_size
603
604    stat = receive_string(msg, source, control_tag)
605    #No need to create status object here - it is reserved
606    #for payload communications only
607
608    msg = msg[:stat[3]] #Trim buffer to actual received length (needed?)
609
610    control_info = msg.split(control_sep)
611
612    assert len(control_info) == 4, 'len(control_info) = %d' %len(control_info)
613    control_info[2] = eval(control_info[2]) #Convert back to int
614    control_info[3] = eval(control_info[3]) #Convert back to tuple
615
616
617    return control_info
618
619
620#----------------------------------------------------------------------------
621# Initialise module
622#----------------------------------------------------------------------------
623
624
625# Take care of situation where module is part of package
626import sys, os, string, os.path
627dirname = os.path.dirname(string.replace(__name__,'.',os.sep)).strip()
628
629if not dirname:
630    dirname = '.'
631   
632if dirname[-1] != os.sep:
633    dirname += os.sep   
634
635
636
637# Import MPI extension
638#
639# Verify existence of mpiext.so.
640
641try:
642    import mpiext
643except:
644    errmsg = 'ERROR: C extension mpiext could not be imported.\n'
645    errmsg += 'Please compile mpiext.c e.g. by running\n'
646    errmsg += '  python install.py\n'
647    errmsg += 'in the pypar directory, or by using\n'
648    errmsg += '  python setup.py install\n'
649    raise Exception, errmsg
650
651 
652# Determine if MPI program is allowed to run sequentially on current platform
653# Attempting to check this automatically may case some systems to hang.
654
655if sys.platform in ['linux2', 'sunos5', 'win32', 'darwin']:
656    #Linux (LAM,MPICH) or Sun (MPICH)
657    error = 0  #Sequential execution of MPI is allowed   
658else:
659    #Platform: Alpha 'osf1V5' 
660    cmdstring = '"import mpiext, sys; mpiext.init(sys.argv); mpiext.Finalize()"'
661    #s = 'cd %s; python -c %s' %(dirname, cmdstring)
662    s = 'python -c %s >/dev/null 2>/dev/null' %cmdstring 
663    error = os.system(s)
664 
665    # The check is performed in a separate shell.
666    # Reason: The Alpha server, LAM/Linux or the Sun cannot recover from a
667    # try:
668    #   mpiext.init(sys.argv)
669
670    # However, on LAM/Linux, this test causes system to hang.
671    # Verified (OMN 12/12/2)
672    # If lamboot is started, the system, will hang when init is called
673    # again further down in this file.
674    # If lamboot is not loaded error will be nozero as it should.
675    # I don't know how to deal with this
676    #
677    #Comparisons of two strategies using LAM
678    #
679    # Strategy 1: Assume seq execution is OK (i.e. set error = 0)
680    # Strategy 2: Try to test if mpi can be initialised (in a separate shell)
681    #
682    #
683    # Strategy 1 (currently used)
684    #                    | Lam booted  | Lam not booted
685    #-----------------------------------------------------
686    #
687    # Sequential exec    |  OK         | Not OK
688    # Parallel exec      |  OK         | Not OK 
689    #
690    #
691    # Strategy 2
692    #                    | Lam booted  | Lam not booted
693    #-----------------------------------------------------
694    #
695    # Sequential exec    |  Hangs      | Not OK
696    # Parallel exec      |  Hangs      | OK 
697    #
698
699
700
701# Initialise MPI
702#
703# Attempt to initialise mpiext.so
704# If this fails, define a rudimentary interface suitable for
705# sequential execution.
706
707if error:
708    print "WARNING: MPI library could not be initialised - running sequentially"
709
710    # Define rudimentary functions to keep sequential programs happy
711
712    def size(): return 1
713    def rank(): return 0
714
715    def get_processor_name():
716        import os
717        try:
718            hostname = os.environ['HOST']
719        except:
720            try: 
721                hostname = os.environ['HOSTNAME'] 
722            except:
723                hostname = 'Unknown' 
724
725        return hostname
726       
727    def abort():
728        import sys
729        sys.exit()
730
731    def finalize(): pass
732 
733    def barrier(): pass 
734
735    def time():
736        import time
737        return time.time()
738
739else:
740
741    from mpiext import size, rank, barrier, time,\
742         get_processor_name,\
743         init, initialized, finalize, abort,\
744         send_string, receive_string,\
745         send_array, receive_array, broadcast_string, broadcast_array,\
746         scatter_string, scatter_array,\
747         gather_string, gather_array,\
748         reduce_array,\
749         MPI_ANY_TAG as any_tag, MPI_TAG_UB as max_tag,\
750         MPI_ANY_SOURCE as any_source,\
751         MAX, MIN, SUM, PROD, LAND, BAND,\
752         LOR, BOR, LXOR, BXOR
753
754    init(sys.argv) #Initialise MPI with cmd line (needed by MPICH/Linux)
755
756    if rank() == 0:     
757        print "Pypar (version %s) initialised MPI OK with %d processors" %(__version__, size())
758
759
760
761
762
763
Note: See TracBrowser for help on using the repository browser.