source: inundation/pyvolution/least_squares.py @ 1904

Last change on this file since 1904 was 1904, checked in by ole, 19 years ago

Modified Interpolation_function to allow for time-independent quantities (such as elevation) and wrote test

File size: 48.3 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME (DSG-ON) reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203    interp = Interpolation(vertex_coordinates,
204                           triangles,
205                           point_coordinates,
206                           alpha = alpha,
207                           verbose = verbose,
208                           expand_search = expand_search,
209                           data_origin = data_origin,
210                           mesh_origin = mesh_origin,
211                           precrop = precrop)
212
213    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
214    return vertex_attributes
215
216
217
218def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
219                    verbose = False, reduction = 1):
220    """Fits attributes from pts file to MxN rectangular mesh
221
222    Read pts file and create rectangular mesh of resolution MxN such that
223    it covers all points specified in pts file.
224
225    FIXME: This may be a temporary function until we decide on
226    netcdf formats etc
227
228    FIXME: Uses elevation hardwired
229    """
230
231    import  mesh_factory
232    from load_mesh.loadASCII import import_points_file
233   
234    if verbose: print 'Read pts'
235    points_dict = import_points_file(pts_name)
236    #points, attributes = util.read_xya(pts_name)
237
238    #Reduce number of points a bit
239    points = points_dict['pointlist'][::reduction]
240    elevation = points_dict['attributelist']['elevation']  #Must be elevation
241    elevation = elevation[::reduction]
242
243    if verbose: print 'Got %d data points' %len(points)
244
245    if verbose: print 'Create mesh'
246    #Find extent
247    max_x = min_x = points[0][0]
248    max_y = min_y = points[0][1]
249    for point in points[1:]:
250        x = point[0]
251        if x > max_x: max_x = x
252        if x < min_x: min_x = x
253        y = point[1]
254        if y > max_y: max_y = y
255        if y < min_y: min_y = y
256
257    #Create appropriate mesh
258    vertex_coordinates, triangles, boundary =\
259         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
260                                (min_x, min_y))
261
262    #Fit attributes to mesh
263    vertex_attributes = fit_to_mesh(vertex_coordinates,
264                        triangles,
265                        points,
266                        elevation, alpha=alpha, verbose=verbose)
267
268
269
270    return vertex_coordinates, triangles, boundary, vertex_attributes
271
272
273
274class Interpolation:
275
276    def __init__(self,
277                 vertex_coordinates,
278                 triangles,
279                 point_coordinates = None,
280                 alpha = None,
281                 verbose = False,
282                 expand_search = True,
283                 interp_only = False,
284                 max_points_per_cell = 30,
285                 mesh_origin = None,
286                 data_origin = None,
287                 precrop = False):
288
289
290        """ Build interpolation matrix mapping from
291        function values at vertices to function values at data points
292
293        Inputs:
294
295          vertex_coordinates: List of coordinate pairs [xi, eta] of
296          points constituting mesh (or a an m x 2 Numeric array)
297          Points may appear multiple times
298          (e.g. if vertices have discontinuities)
299
300          triangles: List of 3-tuples (or a Numeric array) of
301          integers representing indices of all vertices in the mesh.
302
303          point_coordinates: List of coordinate pairs [x, y] of
304          data points (or an nx2 Numeric array)
305          If point_coordinates is absent, only smoothing matrix will
306          be built
307
308          alpha: Smoothing parameter
309
310          data_origin and mesh_origin are 3-tuples consisting of
311          UTM zone, easting and northing. If specified
312          point coordinates and vertex coordinates are assumed to be
313          relative to their respective origins.
314
315        """
316        from util import ensure_numeric
317
318        #Convert input to Numeric arrays
319        triangles = ensure_numeric(triangles, Int)
320        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
321
322        #Build underlying mesh
323        if verbose: print 'Building mesh'
324        #self.mesh = General_mesh(vertex_coordinates, triangles,
325        #FIXME: Trying the normal mesh while testing precrop,
326        #       The functionality of boundary_polygon is needed for that
327
328        #FIXME - geo ref does not have to go into mesh.
329        # Change the point co-ords to conform to the
330        # mesh co-ords early in the code
331        if mesh_origin == None:
332            geo = None
333        else:
334            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
335        self.mesh = Mesh(vertex_coordinates, triangles,
336                         geo_reference = geo)
337       
338        self.mesh.check_integrity()
339
340        self.data_origin = data_origin
341
342        self.point_indices = None
343
344        #Smoothing parameter
345        if alpha is None:
346            self.alpha = DEFAULT_ALPHA
347        else:   
348            self.alpha = alpha
349
350
351        if point_coordinates is not None:
352            if verbose: print 'Building interpolation matrix'
353            self.build_interpolation_matrix_A(point_coordinates,
354                                              verbose = verbose,
355                                              expand_search = expand_search,
356                                              interp_only = interp_only, 
357                                              max_points_per_cell =\
358                                              max_points_per_cell,
359                                              data_origin = data_origin,
360                                              precrop = precrop)
361        #Build coefficient matrices
362        if interp_only == False:
363            self.build_coefficient_matrix_B(point_coordinates,
364                                        verbose = verbose,
365                                        expand_search = expand_search,
366                                        max_points_per_cell =\
367                                        max_points_per_cell,
368                                        data_origin = data_origin,
369                                        precrop = precrop)
370
371
372    def set_point_coordinates(self, point_coordinates,
373                              data_origin = None,
374                              verbose = False,
375                              precrop = True):
376        """
377        A public interface to setting the point co-ordinates.
378        """
379        if point_coordinates is not None:
380            if verbose: print 'Building interpolation matrix'
381            self.build_interpolation_matrix_A(point_coordinates,
382                                              verbose = verbose,
383                                              data_origin = data_origin,
384                                              precrop = precrop)
385        self.build_coefficient_matrix_B(point_coordinates, data_origin)
386
387    def build_coefficient_matrix_B(self, point_coordinates=None,
388                                   verbose = False, expand_search = True,
389                                   max_points_per_cell=30,
390                                   data_origin = None,
391                                   precrop = False):
392        """Build final coefficient matrix"""
393
394
395        if self.alpha <> 0:
396            if verbose: print 'Building smoothing matrix'
397            self.build_smoothing_matrix_D()
398
399        if point_coordinates is not None:
400            if self.alpha <> 0:
401                self.B = self.AtA + self.alpha*self.D
402            else:
403                self.B = self.AtA
404
405            #Convert self.B matrix to CSR format for faster matrix vector
406            self.B = Sparse_CSR(self.B)
407
408    def build_interpolation_matrix_A(self, point_coordinates,
409                                     verbose = False, expand_search = True,
410                                     max_points_per_cell=30,
411                                     data_origin = None,
412                                     precrop = False,
413                                     interp_only = False):
414        """Build n x m interpolation matrix, where
415        n is the number of data points and
416        m is the number of basis functions phi_k (one per vertex)
417
418        This algorithm uses a quad tree data structure for fast binning of data points
419        origin is a 3-tuple consisting of UTM zone, easting and northing.
420        If specified coordinates are assumed to be relative to this origin.
421
422        This one will override any data_origin that may be specified in
423        interpolation instance
424
425        """
426
427
428        #FIXME (Ole): Check that this function is memeory efficient.
429        #6 million datapoints and 300000 basis functions
430        #causes out-of-memory situation
431        #First thing to check is whether there is room for self.A and self.AtA
432        #
433        #Maybe we need some sort of blocking
434
435        from quad import build_quadtree
436        from util import ensure_numeric
437
438        if data_origin is None:
439            data_origin = self.data_origin #Use the one from
440                                           #interpolation instance
441
442        #Convert input to Numeric arrays just in case.
443        point_coordinates = ensure_numeric(point_coordinates, Float)
444
445        #Keep track of discarded points (if any).
446        #This is only registered if precrop is True
447        self.cropped_points = False
448
449        #Shift data points to same origin as mesh (if specified)
450
451        #FIXME this will shift if there was no geo_ref.
452        #But all this should be removed anyhow.
453        #change coords before this point
454        mesh_origin = self.mesh.geo_reference.get_origin()
455        if point_coordinates is not None:
456            if data_origin is not None:
457                if mesh_origin is not None:
458
459                    #Transformation:
460                    #
461                    #Let x_0 be the reference point of the point coordinates
462                    #and xi_0 the reference point of the mesh.
463                    #
464                    #A point coordinate (x + x_0) is then made relative
465                    #to xi_0 by
466                    #
467                    # x_new = x + x_0 - xi_0
468                    #
469                    #and similarly for eta
470
471                    x_offset = data_origin[1] - mesh_origin[1]
472                    y_offset = data_origin[2] - mesh_origin[2]
473                else: #Shift back to a zero origin
474                    x_offset = data_origin[1]
475                    y_offset = data_origin[2]
476
477                point_coordinates[:,0] += x_offset
478                point_coordinates[:,1] += y_offset
479            else:
480                if mesh_origin is not None:
481                    #Use mesh origin for data points
482                    point_coordinates[:,0] -= mesh_origin[1]
483                    point_coordinates[:,1] -= mesh_origin[2]
484
485
486
487        #Remove points falling outside mesh boundary
488        #This reduced one example from 1356 seconds to 825 seconds
489        if precrop is True:
490            from Numeric import take
491            from util import inside_polygon
492
493            if verbose: print 'Getting boundary polygon'
494            P = self.mesh.get_boundary_polygon()
495
496            if verbose: print 'Getting indices inside mesh boundary'
497            indices = inside_polygon(point_coordinates, P, verbose = verbose)
498
499
500            if len(indices) != point_coordinates.shape[0]:
501                self.cropped_points = True
502                if verbose:
503                    print 'Done - %d points outside mesh have been cropped.'\
504                          %(point_coordinates.shape[0] - len(indices))
505
506            point_coordinates = take(point_coordinates, indices)
507            self.point_indices = indices
508
509
510
511
512        #Build n x m interpolation matrix
513        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
514        n = point_coordinates.shape[0]     #Nbr of data points
515
516        if verbose: print 'Number of datapoints: %d' %n
517        if verbose: print 'Number of basis functions: %d' %m
518
519        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
520        #However, Sparse_CSR does not have the same methods as Sparse yet
521        #The tests will reveal what needs to be done
522        #self.A = Sparse_CSR(Sparse(n,m))
523        #self.AtA = Sparse_CSR(Sparse(m,m))
524        self.A = Sparse(n,m)
525        self.AtA = Sparse(m,m)
526
527        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
528        root = build_quadtree(self.mesh,
529                              max_points_per_cell = max_points_per_cell)
530
531        #Compute matrix elements
532        for i in range(n):
533            #For each data_coordinate point
534
535            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
536            x = point_coordinates[i]
537
538            #Find vertices near x
539            candidate_vertices = root.search(x[0], x[1])
540            is_more_elements = True
541
542            element_found, sigma0, sigma1, sigma2, k = \
543                self.search_triangles_of_vertices(candidate_vertices, x)
544            while not element_found and is_more_elements and expand_search:
545                #if verbose: print 'Expanding search'
546                candidate_vertices, branch = root.expand_search()
547                if branch == []:
548                    # Searching all the verts from the root cell that haven't
549                    # been searched.  This is the last try
550                    element_found, sigma0, sigma1, sigma2, k = \
551                      self.search_triangles_of_vertices(candidate_vertices, x)
552                    is_more_elements = False
553                else:
554                    element_found, sigma0, sigma1, sigma2, k = \
555                      self.search_triangles_of_vertices(candidate_vertices, x)
556
557
558            #Update interpolation matrix A if necessary
559            if element_found is True:
560                #Assign values to matrix A
561
562                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
563                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
564                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
565
566                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
567                js     = [j0,j1,j2]
568
569                for j in js:
570                    self.A[i,j] = sigmas[j]
571                    for k in js:
572                        if interp_only == False:
573                            self.AtA[j,k] += sigmas[j]*sigmas[k]
574            else:
575                pass
576                #Ok if there is no triangle for datapoint
577                #(as in brute force version)
578                #raise 'Could not find triangle for point', x
579
580
581
582    def search_triangles_of_vertices(self, candidate_vertices, x):
583            #Find triangle containing x:
584            element_found = False
585
586            # This will be returned if element_found = False
587            sigma2 = -10.0
588            sigma0 = -10.0
589            sigma1 = -10.0
590            k = -10.0
591
592            #For all vertices in same cell as point x
593            for v in candidate_vertices:
594
595                #for each triangle id (k) which has v as a vertex
596                for k, _ in self.mesh.vertexlist[v]:
597
598                    #Get the three vertex_points of candidate triangle
599                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
600                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
601                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
602
603                    #print "PDSG - k", k
604                    #print "PDSG - xi0", xi0
605                    #print "PDSG - xi1", xi1
606                    #print "PDSG - xi2", xi2
607                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
608                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
609
610                    #Get the three normals
611                    n0 = self.mesh.get_normal(k, 0)
612                    n1 = self.mesh.get_normal(k, 1)
613                    n2 = self.mesh.get_normal(k, 2)
614
615
616                    #Compute interpolation
617                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
618                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
619                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
620
621                    #print "PDSG - sigma0", sigma0
622                    #print "PDSG - sigma1", sigma1
623                    #print "PDSG - sigma2", sigma2
624
625                    #FIXME: Maybe move out to test or something
626                    epsilon = 1.0e-6
627                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
628
629                    #Check that this triangle contains the data point
630
631                    #Sigmas can get negative within
632                    #machine precision on some machines (e.g nautilus)
633                    #Hence the small eps
634                    eps = 1.0e-15
635                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
636                        element_found = True
637                        break
638
639                if element_found is True:
640                    #Don't look for any other triangle
641                    break
642            return element_found, sigma0, sigma1, sigma2, k
643
644
645
646    def build_interpolation_matrix_A_brute(self, point_coordinates):
647        """Build n x m interpolation matrix, where
648        n is the number of data points and
649        m is the number of basis functions phi_k (one per vertex)
650
651        This is the brute force which is too slow for large problems,
652        but could be used for testing
653        """
654
655        from util import ensure_numeric
656
657        #Convert input to Numeric arrays
658        point_coordinates = ensure_numeric(point_coordinates, Float)
659
660        #Build n x m interpolation matrix
661        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
662        n = point_coordinates.shape[0]     #Nbr of data points
663
664        self.A = Sparse(n,m)
665        self.AtA = Sparse(m,m)
666
667        #Compute matrix elements
668        for i in range(n):
669            #For each data_coordinate point
670
671            x = point_coordinates[i]
672            element_found = False
673            k = 0
674            while not element_found and k < len(self.mesh):
675                #For each triangle (brute force)
676                #FIXME: Real algorithm should only visit relevant triangles
677
678                #Get the three vertex_points
679                xi0 = self.mesh.get_vertex_coordinate(k, 0)
680                xi1 = self.mesh.get_vertex_coordinate(k, 1)
681                xi2 = self.mesh.get_vertex_coordinate(k, 2)
682
683                #Get the three normals
684                n0 = self.mesh.get_normal(k, 0)
685                n1 = self.mesh.get_normal(k, 1)
686                n2 = self.mesh.get_normal(k, 2)
687
688                #Compute interpolation
689                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
690                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
691                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
692
693                #FIXME: Maybe move out to test or something
694                epsilon = 1.0e-6
695                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
696
697                #Check that this triangle contains data point
698                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
699                    element_found = True
700                    #Assign values to matrix A
701
702                    j0 = self.mesh.triangles[k,0] #Global vertex id
703                    #self.A[i, j0] = sigma0
704
705                    j1 = self.mesh.triangles[k,1] #Global vertex id
706                    #self.A[i, j1] = sigma1
707
708                    j2 = self.mesh.triangles[k,2] #Global vertex id
709                    #self.A[i, j2] = sigma2
710
711                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
712                    js     = [j0,j1,j2]
713
714                    for j in js:
715                        self.A[i,j] = sigmas[j]
716                        for k in js:
717                            self.AtA[j,k] += sigmas[j]*sigmas[k]
718                k = k+1
719
720
721
722    def get_A(self):
723        return self.A.todense()
724
725    def get_B(self):
726        return self.B.todense()
727
728    def get_D(self):
729        return self.D.todense()
730
731        #FIXME: Remember to re-introduce the 1/n factor in the
732        #interpolation term
733
734    def build_smoothing_matrix_D(self):
735        """Build m x m smoothing matrix, where
736        m is the number of basis functions phi_k (one per vertex)
737
738        The smoothing matrix is defined as
739
740        D = D1 + D2
741
742        where
743
744        [D1]_{k,l} = \int_\Omega
745           \frac{\partial \phi_k}{\partial x}
746           \frac{\partial \phi_l}{\partial x}\,
747           dx dy
748
749        [D2]_{k,l} = \int_\Omega
750           \frac{\partial \phi_k}{\partial y}
751           \frac{\partial \phi_l}{\partial y}\,
752           dx dy
753
754
755        The derivatives \frac{\partial \phi_k}{\partial x},
756        \frac{\partial \phi_k}{\partial x} for a particular triangle
757        are obtained by computing the gradient a_k, b_k for basis function k
758        """
759
760        #FIXME: algorithm might be optimised by computing local 9x9
761        #"element stiffness matrices:
762
763        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
764
765        self.D = Sparse(m,m)
766
767        #For each triangle compute contributions to D = D1+D2
768        for i in range(len(self.mesh)):
769
770            #Get area
771            area = self.mesh.areas[i]
772
773            #Get global vertex indices
774            v0 = self.mesh.triangles[i,0]
775            v1 = self.mesh.triangles[i,1]
776            v2 = self.mesh.triangles[i,2]
777
778            #Get the three vertex_points
779            xi0 = self.mesh.get_vertex_coordinate(i, 0)
780            xi1 = self.mesh.get_vertex_coordinate(i, 1)
781            xi2 = self.mesh.get_vertex_coordinate(i, 2)
782
783            #Compute gradients for each vertex
784            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
785                              1, 0, 0)
786
787            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
788                              0, 1, 0)
789
790            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
791                              0, 0, 1)
792
793            #Compute diagonal contributions
794            self.D[v0,v0] += (a0*a0 + b0*b0)*area
795            self.D[v1,v1] += (a1*a1 + b1*b1)*area
796            self.D[v2,v2] += (a2*a2 + b2*b2)*area
797
798            #Compute contributions for basis functions sharing edges
799            e01 = (a0*a1 + b0*b1)*area
800            self.D[v0,v1] += e01
801            self.D[v1,v0] += e01
802
803            e12 = (a1*a2 + b1*b2)*area
804            self.D[v1,v2] += e12
805            self.D[v2,v1] += e12
806
807            e20 = (a2*a0 + b2*b0)*area
808            self.D[v2,v0] += e20
809            self.D[v0,v2] += e20
810
811
812    def fit(self, z):
813        """Fit a smooth surface to given 1d array of data points z.
814
815        The smooth surface is computed at each vertex in the underlying
816        mesh using the formula given in the module doc string.
817
818        Pre Condition:
819          self.A, self.AtA and self.B have been initialised
820
821        Inputs:
822          z: Single 1d vector or array of data at the point_coordinates.
823        """
824
825        #Convert input to Numeric arrays
826        from util import ensure_numeric
827        z = ensure_numeric(z, Float)
828
829        if len(z.shape) > 1 :
830            raise VectorShapeError, 'Can only deal with 1d data vector'
831
832        if self.point_indices is not None:
833            #Remove values for any points that were outside mesh
834            z = take(z, self.point_indices)
835
836        #Compute right hand side based on data
837        Atz = self.A.trans_mult(z)
838
839
840        #Check sanity
841        n, m = self.A.shape
842        if n<m and self.alpha == 0.0:
843            msg = 'ERROR (least_squares): Too few data points\n'
844            msg += 'There are only %d data points and alpha == 0. ' %n
845            msg += 'Need at least %d\n' %m
846            msg += 'Alternatively, set smoothing parameter alpha to a small '
847            msg += 'positive value,\ne.g. 1.0e-3.'
848            raise msg
849
850
851
852        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
853        #FIXME: Should we store the result here for later use? (ON)
854
855
856    def fit_points(self, z, verbose=False):
857        """Like fit, but more robust when each point has two or more attributes
858        FIXME (Ole): The name fit_points doesn't carry any meaning
859        for me. How about something like fit_multiple or fit_columns?
860        """
861
862        try:
863            if verbose: print 'Solving penalised least_squares problem'
864            return self.fit(z)
865        except VectorShapeError, e:
866            # broadcasting is not supported.
867
868            #Convert input to Numeric arrays
869            from util import ensure_numeric
870            z = ensure_numeric(z, Float)
871
872            #Build n x m interpolation matrix
873            m = self.mesh.coordinates.shape[0] #Number of vertices
874            n = z.shape[1]                     #Number of data points
875
876            f = zeros((m,n), Float) #Resulting columns
877
878            for i in range(z.shape[1]):
879                f[:,i] = self.fit(z[:,i])
880
881            return f
882
883
884    def interpolate(self, f):
885        """Evaluate smooth surface f at data points implied in self.A.
886
887        The mesh values representing a smooth surface are
888        assumed to be specified in f. This argument could,
889        for example have been obtained from the method self.fit()
890
891        Pre Condition:
892          self.A has been initialised
893
894        Inputs:
895          f: Vector or array of data at the mesh vertices.
896          If f is an array, interpolation will be done for each column as
897          per underlying matrix-matrix multiplication
898
899        Output:
900          Interpolated values at data points implied in self.A
901
902        """
903
904        return self.A * f
905
906    def cull_outsiders(self, f):
907        pass
908
909
910
911
912class Interpolation_function:
913    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
914    which is interpolated from time series defined at vertices of
915    triangular mesh (such as those stored in sww files)
916
917    Let m be the number of vertices, n the number of triangles
918    and p the number of timesteps.
919
920    Mandatory input
921        time:               px1 array of monotonously increasing times (Float)
922        quantities:         Dictionary of arrays or 1 array (Float)
923                            The arrays must either have dimensions pxm or mx1.
924                            The resulting function will be time dependent in
925                            the former case while it will be constan with
926                            respect to time in the latter case.
927       
928    Optional input:
929        quantity_names:     List of keys into the quantities dictionary
930        vertex_coordinates: mx2 array of coordinates (Float)
931        triangles:          nx3 array of indices into vertex_coordinates (Int)
932        interpolation_points: Nx2 array of coordinates to be interpolated to
933        verbose:            Level of reporting
934   
935   
936    The quantities returned by the callable object are specified by
937    the list quantities which must contain the names of the
938    quantities to be returned and also reflect the order, e.g. for
939    the shallow water wave equation, on would have
940    quantities = ['stage', 'xmomentum', 'ymomentum']
941
942    The parameter interpolation_points decides at which points interpolated
943    quantities are to be computed whenever object is called.
944    If None, return average value
945    """
946
947   
948   
949    def __init__(self,
950                 time,
951                 quantities,
952                 quantity_names = None, 
953                 vertex_coordinates = None,
954                 triangles = None,
955                 interpolation_points = None,
956                 verbose = False):
957        """Initialise object and build spatial interpolation if required
958        """
959
960        from Numeric import array, zeros, Float, alltrue, concatenate,\
961             reshape, ArrayType
962
963
964        from util import mean, ensure_numeric
965        from config import time_format
966        import types
967
968
969
970        #Check temporal info
971        time = ensure_numeric(time)       
972        msg = 'Time must be a monotonuosly '
973        msg += 'increasing sequence %s' %time
974        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
975
976
977        #Check if quantities is a single array only
978        if type(quantities) != types.DictType:
979            quantities = ensure_numeric(quantities)
980            quantity_names = ['Attribute']
981
982            #Make it a dictionary
983            quantities = {quantity_names[0]: quantities}
984
985
986        #Use keys if no names are specified
987        if quantity_names is None:
988            quantity_names = quantities.keys()
989
990
991        #Check spatial info
992        if vertex_coordinates is None:
993            self.spatial = False
994        else:   
995            vertex_coordinates = ensure_numeric(vertex_coordinates)
996
997            assert triangles is not None, 'Triangles array must be specified'
998            triangles = ensure_numeric(triangles)
999            self.spatial = True           
1000           
1001
1002 
1003        #Save for use with statistics
1004        self.quantity_names = quantity_names       
1005        self.quantities = quantities       
1006        self.vertex_coordinates = vertex_coordinates
1007        self.interpolation_points = interpolation_points
1008        self.T = time[:]  # Time assumed to be relative to starttime
1009        self.index = 0    # Initial time index
1010        self.precomputed_values = {}
1011           
1012
1013
1014        #Precomputed spatial interpolation if requested
1015        if interpolation_points is not None:
1016            if self.spatial is False:
1017                raise 'Triangles and vertex_coordinates must be specified'
1018           
1019            try:
1020                self.interpolation_points = ensure_numeric(interpolation_points)
1021            except:
1022                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1023                      'or a list of points\n'
1024                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1025                                      '...')
1026                raise msg
1027
1028
1029            m = len(self.interpolation_points)
1030            p = len(self.T)
1031           
1032            for name in quantity_names:
1033                self.precomputed_values[name] = zeros((p, m), Float)
1034
1035            #Build interpolator
1036            interpol = Interpolation(vertex_coordinates,
1037                                     triangles,
1038                                     point_coordinates = \
1039                                     self.interpolation_points,
1040                                     alpha = 0,
1041                                     precrop = False, 
1042                                     verbose = verbose)
1043
1044            if verbose: print 'Interpolate'
1045            for i, t in enumerate(self.T):
1046                #Interpolate quantities at this timestep
1047                if verbose and i%((p+10)/10)==0:
1048                    print ' time step %d of %d' %(i, p)
1049                   
1050                for name in quantity_names:
1051                    if len(quantities[name].shape) == 2:
1052                        result = interpol.interpolate(quantities[name][i,:])
1053                    else:
1054                       #Assume no time dependency
1055                       result = interpol.interpolate(quantities[name][:])
1056                       
1057                    self.precomputed_values[name][i, :] = result
1058                   
1059                       
1060
1061            #Report
1062            if verbose:
1063                print self.statistics()
1064                #self.print_statistics()
1065           
1066        else:
1067            #Store quantitites as is
1068            for name in quantity_names:
1069                self.precomputed_values[name] = quantities[name]
1070
1071
1072        #else:
1073        #    #Return an average, making this a time series
1074        #    for name in quantity_names:
1075        #        self.values[name] = zeros(len(self.T), Float)
1076        #
1077        #    if verbose: print 'Compute mean values'
1078        #    for i, t in enumerate(self.T):
1079        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1080        #        for name in quantity_names:
1081        #           self.values[name][i] = mean(quantities[name][i,:])
1082
1083
1084
1085
1086    def __repr__(self):
1087        #return 'Interpolation function (spatio-temporal)'
1088        return self.statistics()
1089   
1090
1091    def __call__(self, t, point_id = None, x = None, y = None):
1092        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1093
1094        Inputs:
1095          t: time - Model time. Must lie within existing timesteps
1096          point_id: index of one of the preprocessed points.
1097          x, y:     Overrides location, point_id ignored
1098         
1099          If spatial info is present and all of x,y,point_id
1100          are None an exception is raised
1101                   
1102          If no spatial info is present, point_id and x,y arguments are ignored
1103          making f a function of time only.
1104
1105         
1106          FIXME: point_id could also be a slice
1107          FIXME: What if x and y are vectors?
1108          FIXME: What about f(x,y) without t?
1109        """
1110
1111        from math import pi, cos, sin, sqrt
1112        from Numeric import zeros, Float
1113        from util import mean       
1114
1115        if self.spatial is True:
1116            if point_id is None:
1117                if x is None or y is None:
1118                    msg = 'Either point_id or x and y must be specified'
1119                    raise msg
1120            else:
1121                if self.interpolation_points is None:
1122                    msg = 'Interpolation_function must be instantiated ' +\
1123                          'with a list of interpolation points before parameter ' +\
1124                          'point_id can be used'
1125                    raise msg
1126
1127
1128        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1129        msg += ' does not match model time: %s\n' %t
1130        if t < self.T[0]: raise msg
1131        if t > self.T[-1]: raise msg
1132
1133        oldindex = self.index #Time index
1134
1135        #Find current time slot
1136        while t > self.T[self.index]: self.index += 1
1137        while t < self.T[self.index]: self.index -= 1
1138
1139        if t == self.T[self.index]:
1140            #Protect against case where t == T[-1] (last time)
1141            # - also works in general when t == T[i]
1142            ratio = 0
1143        else:
1144            #t is now between index and index+1
1145            ratio = (t - self.T[self.index])/\
1146                    (self.T[self.index+1] - self.T[self.index])
1147
1148        #Compute interpolated values
1149        q = zeros(len(self.quantity_names), Float)
1150
1151        for i, name in enumerate(self.quantity_names):
1152            Q = self.precomputed_values[name]
1153
1154            if self.spatial is False:
1155                #If there is no spatial info               
1156                assert len(Q.shape) == 1
1157
1158                Q0 = Q[self.index]
1159                if ratio > 0: Q1 = Q[self.index+1]
1160
1161            else:
1162                if x is not None and y is not None:
1163                    #Interpolate to x, y
1164                   
1165                    raise 'x,y interpolation not yet implemented'
1166                else:
1167                    #Use precomputed point
1168                    Q0 = Q[self.index, point_id]
1169                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1170
1171            #Linear temporal interpolation   
1172            if ratio > 0:
1173                q[i] = Q0 + ratio*(Q1 - Q0)
1174            else:
1175                q[i] = Q0
1176
1177
1178        #Return vector of interpolated values
1179        #if len(q) == 1:
1180        #    return q[0]
1181        #else:
1182        #    return q
1183
1184
1185        #Return vector of interpolated values
1186        #FIXME:
1187        if self.spatial is True:
1188            return q
1189        else:
1190            #Replicate q according to x and y
1191            #This is e.g used for Wind_stress
1192            if x == None or y == None: 
1193                return q
1194            else:
1195                try:
1196                    N = len(x)
1197                except:
1198                    return q
1199                else:
1200                    from Numeric import ones, Float
1201                    #x is a vector - Create one constant column for each value
1202                    N = len(x)
1203                    assert len(y) == N, 'x and y must have same length'
1204                    res = []
1205                    for col in q:
1206                        res.append(col*ones(N, Float))
1207                       
1208                return res
1209
1210
1211    def statistics(self):
1212        """Output statistics about interpolation_function
1213        """
1214       
1215        vertex_coordinates = self.vertex_coordinates
1216        interpolation_points = self.interpolation_points               
1217        quantity_names = self.quantity_names
1218        quantities = self.quantities
1219        precomputed_values = self.precomputed_values                 
1220               
1221        x = vertex_coordinates[:,0]
1222        y = vertex_coordinates[:,1]               
1223
1224        str =  '------------------------------------------------\n'
1225        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1226        str += '  Extent:\n'
1227        str += '    x in [%f, %f], len(x) == %d\n'\
1228               %(min(x), max(x), len(x))
1229        str += '    y in [%f, %f], len(y) == %d\n'\
1230               %(min(y), max(y), len(y))
1231        str += '    t in [%f, %f], len(t) == %d\n'\
1232               %(min(self.T), max(self.T), len(self.T))
1233        str += '  Quantities:\n'
1234        for name in quantity_names:
1235            q = quantities[name][:].flat
1236            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1237
1238        if interpolation_points is not None:   
1239            str += '  Interpolation points (xi, eta):'\
1240                   ' number of points == %d\n' %interpolation_points.shape[0]
1241            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1242                                            max(interpolation_points[:,0]))
1243            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1244                                             max(interpolation_points[:,1]))
1245            str += '  Interpolated quantities (over all timesteps):\n'
1246       
1247            for name in quantity_names:
1248                q = precomputed_values[name][:].flat
1249                str += '    %s at interpolation points in [%f, %f]\n'\
1250                       %(name, min(q), max(q))
1251        str += '------------------------------------------------\n'
1252
1253        return str
1254
1255        #FIXME: Delete
1256        #print '------------------------------------------------'
1257        #print 'Interpolation_function statistics:'
1258        #print '  Extent:'
1259        #print '    x in [%f, %f], len(x) == %d'\
1260        #      %(min(x), max(x), len(x))
1261        #print '    y in [%f, %f], len(y) == %d'\
1262        #      %(min(y), max(y), len(y))
1263        #print '    t in [%f, %f], len(t) == %d'\
1264        #      %(min(self.T), max(self.T), len(self.T))
1265        #print '  Quantities:'
1266        #for name in quantity_names:
1267        #    q = quantities[name][:].flat
1268        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1269        #print '  Interpolation points (xi, eta):'\
1270        #      ' number of points == %d ' %interpolation_points.shape[0]
1271        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1272        #                             max(interpolation_points[:,0]))
1273        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1274        #                              max(interpolation_points[:,1]))
1275        #print '  Interpolated quantities (over all timesteps):'
1276        #
1277        #for name in quantity_names:
1278        #    q = precomputed_values[name][:].flat
1279        #    print '    %s at interpolation points in [%f, %f]'\
1280        #          %(name, min(q), max(q))
1281        #print '------------------------------------------------'
1282
1283
1284#-------------------------------------------------------------
1285if __name__ == "__main__":
1286    """
1287    Load in a mesh and data points with attributes.
1288    Fit the attributes to the mesh.
1289    Save a new mesh file.
1290    """
1291    import os, sys
1292    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1293            %os.path.basename(sys.argv[0])
1294
1295    if len(sys.argv) < 4:
1296        print usage
1297    else:
1298        mesh_file = sys.argv[1]
1299        point_file = sys.argv[2]
1300        mesh_output_file = sys.argv[3]
1301
1302        expand_search = False
1303        if len(sys.argv) > 4:
1304            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1305                expand_search = True
1306            else:
1307                expand_search = False
1308
1309        verbose = False
1310        if len(sys.argv) > 5:
1311            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1312                verbose = False
1313            else:
1314                verbose = True
1315
1316        if len(sys.argv) > 6:
1317            alpha = sys.argv[6]
1318        else:
1319            alpha = DEFAULT_ALPHA
1320
1321        # This is used more for testing
1322        if len(sys.argv) > 7:
1323            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1324                display_errors = False
1325            else:
1326                display_errors = True
1327           
1328        t0 = time.time()
1329        try:
1330            fit_to_mesh_file(mesh_file,
1331                         point_file,
1332                         mesh_output_file,
1333                         alpha,
1334                         verbose= verbose,
1335                         expand_search = expand_search,
1336                         display_errors = display_errors)
1337        except IOError,e:
1338            import sys; sys.exit(1)
1339
1340        print 'That took %.2f seconds' %(time.time()-t0)
1341
Note: See TracBrowser for help on using the repository browser.