1 | #external modules |
---|
2 | import os |
---|
3 | import unittest |
---|
4 | import numpy as num |
---|
5 | from Scientific.IO.NetCDF import NetCDFFile |
---|
6 | |
---|
7 | # ANUGA modules |
---|
8 | from anuga.config import netcdf_float32, netcdf_float64 |
---|
9 | from anuga.file.csv_file import load_csv_as_dict |
---|
10 | |
---|
11 | # Local modules |
---|
12 | from csv2sts import csv2sts |
---|
13 | |
---|
14 | # some test file we want to generate |
---|
15 | testfile_csv = 'small___.csv' |
---|
16 | sts_out = 'sts_out.sts' |
---|
17 | |
---|
18 | lat = 10 |
---|
19 | lon = 20 |
---|
20 | |
---|
21 | |
---|
22 | |
---|
23 | class Test_csv2sts(unittest.TestCase): |
---|
24 | """ |
---|
25 | Test csv to NetCDFFile conversion functionality. |
---|
26 | """ |
---|
27 | def setUp(self): |
---|
28 | self.verbose = True |
---|
29 | fid = open(testfile_csv, 'w') |
---|
30 | fid.write("""time stage |
---|
31 | 0 4 |
---|
32 | 1 150.66667 |
---|
33 | 2 150.83334 |
---|
34 | 3 151. |
---|
35 | 4 151.16667 |
---|
36 | 5 -34. |
---|
37 | 6 -34.16667 |
---|
38 | 7 -34.33333 |
---|
39 | 8 -34.5 |
---|
40 | 9 -1. |
---|
41 | 10 -5. |
---|
42 | 11 -9. |
---|
43 | 12 -13. |
---|
44 | """) |
---|
45 | fid.close() |
---|
46 | |
---|
47 | def tearDown(self): |
---|
48 | pass |
---|
49 | # os.remove(testfile_csv) |
---|
50 | |
---|
51 | def test_missing_input_file(self): |
---|
52 | """ |
---|
53 | Test that a missing csv file raises the correct exception. |
---|
54 | """ |
---|
55 | got_except = False |
---|
56 | |
---|
57 | try: |
---|
58 | csv2sts('somename_not_here.csv', sts_out, 10, 20) |
---|
59 | except IOError, e: |
---|
60 | got_except = True |
---|
61 | except: |
---|
62 | assert False, 'Missing file raised wrong exception.' |
---|
63 | |
---|
64 | assert got_except is True, 'Missing file did not raise an exception.' |
---|
65 | |
---|
66 | def test_csv2sts_output(self): |
---|
67 | """ |
---|
68 | Test that a csv file is correctly rendered to .sts (NetCDF) format. |
---|
69 | """ |
---|
70 | csv2sts(testfile_csv, sts_out, latitude = lat, longitude = lon) |
---|
71 | self._check_generated_sts() |
---|
72 | |
---|
73 | def test_run_via_commandline(self): |
---|
74 | """ |
---|
75 | Make sure that the python file functions as a command-line tool. |
---|
76 | """ |
---|
77 | cmd = 'python csv2sts.py --latitude ' + str(lat) + ' --lon ' + str(lon) |
---|
78 | cmd += ' ' + testfile_csv + ' ' + sts_out |
---|
79 | print cmd |
---|
80 | os.system(cmd) |
---|
81 | self._check_generated_sts() |
---|
82 | |
---|
83 | |
---|
84 | def _check_generated_sts(self): |
---|
85 | """ check that we can read data out of the file """ |
---|
86 | sts = NetCDFFile(sts_out) |
---|
87 | |
---|
88 | data, names = load_csv_as_dict(testfile_csv, delimiter=' ', d_type = num.float64) |
---|
89 | |
---|
90 | assert sts.latitude == lat, 'latitude does not match' |
---|
91 | assert sts.longitude == lon, 'longitude does not match' |
---|
92 | |
---|
93 | assert len(sts.variables) == len(data), 'num variables does not match' |
---|
94 | |
---|
95 | # make sure data is returned in exactly the expected format |
---|
96 | for key, values in data.items(): |
---|
97 | assert list(sts.variables[key][:]) == values, \ |
---|
98 | 'stored data does not match' |
---|
99 | |
---|
100 | os.remove(sts_out) |
---|
101 | |
---|
102 | if __name__ == "__main__": |
---|
103 | suite = unittest.makeSuite(Test_csv2sts,'test') |
---|
104 | runner = unittest.TextTestRunner() |
---|
105 | runner.run(suite) |
---|