1 | """ |
---|
2 | Procedures to support checkpointing |
---|
3 | |
---|
4 | There is already checkpointing available in domain. |
---|
5 | |
---|
6 | Setup with domain.set_checkpointing(checkpoint_step, checkpoint_dir) |
---|
7 | |
---|
8 | checkpoint_step: the number of yieldsteps between saving a checkpoint file |
---|
9 | checkpoint_dir: the name of the directory where teh checkpoint files are stored. |
---|
10 | |
---|
11 | |
---|
12 | But if we are restarting a calculation there is no domain yet available, so we must |
---|
13 | read in the last stored domain. Do that via |
---|
14 | |
---|
15 | domain = load_last_checkpoint_file(domain_name, checkpoint_dir) |
---|
16 | |
---|
17 | """ |
---|
18 | |
---|
19 | from anuga import send, receive, myid, numprocs, barrier |
---|
20 | from time import time as walltime |
---|
21 | |
---|
22 | |
---|
23 | |
---|
24 | def load_checkpoint_file(domain_name = 'domain', checkpoint_dir = '.', time = None): |
---|
25 | |
---|
26 | from os.path import join |
---|
27 | |
---|
28 | if numprocs > 1: |
---|
29 | domain_name = domain_name+'_P{}_{}'.format(numprocs,myid) |
---|
30 | |
---|
31 | if time is None: |
---|
32 | # will pull out the last available time |
---|
33 | times = _get_checkpoint_times(domain_name, checkpoint_dir) |
---|
34 | |
---|
35 | times = list(times) |
---|
36 | times.sort() |
---|
37 | #print times |
---|
38 | else: |
---|
39 | times = [float(time)] |
---|
40 | |
---|
41 | if len(times) == 0: raise Exception, "Unable to open checkpoint file" |
---|
42 | |
---|
43 | for time in reversed(times): |
---|
44 | |
---|
45 | pickle_name = join(checkpoint_dir,domain_name)+'_'+str(time)+'.pickle' |
---|
46 | #print pickle_name |
---|
47 | |
---|
48 | try: |
---|
49 | import cPickle |
---|
50 | domain = cPickle.load(open(pickle_name, 'rb')) |
---|
51 | success = True |
---|
52 | except: |
---|
53 | success = False |
---|
54 | |
---|
55 | #print success |
---|
56 | overall = success |
---|
57 | for cpu in range(numprocs): |
---|
58 | if cpu != myid: |
---|
59 | send(success,cpu) |
---|
60 | |
---|
61 | for cpu in range(numprocs): |
---|
62 | if cpu != myid: |
---|
63 | overall = overall & receive(cpu) |
---|
64 | |
---|
65 | barrier() |
---|
66 | |
---|
67 | #print myid, overall, success, time |
---|
68 | |
---|
69 | if overall: break |
---|
70 | |
---|
71 | if not overall: raise Exception, "Unable to open checkpoint file" |
---|
72 | |
---|
73 | domain.last_walltime = walltime() |
---|
74 | domain.communication_time = 0.0 |
---|
75 | domain.communication_reduce_time = 0.0 |
---|
76 | domain.communication_broadcast_time = 0.0 |
---|
77 | |
---|
78 | return domain |
---|
79 | |
---|
80 | |
---|
81 | def _get_checkpoint_times(domain_name, checkpoint_dir): |
---|
82 | |
---|
83 | import os |
---|
84 | times = set() |
---|
85 | |
---|
86 | for (path, directory, filenames) in os.walk(checkpoint_dir): |
---|
87 | #print filenames |
---|
88 | #print directory |
---|
89 | |
---|
90 | if len(filenames) == 0: |
---|
91 | return None |
---|
92 | else: |
---|
93 | for filename in filenames: |
---|
94 | filebase = os.path.splitext(filename)[0].rpartition("_") |
---|
95 | time = filebase[-1] |
---|
96 | domain_name_base = filebase[0] |
---|
97 | if domain_name_base == domain_name : |
---|
98 | #print domain_name_base, time |
---|
99 | times.add(float(time)) |
---|
100 | |
---|
101 | |
---|
102 | #times.sort() |
---|
103 | #times = set(times) |
---|
104 | |
---|
105 | #print times |
---|
106 | combined = times |
---|
107 | for cpu in range(numprocs): |
---|
108 | if myid != cpu: |
---|
109 | send(times,cpu) |
---|
110 | rec = receive(cpu) |
---|
111 | #print rec |
---|
112 | combined = combined & rec |
---|
113 | |
---|
114 | #combined = list(combined).sort() |
---|
115 | #print combined |
---|
116 | |
---|
117 | return combined |
---|
118 | |
---|