source: trunk/anuga_core/source/anuga/shallow_water/checkpoint.py @ 9486

Last change on this file since 9486 was 9265, checked in by steve, 11 years ago

Moving checkpoint.py to shallow_water

File size: 3.1 KB
Line 
1"""
2Procedures to support checkpointing
3
4There is already checkpointing available in domain.
5
6Setup with  domain.set_checkpointing(checkpoint_step, checkpoint_dir)
7
8checkpoint_step: the number of yieldsteps between saving a checkpoint file
9checkpoint_dir: the name of the directory where teh checkpoint files are stored.
10
11
12But if we are restarting a calculation there is no domain yet available, so we must
13read in the last stored domain. Do that via
14
15domain = load_last_checkpoint_file(domain_name, checkpoint_dir)
16
17"""
18
19from anuga import send, receive, myid, numprocs, barrier
20from time import time as walltime
21
22
23
24def load_checkpoint_file(domain_name = 'domain', checkpoint_dir = '.', time = None):
25   
26    from os.path import join
27
28    if numprocs > 1:
29        domain_name = domain_name+'_P{}_{}'.format(numprocs,myid)
30           
31    if time is None:
32        # will pull out the last available time
33        times = _get_checkpoint_times(domain_name, checkpoint_dir)
34
35        times = list(times)
36        times.sort()
37        #print times
38    else:
39        times = [float(time)]
40   
41    if len(times) == 0: raise Exception, "Unable to open checkpoint file"
42   
43    for time in reversed(times):
44       
45        pickle_name = join(checkpoint_dir,domain_name)+'_'+str(time)+'.pickle'
46        #print pickle_name
47       
48        try:
49            import cPickle
50            domain = cPickle.load(open(pickle_name, 'rb'))
51            success = True
52        except:
53            success = False
54           
55        #print success
56        overall = success
57        for cpu in range(numprocs):
58            if cpu != myid:
59                send(success,cpu)
60               
61        for cpu in range(numprocs):
62            if cpu != myid:
63                overall = overall & receive(cpu)
64       
65        barrier() 
66       
67        #print myid, overall, success, time
68       
69        if overall: break
70   
71    if not overall: raise Exception, "Unable to open checkpoint file"
72   
73    domain.last_walltime = walltime()
74    domain.communication_time = 0.0
75    domain.communication_reduce_time = 0.0
76    domain.communication_broadcast_time = 0.0
77   
78    return domain
79
80
81def _get_checkpoint_times(domain_name, checkpoint_dir):
82
83    import os
84    times = set()
85   
86    for (path, directory, filenames) in os.walk(checkpoint_dir):
87        #print filenames
88        #print directory
89       
90        if len(filenames) == 0:
91            return None
92        else:         
93            for filename in filenames:
94                filebase = os.path.splitext(filename)[0].rpartition("_")
95                time = filebase[-1]
96                domain_name_base = filebase[0]
97                if domain_name_base == domain_name :
98                    #print domain_name_base, time
99                    times.add(float(time))
100
101
102    #times.sort()
103    #times = set(times)
104   
105    #print times
106    combined = times
107    for cpu in range(numprocs):
108        if myid != cpu:
109            send(times,cpu)
110            rec = receive(cpu) 
111            #print rec
112            combined = combined & rec
113   
114    #combined = list(combined).sort()
115    #print combined
116   
117    return combined
118   
Note: See TracBrowser for help on using the repository browser.