1 | #external modules |
---|
2 | import os |
---|
3 | import sys |
---|
4 | import unittest |
---|
5 | import numpy as num |
---|
6 | from Scientific.IO.NetCDF import NetCDFFile |
---|
7 | |
---|
8 | # ANUGA modules |
---|
9 | from anuga.config import netcdf_float32, netcdf_float64 |
---|
10 | from anuga.file.csv_file import load_csv_as_dict |
---|
11 | |
---|
12 | # Local modules |
---|
13 | from csv2sts import csv2sts |
---|
14 | |
---|
15 | # some test file we want to generate |
---|
16 | testfile_csv = 'small___.csv' |
---|
17 | sts_out = 'sts_out.sts' |
---|
18 | |
---|
19 | lat = 10 |
---|
20 | lon = 20 |
---|
21 | |
---|
22 | |
---|
23 | |
---|
24 | class Test_csv2sts(unittest.TestCase): |
---|
25 | """ |
---|
26 | Test csv to NetCDFFile conversion functionality. |
---|
27 | """ |
---|
28 | def setUp(self): |
---|
29 | """ Setup for all tests. """ |
---|
30 | self.verbose = True |
---|
31 | fid = open(testfile_csv, 'w') |
---|
32 | fid.write("""time stage |
---|
33 | 0 4 |
---|
34 | 1 150.66667 |
---|
35 | 2 150.83334 |
---|
36 | 3 151. |
---|
37 | 4 151.16667 |
---|
38 | 5 -34. |
---|
39 | 6 -34.16667 |
---|
40 | 7 -34.33333 |
---|
41 | 8 -34.5 |
---|
42 | 9 -1. |
---|
43 | 10 -5. |
---|
44 | 11 -9. |
---|
45 | 12 -13. |
---|
46 | """) |
---|
47 | fid.close() |
---|
48 | |
---|
49 | def tearDown(self): |
---|
50 | """ Cleanup for all tests. """ |
---|
51 | os.remove(testfile_csv) |
---|
52 | |
---|
53 | def test_missing_input_file(self): |
---|
54 | """ |
---|
55 | Test that a missing csv file raises the correct exception. |
---|
56 | """ |
---|
57 | got_except = False |
---|
58 | |
---|
59 | try: |
---|
60 | csv2sts('somename_not_here.csv', sts_out, 10, 20) |
---|
61 | except IOError, e: |
---|
62 | got_except = True |
---|
63 | except: |
---|
64 | assert False, 'Missing file raised wrong exception.' |
---|
65 | |
---|
66 | assert got_except is True, 'Missing file did not raise an exception.' |
---|
67 | |
---|
68 | def test_csv2sts_output(self): |
---|
69 | """ |
---|
70 | Test that a csv file is correctly rendered to .sts (NetCDF) format. |
---|
71 | """ |
---|
72 | csv2sts(testfile_csv, sts_out, latitude = lat, longitude = lon) |
---|
73 | self._check_generated_sts() |
---|
74 | |
---|
75 | def test_run_via_commandline(self): |
---|
76 | """ |
---|
77 | Make sure that the python file functions as a command-line tool. |
---|
78 | """ |
---|
79 | |
---|
80 | # Look for script in same dir as this unit test. |
---|
81 | path = os.path.dirname( os.path.realpath( __file__ ) ) |
---|
82 | |
---|
83 | cmd = 'python ' + path + os.sep +'csv2sts.py --latitude ' |
---|
84 | cmd += '%s --lon %s %s %s' % (str(lat), str(lon), testfile_csv, sts_out) |
---|
85 | |
---|
86 | os.system(cmd) |
---|
87 | self._check_generated_sts() |
---|
88 | |
---|
89 | |
---|
90 | def _check_generated_sts(self): |
---|
91 | """ check that we can read data out of the file """ |
---|
92 | sts = NetCDFFile(sts_out) |
---|
93 | |
---|
94 | data, names = load_csv_as_dict(testfile_csv, delimiter=' ', d_type = num.float64) |
---|
95 | |
---|
96 | assert sts.latitude == lat, 'latitude does not match' |
---|
97 | assert sts.longitude == lon, 'longitude does not match' |
---|
98 | |
---|
99 | assert len(sts.variables) == len(data), 'num variables does not match' |
---|
100 | |
---|
101 | # make sure data is returned in exactly the expected format |
---|
102 | for key, values in data.items(): |
---|
103 | assert list(sts.variables[key][:]) == values, \ |
---|
104 | 'stored data does not match' |
---|
105 | |
---|
106 | if not sys.platform == 'win32': |
---|
107 | # Windows cannot delete the file for some reason. |
---|
108 | os.remove(sts_out) |
---|
109 | |
---|
110 | if __name__ == "__main__": |
---|
111 | suite = unittest.makeSuite(Test_csv2sts,'test') |
---|
112 | runner = unittest.TextTestRunner() |
---|
113 | runner.run(suite) |
---|