source: anuga_validation/Hinwood_2008/calc_rmsd.py @ 5698

Last change on this file since 5698 was 5698, checked in by duncan, 16 years ago

Hinwood - continuing fix for line overrunning graph in anuga report, plus words for the report

File size: 7.3 KB
Line 
1"""
2Functions used to calculate the root mean square deviation.
3
4Duncan Gray, GA - 2007
5
6"""
7
8
9#----------------------------------------------------------------------------
10# Import necessary modules
11#----------------------------------------------------------------------------
12
13# Standard modules
14import os
15from csv import writer
16from time import localtime, strftime
17from os.path import join
18
19# Related major packages
20from Numeric import zeros, Float, where, greater, less, compress, sqrt, sum
21from anuga.shallow_water.data_manager import csv2dict
22from anuga.utilities.numerical_tools import ensure_numeric, err, norm
23from anuga.utilities.interp import interp
24
25# Scenario specific imports
26import project
27
28def get_max_min_condition_array(min, max, vector):
29    """
30    Given a vector of values, and minimum and maximum values, return a
31    vector of 0/1's that can be used to cut arrays so only the times
32    in the min max range are used.
33   
34    precondition: The vector values are ascending.
35   
36    """
37   
38    SMALL_MIN = -1e10  # Not that small, but small enough
39    vector = ensure_numeric(vector)
40    assert min > SMALL_MIN
41    no_maxs = where(less(vector,max), vector, SMALL_MIN)
42    band_condition = greater(no_maxs, min)
43    return band_condition
44
45   
46def auto_rrms(outputdir_tag, scenarios, quantity='stage',
47              y_location_tag=':0.0'):
48    """
49    Given a list of scenarios that have CSV guage files, calc the
50    err, Number_of_samples and rmsd for all gauges in each scenario.
51    Write this info to a file for each scenario.
52    """
53    for run_data in scenarios:           
54        location_sims = []
55        location_exps = []
56        for gauge_x in run_data['gauge_x']:
57            gauge_x = str(gauge_x)
58            location_sims.append(gauge_x + y_location_tag)
59            location_exps.append(gauge_x)
60       
61        id = run_data['scenario_id']
62        outputdir_name = id + outputdir_tag
63        file_sim = join(project.output_dir,outputdir_name + '_' +  \
64                        quantity + ".csv")
65        file_exp = id + '_exp_' + quantity + '.csv'
66        file_err = join(project.output_dir,outputdir_name + "_" + \
67                        quantity + "_err.csv")
68       
69
70        simulation, _ = csv2dict(file_sim)
71        experiment, _ = csv2dict(file_exp)
72       
73        time_sim = [float(x) for x in simulation['time']]
74        time_exp = [float(x) for x in experiment['Time']]
75        time_sim = ensure_numeric(time_sim)
76        time_exp = ensure_numeric(time_exp)
77        condition = get_max_min_condition_array(run_data['wave_times'][0],
78                                                run_data['wave_times'][1],
79                                                time_exp)
80        time_exp_cut = compress(condition, time_exp)
81       
82        print "Writing to ", file_err
83       
84        err_list = []
85        points = []
86        rmsd_list = []
87        for location_sim, location_exp in map(None, location_sims,
88                                              location_exps):
89            quantity_sim = [float(x) for x in simulation[location_sim]]
90            quantity_exp = [float(x) for x in experiment[location_exp]]
91
92            quantity_exp_cut = compress(condition, quantity_exp)
93
94            # Now let's do interpolation
95            quantity_sim_interp = interp(quantity_sim, time_sim, time_exp_cut)
96
97            assert len(quantity_sim_interp) == len(quantity_exp_cut)
98            norm = err(quantity_sim_interp,
99                       quantity_exp_cut,
100                       2, relative = False)  # 2nd norm (rel. RMS)
101            err_list.append(norm)
102            points.append(len(quantity_sim_interp))
103            rmsd_list.append(norm/sqrt(len(quantity_sim_interp))) 
104        assert len(location_exps) == len(err_list)
105
106        # Writing the file out for one scenario
107        a_writer = writer(file(file_err, "wb"))
108        a_writer.writerow(["x location", "err", "Number_of_samples", "rmsd"])
109        a_writer.writerows(map(None,
110                               location_exps,
111                               err_list,
112                               points,
113                               rmsd_list))
114
115
116
117def load_sensors(quantity_file):
118    """
119    Load a csv file, where the first row is the column header and
120    the first colum explains the rows.
121
122    returns the data as two vectors and an array.
123   
124    """
125   
126    # Read the depth file
127    dfid = open(quantity_file)
128    lines = dfid.readlines()
129    dfid.close()
130
131    title = lines.pop(0)
132    n_time = len(lines)
133    n_sensors = len(lines[0].split(','))-1  # -1 to remove time
134    times = zeros(n_time, Float)  #Time
135    depths = zeros(n_time, Float)  #
136    sensors = zeros((n_time,n_sensors), Float)
137    quantity_locations = title.split(',')
138    quantity_locations.pop(0) # remove 'time'
139
140    # Doing j.split(':')[0] drops the y location
141    locations = [float(j.split(':')[0]) for j in quantity_locations]
142   
143    for i, line in enumerate(lines):
144        fields = line.split(',')
145        fields = [float(j) for j in fields]
146        times[i] = fields[0]
147        sensors[i] = fields[1:] # 1: to remove time
148
149    return times, locations, sensors                 
150
151   
152def err_files(scenarios, outputdir_tag, quantity='stage'):
153    """
154    Create a list of err files, for a list of scenarios.
155    """
156    file_errs = []
157    for scenario in scenarios:
158        id = scenario['scenario_id']
159        outputdir_name = id + outputdir_tag
160        file_err =  join(project.output_dir,outputdir_name + "_" + \
161                         quantity + "_err.csv")
162        file_errs.append(file_err)
163    return file_errs
164   
165
166def compare_different_settings(outputdir_tag, scenarios, quantity='stage'):
167    """
168    Calculate the RMSD for all the tests in a scenario
169    """
170    files = err_files(scenarios, outputdir_tag, quantity=quantity)
171    err = 0.0
172    number_of_samples = 0
173    for run_data, file in map(None, scenarios, files):
174       
175        simulation, _ = csv2dict(file)
176        err_list = [float(x) for x in simulation['err']]
177        number_of_samples_list = [float(x) for x in \
178                                  simulation['Number_of_samples']]
179       
180        if number_of_samples is not 0:
181            err_list.append(err)
182            number_of_samples_list.append(number_of_samples)
183        err, number_of_samples = err_addition(err_list, number_of_samples_list)
184    rmsd = err/sqrt(number_of_samples)
185    print outputdir_tag + "   " + str(rmsd)
186   
187   
188   
189def err_addition(err_list, number_of_samples_list):
190    """
191    This function 'sums' a list of errs and sums a list of samples
192   
193    err is the err value (sqrt(sum_over_x&y((xi - yi)^2))) for a set of values.
194    number_of_samples is the number of values associated with the err.
195   
196    If this function gets used alot, maybe pull this out and make it an object
197    """
198    err = norm(ensure_numeric(err_list))
199    number_of_samples = sum(ensure_numeric(number_of_samples_list))
200
201    return err, number_of_samples
202
203                 
204#-------------------------------------------------------------
205if __name__ == "__main__":
206   
207    from scenarios import scenarios
208   
209    #scenarios = [scenarios[0]] # !!!!!!!!!!!!!!!!!!!!!!
210
211    outputdir_tag = "_nolmts_wdth_0.1_z_0.0_ys_0.01_mta_0.01_A"
212    calc_norms = True
213    #calc_norms = False
214    if calc_norms:
215        auto_rrms(outputdir_tag, scenarios, "stage", y_location_tag=':0.0')
216    compare_different_settings(outputdir_tag, scenarios, "stage")
217   
Note: See TracBrowser for help on using the repository browser.