source: inundation/pyvolution/least_squares.py @ 2447

Last change on this file since 2447 was 2447, checked in by ole, 18 years ago

Build in caching into least squares + comments

File size: 48.8 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from pyvolution.mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from pyvolution.sparse import Sparse, Sparse_CSR
29from pyvolution.cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME (DSG-ON) reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose: print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ", mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False,
177                use_cache = False):
178    """
179    Fit a smooth surface to a triangulation,
180    given data points with attributes.
181
182
183        Inputs:
184
185          vertex_coordinates: List of coordinate pairs [xi, eta] of points
186          constituting mesh (or a an m x 2 Numeric array)
187
188          triangles: List of 3-tuples (or a Numeric array) of
189          integers representing indices of all vertices in the mesh.
190
191          point_coordinates: List of coordinate pairs [x, y] of data points
192          (or an nx2 Numeric array)
193
194          alpha: Smoothing parameter.
195
196          point_attributes: Vector or array of data at the point_coordinates.
197
198          data_origin and mesh_origin are 3-tuples consisting of
199          UTM zone, easting and northing. If specified
200          point coordinates and vertex coordinates are assumed to be
201          relative to their respective origins.
202
203    """
204
205    if use_cache is True:
206        from caching.caching import cache
207        interp = cache(_interpolation,
208                       (vertex_coordinates,
209                        triangles,
210                        point_coordinates),
211                       {'alpha': alpha,
212                        'verbose': verbose,
213                        'expand_search': expand_search,
214                        'data_origin': data_origin,
215                        'mesh_origin': mesh_origin,
216                        'precrop': precrop},
217                       verbose = verbose)       
218       
219    else:
220        interp = Interpolation(vertex_coordinates,
221                               triangles,
222                               point_coordinates,
223                               alpha = alpha,
224                               verbose = verbose,
225                               expand_search = expand_search,
226                               data_origin = data_origin,
227                               mesh_origin = mesh_origin,
228                               precrop = precrop)
229
230    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
231    return vertex_attributes
232
233
234
235def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
236                    verbose = False, reduction = 1):
237    """Fits attributes from pts file to MxN rectangular mesh
238
239    Read pts file and create rectangular mesh of resolution MxN such that
240    it covers all points specified in pts file.
241
242    FIXME: This may be a temporary function until we decide on
243    netcdf formats etc
244
245    FIXME: Uses elevation hardwired
246    """
247
248    import  mesh_factory
249    from load_mesh.loadASCII import import_points_file
250   
251    if verbose: print 'Read pts'
252    points_dict = import_points_file(pts_name)
253    #points, attributes = util.read_xya(pts_name)
254
255    #Reduce number of points a bit
256    points = points_dict['pointlist'][::reduction]
257    elevation = points_dict['attributelist']['elevation']  #Must be elevation
258    elevation = elevation[::reduction]
259
260    if verbose: print 'Got %d data points' %len(points)
261
262    if verbose: print 'Create mesh'
263    #Find extent
264    max_x = min_x = points[0][0]
265    max_y = min_y = points[0][1]
266    for point in points[1:]:
267        x = point[0]
268        if x > max_x: max_x = x
269        if x < min_x: min_x = x
270        y = point[1]
271        if y > max_y: max_y = y
272        if y < min_y: min_y = y
273
274    #Create appropriate mesh
275    vertex_coordinates, triangles, boundary =\
276         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
277                                (min_x, min_y))
278
279    #Fit attributes to mesh
280    vertex_attributes = fit_to_mesh(vertex_coordinates,
281                        triangles,
282                        points,
283                        elevation, alpha=alpha, verbose=verbose)
284
285
286
287    return vertex_coordinates, triangles, boundary, vertex_attributes
288
289
290def _interpolation(*args, **kwargs):
291    """Private function for use with caching. Reason is that classes
292    may change their byte code between runs which is annoying.
293    """
294   
295    return Interpolation(*args, **kwargs)
296
297
298class Interpolation:
299
300    def __init__(self,
301                 vertex_coordinates,
302                 triangles,
303                 point_coordinates = None,
304                 alpha = None,
305                 verbose = False,
306                 expand_search = True,
307                 interp_only = False,
308                 max_points_per_cell = 30,
309                 mesh_origin = None,
310                 data_origin = None,
311                 precrop = False):
312
313
314        """ Build interpolation matrix mapping from
315        function values at vertices to function values at data points
316
317        Inputs:
318
319          vertex_coordinates: List of coordinate pairs [xi, eta] of
320          points constituting mesh (or a an m x 2 Numeric array)
321          Points may appear multiple times
322          (e.g. if vertices have discontinuities)
323
324          triangles: List of 3-tuples (or a Numeric array) of
325          integers representing indices of all vertices in the mesh.
326
327          point_coordinates: List of coordinate pairs [x, y] of
328          data points (or an nx2 Numeric array)
329          If point_coordinates is absent, only smoothing matrix will
330          be built
331
332          alpha: Smoothing parameter
333
334          data_origin and mesh_origin are 3-tuples consisting of
335          UTM zone, easting and northing. If specified
336          point coordinates and vertex coordinates are assumed to be
337          relative to their respective origins.
338
339        """
340        from pyvolution.util import ensure_numeric
341
342        #Convert input to Numeric arrays
343        triangles = ensure_numeric(triangles, Int)
344        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
345
346        #Build underlying mesh
347        if verbose: print 'Building mesh'
348        #self.mesh = General_mesh(vertex_coordinates, triangles,
349        #FIXME: Trying the normal mesh while testing precrop,
350        #       The functionality of boundary_polygon is needed for that
351
352        #FIXME - geo ref does not have to go into mesh.
353        # Change the point co-ords to conform to the
354        # mesh co-ords early in the code
355        if mesh_origin is None:
356            geo = None
357        else:
358            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
359        self.mesh = Mesh(vertex_coordinates, triangles,
360                         geo_reference = geo)
361       
362        self.mesh.check_integrity()
363
364        self.data_origin = data_origin
365
366        self.point_indices = None
367
368        #Smoothing parameter
369        if alpha is None:
370            self.alpha = DEFAULT_ALPHA
371        else:   
372            self.alpha = alpha
373
374
375        if point_coordinates is not None:
376            if verbose: print 'Building interpolation matrix'
377            self.build_interpolation_matrix_A(point_coordinates,
378                                              verbose = verbose,
379                                              expand_search = expand_search,
380                                              interp_only = interp_only, 
381                                              max_points_per_cell =\
382                                              max_points_per_cell,
383                                              data_origin = data_origin,
384                                              precrop = precrop)
385        #Build coefficient matrices
386        if interp_only == False:
387            self.build_coefficient_matrix_B(point_coordinates,
388                                        verbose = verbose,
389                                        expand_search = expand_search,
390                                        max_points_per_cell =\
391                                        max_points_per_cell,
392                                        data_origin = data_origin,
393                                        precrop = precrop)
394
395    def set_point_coordinates(self, point_coordinates,
396                              data_origin = None,
397                              verbose = False,
398                              precrop = True):
399        """
400        A public interface to setting the point co-ordinates.
401        """
402        if point_coordinates is not None:
403            if verbose: print 'Building interpolation matrix'
404            self.build_interpolation_matrix_A(point_coordinates,
405                                              verbose = verbose,
406                                              data_origin = data_origin,
407                                              precrop = precrop)
408        self.build_coefficient_matrix_B(point_coordinates, data_origin)
409
410    def build_coefficient_matrix_B(self, point_coordinates=None,
411                                   verbose = False, expand_search = True,
412                                   max_points_per_cell=30,
413                                   data_origin = None,
414                                   precrop = False):
415        """Build final coefficient matrix"""
416
417
418        if self.alpha <> 0:
419            if verbose: print 'Building smoothing matrix'
420            self.build_smoothing_matrix_D()
421
422        if point_coordinates is not None:
423            if self.alpha <> 0:
424                self.B = self.AtA + self.alpha*self.D
425            else:
426                self.B = self.AtA
427
428            #Convert self.B matrix to CSR format for faster matrix vector
429            self.B = Sparse_CSR(self.B)
430
431    def build_interpolation_matrix_A(self, point_coordinates,
432                                     verbose = False, expand_search = True,
433                                     max_points_per_cell=30,
434                                     data_origin = None,
435                                     precrop = False,
436                                     interp_only = False):
437        """Build n x m interpolation matrix, where
438        n is the number of data points and
439        m is the number of basis functions phi_k (one per vertex)
440
441        This algorithm uses a quad tree data structure for fast binning of data points
442        origin is a 3-tuple consisting of UTM zone, easting and northing.
443        If specified coordinates are assumed to be relative to this origin.
444
445        This one will override any data_origin that may be specified in
446        interpolation instance
447
448        """
449
450
451
452        #FIXME (Ole): Check that this function is memeory efficient.
453        #6 million datapoints and 300000 basis functions
454        #causes out-of-memory situation
455        #First thing to check is whether there is room for self.A and self.AtA
456        #
457        #Maybe we need some sort of blocking
458
459        from pyvolution.quad import build_quadtree
460        from utilities.numerical_tools import ensure_numeric
461        from utilities.polygon import inside_polygon
462       
463
464        if data_origin is None:
465            data_origin = self.data_origin #Use the one from
466                                           #interpolation instance
467
468        #Convert input to Numeric arrays just in case.
469        point_coordinates = ensure_numeric(point_coordinates, Float)
470
471        #Keep track of discarded points (if any).
472        #This is only registered if precrop is True
473        self.cropped_points = False
474
475        #Shift data points to same origin as mesh (if specified)
476
477        #FIXME this will shift if there was no geo_ref.
478        #But all this should be removed anyhow.
479        #change coords before this point
480        mesh_origin = self.mesh.geo_reference.get_origin()
481        if point_coordinates is not None:
482            if data_origin is not None:
483                if mesh_origin is not None:
484
485                    #Transformation:
486                    #
487                    #Let x_0 be the reference point of the point coordinates
488                    #and xi_0 the reference point of the mesh.
489                    #
490                    #A point coordinate (x + x_0) is then made relative
491                    #to xi_0 by
492                    #
493                    # x_new = x + x_0 - xi_0
494                    #
495                    #and similarly for eta
496
497                    x_offset = data_origin[1] - mesh_origin[1]
498                    y_offset = data_origin[2] - mesh_origin[2]
499                else: #Shift back to a zero origin
500                    x_offset = data_origin[1]
501                    y_offset = data_origin[2]
502
503                point_coordinates[:,0] += x_offset
504                point_coordinates[:,1] += y_offset
505            else:
506                if mesh_origin is not None:
507                    #Use mesh origin for data points
508                    point_coordinates[:,0] -= mesh_origin[1]
509                    point_coordinates[:,1] -= mesh_origin[2]
510
511
512
513        #Remove points falling outside mesh boundary
514        #This reduced one example from 1356 seconds to 825 seconds
515
516       
517        if precrop is True:
518            from Numeric import take
519
520            if verbose: print 'Getting boundary polygon'
521            P = self.mesh.get_boundary_polygon()
522
523            if verbose: print 'Getting indices inside mesh boundary'
524            indices = inside_polygon(point_coordinates, P, verbose = verbose)
525
526
527            if len(indices) != point_coordinates.shape[0]:
528                self.cropped_points = True
529                if verbose:
530                    print 'Done - %d points outside mesh have been cropped.'\
531                          %(point_coordinates.shape[0] - len(indices))
532
533            point_coordinates = take(point_coordinates, indices)
534            self.point_indices = indices
535
536
537
538
539        #Build n x m interpolation matrix
540        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
541        n = point_coordinates.shape[0]     #Nbr of data points
542
543        if verbose: print 'Number of datapoints: %d' %n
544        if verbose: print 'Number of basis functions: %d' %m
545
546        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
547        #However, Sparse_CSR does not have the same methods as Sparse yet
548        #The tests will reveal what needs to be done
549
550        #
551        #self.A = Sparse_CSR(Sparse(n,m))
552        #self.AtA = Sparse_CSR(Sparse(m,m))
553        self.A = Sparse(n,m)
554        self.AtA = Sparse(m,m)
555
556        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
557        root = build_quadtree(self.mesh,
558                              max_points_per_cell = max_points_per_cell)
559        #root.show()
560        self.expanded_quad_searches = []
561        #Compute matrix elements
562        for i in range(n):
563            #For each data_coordinate point
564
565            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
566            x = point_coordinates[i]
567
568            #Find vertices near x
569            candidate_vertices = root.search(x[0], x[1])
570            is_more_elements = True
571
572            element_found, sigma0, sigma1, sigma2, k = \
573                self.search_triangles_of_vertices(candidate_vertices, x)
574            first_expansion = True
575            while not element_found and is_more_elements and expand_search:
576                #if verbose: print 'Expanding search'
577                if first_expansion == True:
578                    self.expanded_quad_searches.append(1)
579                    first_expansion = False
580                else:
581                    end = len(self.expanded_quad_searches) - 1
582                    assert end >= 0
583                    self.expanded_quad_searches[end] += 1
584                candidate_vertices, branch = root.expand_search()
585                if branch == []:
586                    # Searching all the verts from the root cell that haven't
587                    # been searched.  This is the last try
588                    element_found, sigma0, sigma1, sigma2, k = \
589                      self.search_triangles_of_vertices(candidate_vertices, x)
590                    is_more_elements = False
591                else:
592                    element_found, sigma0, sigma1, sigma2, k = \
593                      self.search_triangles_of_vertices(candidate_vertices, x)
594
595               
596            #Update interpolation matrix A if necessary
597            if element_found is True:
598                #Assign values to matrix A
599
600                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
601                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
602                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
603
604                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
605                js     = [j0,j1,j2]
606
607                for j in js:
608                    self.A[i,j] = sigmas[j]
609                    for k in js:
610                        if interp_only == False:
611                            self.AtA[j,k] += sigmas[j]*sigmas[k]
612            else:
613                pass
614                #Ok if there is no triangle for datapoint
615                #(as in brute force version)
616                #raise 'Could not find triangle for point', x
617
618
619
620    def search_triangles_of_vertices(self, candidate_vertices, x):
621            #Find triangle containing x:
622            element_found = False
623
624            # This will be returned if element_found = False
625            sigma2 = -10.0
626            sigma0 = -10.0
627            sigma1 = -10.0
628            k = -10.0
629            #print "*$* candidate_vertices", candidate_vertices
630            #For all vertices in same cell as point x
631            for v in candidate_vertices:
632                #FIXME (DSG-DSG): this catches verts with no triangle.
633                #Currently pmesh is producing these.
634                #this should be stopped,
635                if self.mesh.vertexlist[v] is None:
636                    continue
637                #for each triangle id (k) which has v as a vertex
638                for k, _ in self.mesh.vertexlist[v]:
639
640                    #Get the three vertex_points of candidate triangle
641                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
642                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
643                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
644
645                    #print "PDSG - k", k
646                    #print "PDSG - xi0", xi0
647                    #print "PDSG - xi1", xi1
648                    #print "PDSG - xi2", xi2
649                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
650                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
651
652                    #Get the three normals
653                    n0 = self.mesh.get_normal(k, 0)
654                    n1 = self.mesh.get_normal(k, 1)
655                    n2 = self.mesh.get_normal(k, 2)
656
657
658                    #Compute interpolation
659                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
660                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
661                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
662
663                    #print "PDSG - sigma0", sigma0
664                    #print "PDSG - sigma1", sigma1
665                    #print "PDSG - sigma2", sigma2
666
667                    #FIXME: Maybe move out to test or something
668                    epsilon = 1.0e-6
669                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
670
671                    #Check that this triangle contains the data point
672
673                    #Sigmas can get negative within
674                    #machine precision on some machines (e.g nautilus)
675                    #Hence the small eps
676                    eps = 1.0e-15
677                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
678                        element_found = True
679                        break
680
681                if element_found is True:
682                    #Don't look for any other triangle
683                    break
684            return element_found, sigma0, sigma1, sigma2, k
685
686
687
688    def build_interpolation_matrix_A_brute(self, point_coordinates):
689        """Build n x m interpolation matrix, where
690        n is the number of data points and
691        m is the number of basis functions phi_k (one per vertex)
692
693        This is the brute force which is too slow for large problems,
694        but could be used for testing
695        """
696
697        from util import ensure_numeric
698
699        #Convert input to Numeric arrays
700        point_coordinates = ensure_numeric(point_coordinates, Float)
701
702        #Build n x m interpolation matrix
703        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
704        n = point_coordinates.shape[0]     #Nbr of data points
705
706        self.A = Sparse(n,m)
707        self.AtA = Sparse(m,m)
708
709        #Compute matrix elements
710        for i in range(n):
711            #For each data_coordinate point
712
713            x = point_coordinates[i]
714            element_found = False
715            k = 0
716            while not element_found and k < len(self.mesh):
717                #For each triangle (brute force)
718                #FIXME: Real algorithm should only visit relevant triangles
719
720                #Get the three vertex_points
721                xi0 = self.mesh.get_vertex_coordinate(k, 0)
722                xi1 = self.mesh.get_vertex_coordinate(k, 1)
723                xi2 = self.mesh.get_vertex_coordinate(k, 2)
724
725                #Get the three normals
726                n0 = self.mesh.get_normal(k, 0)
727                n1 = self.mesh.get_normal(k, 1)
728                n2 = self.mesh.get_normal(k, 2)
729
730                #Compute interpolation
731                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
732                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
733                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
734
735                #FIXME: Maybe move out to test or something
736                epsilon = 1.0e-6
737                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
738
739                #Check that this triangle contains data point
740                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
741                    element_found = True
742                    #Assign values to matrix A
743
744                    j0 = self.mesh.triangles[k,0] #Global vertex id
745                    #self.A[i, j0] = sigma0
746
747                    j1 = self.mesh.triangles[k,1] #Global vertex id
748                    #self.A[i, j1] = sigma1
749
750                    j2 = self.mesh.triangles[k,2] #Global vertex id
751                    #self.A[i, j2] = sigma2
752
753                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
754                    js     = [j0,j1,j2]
755
756                    for j in js:
757                        self.A[i,j] = sigmas[j]
758                        for k in js:
759                            self.AtA[j,k] += sigmas[j]*sigmas[k]
760                k = k+1
761
762
763
764    def get_A(self):
765        return self.A.todense()
766
767    def get_B(self):
768        return self.B.todense()
769
770    def get_D(self):
771        return self.D.todense()
772
773        #FIXME: Remember to re-introduce the 1/n factor in the
774        #interpolation term
775
776    def build_smoothing_matrix_D(self):
777        """Build m x m smoothing matrix, where
778        m is the number of basis functions phi_k (one per vertex)
779
780        The smoothing matrix is defined as
781
782        D = D1 + D2
783
784        where
785
786        [D1]_{k,l} = \int_\Omega
787           \frac{\partial \phi_k}{\partial x}
788           \frac{\partial \phi_l}{\partial x}\,
789           dx dy
790
791        [D2]_{k,l} = \int_\Omega
792           \frac{\partial \phi_k}{\partial y}
793           \frac{\partial \phi_l}{\partial y}\,
794           dx dy
795
796
797        The derivatives \frac{\partial \phi_k}{\partial x},
798        \frac{\partial \phi_k}{\partial x} for a particular triangle
799        are obtained by computing the gradient a_k, b_k for basis function k
800        """
801
802        #FIXME: algorithm might be optimised by computing local 9x9
803        #"element stiffness matrices:
804
805        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
806
807        self.D = Sparse(m,m)
808
809        #For each triangle compute contributions to D = D1+D2
810        for i in range(len(self.mesh)):
811
812            #Get area
813            area = self.mesh.areas[i]
814
815            #Get global vertex indices
816            v0 = self.mesh.triangles[i,0]
817            v1 = self.mesh.triangles[i,1]
818            v2 = self.mesh.triangles[i,2]
819
820            #Get the three vertex_points
821            xi0 = self.mesh.get_vertex_coordinate(i, 0)
822            xi1 = self.mesh.get_vertex_coordinate(i, 1)
823            xi2 = self.mesh.get_vertex_coordinate(i, 2)
824
825            #Compute gradients for each vertex
826            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
827                              1, 0, 0)
828
829            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
830                              0, 1, 0)
831
832            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
833                              0, 0, 1)
834
835            #Compute diagonal contributions
836            self.D[v0,v0] += (a0*a0 + b0*b0)*area
837            self.D[v1,v1] += (a1*a1 + b1*b1)*area
838            self.D[v2,v2] += (a2*a2 + b2*b2)*area
839
840            #Compute contributions for basis functions sharing edges
841            e01 = (a0*a1 + b0*b1)*area
842            self.D[v0,v1] += e01
843            self.D[v1,v0] += e01
844
845            e12 = (a1*a2 + b1*b2)*area
846            self.D[v1,v2] += e12
847            self.D[v2,v1] += e12
848
849            e20 = (a2*a0 + b2*b0)*area
850            self.D[v2,v0] += e20
851            self.D[v0,v2] += e20
852
853
854    def fit(self, z):
855        """Fit a smooth surface to given 1d array of data points z.
856
857        The smooth surface is computed at each vertex in the underlying
858        mesh using the formula given in the module doc string.
859
860        Pre Condition:
861          self.A, self.AtA and self.B have been initialised
862
863        Inputs:
864          z: Single 1d vector or array of data at the point_coordinates.
865        """
866
867        #Convert input to Numeric arrays
868        from pyvolution.util import ensure_numeric
869        z = ensure_numeric(z, Float)
870
871        if len(z.shape) > 1 :
872            raise VectorShapeError, 'Can only deal with 1d data vector'
873
874        if self.point_indices is not None:
875            #Remove values for any points that were outside mesh
876            z = take(z, self.point_indices)
877
878        #Compute right hand side based on data
879        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
880        # after a matrix is built, before calcs.
881        Atz = self.A.trans_mult(z)
882
883
884        #Check sanity
885        n, m = self.A.shape
886        if n<m and self.alpha == 0.0:
887            msg = 'ERROR (least_squares): Too few data points\n'
888            msg += 'There are only %d data points and alpha == 0. ' %n
889            msg += 'Need at least %d\n' %m
890            msg += 'Alternatively, set smoothing parameter alpha to a small '
891            msg += 'positive value,\ne.g. 1.0e-3.'
892            raise msg
893
894
895
896        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
897        #FIXME: Should we store the result here for later use? (ON)
898
899
900    def fit_points(self, z, verbose=False):
901        """Like fit, but more robust when each point has two or more attributes
902        FIXME (Ole): The name fit_points doesn't carry any meaning
903        for me. How about something like fit_multiple or fit_columns?
904        """
905
906        try:
907            if verbose: print 'Solving penalised least_squares problem'
908            return self.fit(z)
909        except VectorShapeError, e:
910            # broadcasting is not supported.
911
912            #Convert input to Numeric arrays
913            from util import ensure_numeric
914            z = ensure_numeric(z, Float)
915
916            #Build n x m interpolation matrix
917            m = self.mesh.coordinates.shape[0] #Number of vertices
918            n = z.shape[1]                     #Number of data points
919
920            f = zeros((m,n), Float) #Resulting columns
921
922            for i in range(z.shape[1]):
923                f[:,i] = self.fit(z[:,i])
924
925            return f
926
927
928    def interpolate(self, f):
929        """Evaluate smooth surface f at data points implied in self.A.
930
931        The mesh values representing a smooth surface are
932        assumed to be specified in f. This argument could,
933        for example have been obtained from the method self.fit()
934
935        Pre Condition:
936          self.A has been initialised
937
938        Inputs:
939          f: Vector or array of data at the mesh vertices.
940          If f is an array, interpolation will be done for each column as
941          per underlying matrix-matrix multiplication
942
943        Output:
944          Interpolated values at data points implied in self.A
945
946        """
947
948        return self.A * f
949
950    def cull_outsiders(self, f):
951        pass
952
953
954
955
956class Interpolation_function:
957    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
958    which is interpolated from time series defined at vertices of
959    triangular mesh (such as those stored in sww files)
960
961    Let m be the number of vertices, n the number of triangles
962    and p the number of timesteps.
963
964    Mandatory input
965        time:               px1 array of monotonously increasing times (Float)
966        quantities:         Dictionary of arrays or 1 array (Float)
967                            The arrays must either have dimensions pxm or mx1.
968                            The resulting function will be time dependent in
969                            the former case while it will be constan with
970                            respect to time in the latter case.
971       
972    Optional input:
973        quantity_names:     List of keys into the quantities dictionary
974        vertex_coordinates: mx2 array of coordinates (Float)
975        triangles:          nx3 array of indices into vertex_coordinates (Int)
976        interpolation_points: Nx2 array of coordinates to be interpolated to
977        verbose:            Level of reporting
978   
979   
980    The quantities returned by the callable object are specified by
981    the list quantities which must contain the names of the
982    quantities to be returned and also reflect the order, e.g. for
983    the shallow water wave equation, on would have
984    quantities = ['stage', 'xmomentum', 'ymomentum']
985
986    The parameter interpolation_points decides at which points interpolated
987    quantities are to be computed whenever object is called.
988    If None, return average value
989    """
990
991   
992   
993    def __init__(self,
994                 time,
995                 quantities,
996                 quantity_names = None, 
997                 vertex_coordinates = None,
998                 triangles = None,
999                 interpolation_points = None,
1000                 verbose = False):
1001        """Initialise object and build spatial interpolation if required
1002        """
1003
1004        from Numeric import array, zeros, Float, alltrue, concatenate,\
1005             reshape, ArrayType
1006
1007
1008        from util import mean, ensure_numeric
1009        from config import time_format
1010        import types
1011
1012
1013
1014        #Check temporal info
1015        time = ensure_numeric(time)       
1016        msg = 'Time must be a monotonuosly '
1017        msg += 'increasing sequence %s' %time
1018        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
1019
1020
1021        #Check if quantities is a single array only
1022        if type(quantities) != types.DictType:
1023            quantities = ensure_numeric(quantities)
1024            quantity_names = ['Attribute']
1025
1026            #Make it a dictionary
1027            quantities = {quantity_names[0]: quantities}
1028
1029
1030        #Use keys if no names are specified
1031        if quantity_names is None:
1032            quantity_names = quantities.keys()
1033
1034
1035        #Check spatial info
1036        if vertex_coordinates is None:
1037            self.spatial = False
1038        else:   
1039            vertex_coordinates = ensure_numeric(vertex_coordinates)
1040
1041            assert triangles is not None, 'Triangles array must be specified'
1042            triangles = ensure_numeric(triangles)
1043            self.spatial = True           
1044           
1045
1046 
1047        #Save for use with statistics
1048        self.quantity_names = quantity_names       
1049        self.quantities = quantities       
1050        self.vertex_coordinates = vertex_coordinates
1051        self.interpolation_points = interpolation_points
1052        self.T = time[:]  # Time assumed to be relative to starttime
1053        self.index = 0    # Initial time index
1054        self.precomputed_values = {}
1055           
1056
1057
1058        #Precomputed spatial interpolation if requested
1059        if interpolation_points is not None:
1060            if self.spatial is False:
1061                raise 'Triangles and vertex_coordinates must be specified'
1062           
1063            try:
1064                self.interpolation_points = ensure_numeric(interpolation_points)
1065            except:
1066                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1067                      'or a list of points\n'
1068                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1069                                      '...')
1070                raise msg
1071
1072
1073            m = len(self.interpolation_points)
1074            p = len(self.T)
1075           
1076            for name in quantity_names:
1077                self.precomputed_values[name] = zeros((p, m), Float)
1078
1079            #Build interpolator
1080            interpol = Interpolation(vertex_coordinates,
1081                                     triangles,
1082                                     point_coordinates = \
1083                                     self.interpolation_points,
1084                                     alpha = 0,
1085                                     precrop = False, 
1086                                     verbose = verbose)
1087
1088            if verbose: print 'Interpolate'
1089            for i, t in enumerate(self.T):
1090                #Interpolate quantities at this timestep
1091                if verbose and i%((p+10)/10)==0:
1092                    print ' time step %d of %d' %(i, p)
1093                   
1094                for name in quantity_names:
1095                    if len(quantities[name].shape) == 2:
1096                        result = interpol.interpolate(quantities[name][i,:])
1097                    else:
1098                       #Assume no time dependency
1099                       result = interpol.interpolate(quantities[name][:])
1100                       
1101                    self.precomputed_values[name][i, :] = result
1102                   
1103                       
1104
1105            #Report
1106            if verbose:
1107                print self.statistics()
1108                #self.print_statistics()
1109           
1110        else:
1111            #Store quantitites as is
1112            for name in quantity_names:
1113                self.precomputed_values[name] = quantities[name]
1114
1115
1116        #else:
1117        #    #Return an average, making this a time series
1118        #    for name in quantity_names:
1119        #        self.values[name] = zeros(len(self.T), Float)
1120        #
1121        #    if verbose: print 'Compute mean values'
1122        #    for i, t in enumerate(self.T):
1123        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1124        #        for name in quantity_names:
1125        #           self.values[name][i] = mean(quantities[name][i,:])
1126
1127
1128
1129
1130    def __repr__(self):
1131        #return 'Interpolation function (spatio-temporal)'
1132        return self.statistics()
1133   
1134
1135    def __call__(self, t, point_id = None, x = None, y = None):
1136        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1137
1138        Inputs:
1139          t: time - Model time. Must lie within existing timesteps
1140          point_id: index of one of the preprocessed points.
1141          x, y:     Overrides location, point_id ignored
1142         
1143          If spatial info is present and all of x,y,point_id
1144          are None an exception is raised
1145                   
1146          If no spatial info is present, point_id and x,y arguments are ignored
1147          making f a function of time only.
1148
1149         
1150          FIXME: point_id could also be a slice
1151          FIXME: What if x and y are vectors?
1152          FIXME: What about f(x,y) without t?
1153        """
1154
1155        from math import pi, cos, sin, sqrt
1156        from Numeric import zeros, Float
1157        from util import mean       
1158
1159        if self.spatial is True:
1160            if point_id is None:
1161                if x is None or y is None:
1162                    msg = 'Either point_id or x and y must be specified'
1163                    raise msg
1164            else:
1165                if self.interpolation_points is None:
1166                    msg = 'Interpolation_function must be instantiated ' +\
1167                          'with a list of interpolation points before parameter ' +\
1168                          'point_id can be used'
1169                    raise msg
1170
1171
1172        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1173        msg += ' does not match model time: %s\n' %t
1174        if t < self.T[0]: raise msg
1175        if t > self.T[-1]: raise msg
1176
1177        oldindex = self.index #Time index
1178
1179        #Find current time slot
1180        while t > self.T[self.index]: self.index += 1
1181        while t < self.T[self.index]: self.index -= 1
1182
1183        if t == self.T[self.index]:
1184            #Protect against case where t == T[-1] (last time)
1185            # - also works in general when t == T[i]
1186            ratio = 0
1187        else:
1188            #t is now between index and index+1
1189            ratio = (t - self.T[self.index])/\
1190                    (self.T[self.index+1] - self.T[self.index])
1191
1192        #Compute interpolated values
1193        q = zeros(len(self.quantity_names), Float)
1194
1195        for i, name in enumerate(self.quantity_names):
1196            Q = self.precomputed_values[name]
1197
1198            if self.spatial is False:
1199                #If there is no spatial info               
1200                assert len(Q.shape) == 1
1201
1202                Q0 = Q[self.index]
1203                if ratio > 0: Q1 = Q[self.index+1]
1204
1205            else:
1206                if x is not None and y is not None:
1207                    #Interpolate to x, y
1208                   
1209                    raise 'x,y interpolation not yet implemented'
1210                else:
1211                    #Use precomputed point
1212                    Q0 = Q[self.index, point_id]
1213                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1214
1215            #Linear temporal interpolation   
1216            if ratio > 0:
1217                q[i] = Q0 + ratio*(Q1 - Q0)
1218            else:
1219                q[i] = Q0
1220
1221
1222        #Return vector of interpolated values
1223        #if len(q) == 1:
1224        #    return q[0]
1225        #else:
1226        #    return q
1227
1228
1229        #Return vector of interpolated values
1230        #FIXME:
1231        if self.spatial is True:
1232            return q
1233        else:
1234            #Replicate q according to x and y
1235            #This is e.g used for Wind_stress
1236            if x is None or y is None: 
1237                return q
1238            else:
1239                try:
1240                    N = len(x)
1241                except:
1242                    return q
1243                else:
1244                    from Numeric import ones, Float
1245                    #x is a vector - Create one constant column for each value
1246                    N = len(x)
1247                    assert len(y) == N, 'x and y must have same length'
1248                    res = []
1249                    for col in q:
1250                        res.append(col*ones(N, Float))
1251                       
1252                return res
1253
1254
1255    def statistics(self):
1256        """Output statistics about interpolation_function
1257        """
1258       
1259        vertex_coordinates = self.vertex_coordinates
1260        interpolation_points = self.interpolation_points               
1261        quantity_names = self.quantity_names
1262        quantities = self.quantities
1263        precomputed_values = self.precomputed_values                 
1264               
1265        x = vertex_coordinates[:,0]
1266        y = vertex_coordinates[:,1]               
1267
1268        str =  '------------------------------------------------\n'
1269        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1270        str += '  Extent:\n'
1271        str += '    x in [%f, %f], len(x) == %d\n'\
1272               %(min(x), max(x), len(x))
1273        str += '    y in [%f, %f], len(y) == %d\n'\
1274               %(min(y), max(y), len(y))
1275        str += '    t in [%f, %f], len(t) == %d\n'\
1276               %(min(self.T), max(self.T), len(self.T))
1277        str += '  Quantities:\n'
1278        for name in quantity_names:
1279            q = quantities[name][:].flat
1280            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1281
1282        if interpolation_points is not None:   
1283            str += '  Interpolation points (xi, eta):'\
1284                   ' number of points == %d\n' %interpolation_points.shape[0]
1285            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1286                                            max(interpolation_points[:,0]))
1287            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1288                                             max(interpolation_points[:,1]))
1289            str += '  Interpolated quantities (over all timesteps):\n'
1290       
1291            for name in quantity_names:
1292                q = precomputed_values[name][:].flat
1293                str += '    %s at interpolation points in [%f, %f]\n'\
1294                       %(name, min(q), max(q))
1295        str += '------------------------------------------------\n'
1296
1297        return str
1298
1299        #FIXME: Delete
1300        #print '------------------------------------------------'
1301        #print 'Interpolation_function statistics:'
1302        #print '  Extent:'
1303        #print '    x in [%f, %f], len(x) == %d'\
1304        #      %(min(x), max(x), len(x))
1305        #print '    y in [%f, %f], len(y) == %d'\
1306        #      %(min(y), max(y), len(y))
1307        #print '    t in [%f, %f], len(t) == %d'\
1308        #      %(min(self.T), max(self.T), len(self.T))
1309        #print '  Quantities:'
1310        #for name in quantity_names:
1311        #    q = quantities[name][:].flat
1312        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1313        #print '  Interpolation points (xi, eta):'\
1314        #      ' number of points == %d ' %interpolation_points.shape[0]
1315        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1316        #                             max(interpolation_points[:,0]))
1317        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1318        #                              max(interpolation_points[:,1]))
1319        #print '  Interpolated quantities (over all timesteps):'
1320        #
1321        #for name in quantity_names:
1322        #    q = precomputed_values[name][:].flat
1323        #    print '    %s at interpolation points in [%f, %f]'\
1324        #          %(name, min(q), max(q))
1325        #print '------------------------------------------------'
1326
1327
1328#-------------------------------------------------------------
1329if __name__ == "__main__":
1330    """
1331    Load in a mesh and data points with attributes.
1332    Fit the attributes to the mesh.
1333    Save a new mesh file.
1334    """
1335    import os, sys
1336    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1337            %os.path.basename(sys.argv[0])
1338
1339    if len(sys.argv) < 4:
1340        print usage
1341    else:
1342        mesh_file = sys.argv[1]
1343        point_file = sys.argv[2]
1344        mesh_output_file = sys.argv[3]
1345
1346        expand_search = False
1347        if len(sys.argv) > 4:
1348            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1349                expand_search = True
1350            else:
1351                expand_search = False
1352
1353        verbose = False
1354        if len(sys.argv) > 5:
1355            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1356                verbose = False
1357            else:
1358                verbose = True
1359
1360        if len(sys.argv) > 6:
1361            alpha = sys.argv[6]
1362        else:
1363            alpha = DEFAULT_ALPHA
1364
1365        # This is used more for testing
1366        if len(sys.argv) > 7:
1367            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1368                display_errors = False
1369            else:
1370                display_errors = True
1371           
1372        t0 = time.time()
1373        try:
1374            fit_to_mesh_file(mesh_file,
1375                         point_file,
1376                         mesh_output_file,
1377                         alpha,
1378                         verbose= verbose,
1379                         expand_search = expand_search,
1380                         display_errors = display_errors)
1381        except IOError,e:
1382            import sys; sys.exit(1)
1383
1384        print 'That took %.2f seconds' %(time.time()-t0)
1385
Note: See TracBrowser for help on using the repository browser.