source: inundation/pyvolution/least_squares.py @ 2598

Last change on this file since 2598 was 2585, checked in by ole, 18 years ago

Added statistics for precropped points in fit_to_mesh

File size: 49.5 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from pyvolution.mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from utilities.sparse import Sparse, Sparse_CSR
29from utilities.cg_solve import conjugate_gradient, VectorShapeError
30from utilities.numerical_tools import ensure_numeric, mean, gradient
31
32
33from coordinate_transforms.geo_reference import Geo_reference
34
35import time
36
37
38
39
40DEFAULT_ALPHA = 0.001
41
42def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
43                     alpha=DEFAULT_ALPHA, verbose= False,
44                     expand_search = False,
45                     data_origin = None,
46                     mesh_origin = None,
47                     precrop = False,
48                     display_errors = True):
49    """
50    Given a mesh file (tsh) and a point attribute file (xya), fit
51    point attributes to the mesh and write a mesh file with the
52    results.
53
54
55    If data_origin is not None it is assumed to be
56    a 3-tuple with geo referenced
57    UTM coordinates (zone, easting, northing)
58
59    NOTE: Throws IOErrors, for a variety of file problems.
60   
61    mesh_origin is the same but refers to the input tsh file.
62    FIXME: When the tsh format contains it own origin, these parameters can go.
63    FIXME: And both origins should be obtained from the specified files.
64    """
65
66    from load_mesh.loadASCII import import_mesh_file, \
67                 import_points_file, export_mesh_file, \
68                 concatinate_attributelist
69
70
71    try:
72        mesh_dict = import_mesh_file(mesh_file)
73    except IOError,e:
74        if display_errors:
75            print "Could not load bad file. ", e
76        raise IOError  #Re-raise exception
77       
78    vertex_coordinates = mesh_dict['vertices']
79    triangles = mesh_dict['triangles']
80    if type(mesh_dict['vertex_attributes']) == ArrayType:
81        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
82    else:
83        old_point_attributes = mesh_dict['vertex_attributes']
84
85    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
86        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
87    else:
88        old_title_list = mesh_dict['vertex_attribute_titles']
89
90    if verbose: print 'tsh file %s loaded' %mesh_file
91
92    # load in the .pts file
93    try:
94        point_dict = import_points_file(point_file, verbose=verbose)
95    except IOError,e:
96        if display_errors:
97            print "Could not load bad file. ", e
98        raise IOError  #Re-raise exception 
99
100    point_coordinates = point_dict['pointlist']
101    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
102
103    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
104        data_origin = point_dict['geo_reference'].get_origin()
105    else:
106        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
107
108    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
109        mesh_origin = mesh_dict['geo_reference'].get_origin()
110    else:
111        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
112
113    if verbose: print "points file loaded"
114    if verbose: print "fitting to mesh"
115    f = fit_to_mesh(vertex_coordinates,
116                    triangles,
117                    point_coordinates,
118                    point_attributes,
119                    alpha = alpha,
120                    verbose = verbose,
121                    expand_search = expand_search,
122                    data_origin = data_origin,
123                    mesh_origin = mesh_origin,
124                    precrop = precrop)
125    if verbose: print "finished fitting to mesh"
126
127    # convert array to list of lists
128    new_point_attributes = f.tolist()
129    #FIXME have this overwrite attributes with the same title - DSG
130    #Put the newer attributes last
131    if old_title_list <> []:
132        old_title_list.extend(title_list)
133        #FIXME can this be done a faster way? - DSG
134        for i in range(len(old_point_attributes)):
135            old_point_attributes[i].extend(new_point_attributes[i])
136        mesh_dict['vertex_attributes'] = old_point_attributes
137        mesh_dict['vertex_attribute_titles'] = old_title_list
138    else:
139        mesh_dict['vertex_attributes'] = new_point_attributes
140        mesh_dict['vertex_attribute_titles'] = title_list
141
142    #FIXME (Ole): Remember to output mesh_origin as well
143    if verbose: print "exporting to file ", mesh_output_file
144
145    try:
146        export_mesh_file(mesh_output_file, mesh_dict)
147    except IOError,e:
148        if display_errors:
149            print "Could not write file. ", e
150        raise IOError
151
152def fit_to_mesh(vertex_coordinates,
153                triangles,
154                point_coordinates,
155                point_attributes,
156                alpha = DEFAULT_ALPHA,
157                verbose = False,
158                expand_search = False,
159                data_origin = None,
160                mesh_origin = None,
161                precrop = False,
162                use_cache = False):
163    """
164    Fit a smooth surface to a triangulation,
165    given data points with attributes.
166
167
168        Inputs:
169
170          vertex_coordinates: List of coordinate pairs [xi, eta] of points
171          constituting mesh (or a an m x 2 Numeric array)
172
173          triangles: List of 3-tuples (or a Numeric array) of
174          integers representing indices of all vertices in the mesh.
175
176          point_coordinates: List of coordinate pairs [x, y] of data points
177          (or an nx2 Numeric array)
178
179          alpha: Smoothing parameter.
180
181          point_attributes: Vector or array of data at the point_coordinates.
182
183          data_origin and mesh_origin are 3-tuples consisting of
184          UTM zone, easting and northing. If specified
185          point coordinates and vertex coordinates are assumed to be
186          relative to their respective origins.
187
188    """
189
190    if use_cache is True:
191        from caching.caching import cache
192        interp = cache(_interpolation,
193                       (vertex_coordinates,
194                        triangles,
195                        point_coordinates),
196                       {'alpha': alpha,
197                        'verbose': verbose,
198                        'expand_search': expand_search,
199                        'data_origin': data_origin,
200                        'mesh_origin': mesh_origin,
201                        'precrop': precrop},
202                       verbose = verbose)       
203       
204    else:
205        interp = Interpolation(vertex_coordinates,
206                               triangles,
207                               point_coordinates,
208                               alpha = alpha,
209                               verbose = verbose,
210                               expand_search = expand_search,
211                               data_origin = data_origin,
212                               mesh_origin = mesh_origin,
213                               precrop = precrop)
214
215    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
216   
217    if verbose:
218   
219        point_coordinates = ensure_numeric(point_coordinates)
220        vertex_coordinates = ensure_numeric(vertex_coordinates)
221               
222        X = point_coordinates[:,0]
223        Y = point_coordinates[:,1]     
224        Z = point_attributes   
225           
226        print '+------------------------------------------------'
227        print 'Least squares statistics'
228        print '+------------------------------------------------'       
229        print 'points: %d points' %(len(Z))
230        print '    x in [%f, %f]'%(min(X), max(X))
231        print '    y in [%f, %f]'%(min(Y), max(Y))
232        print '    z in [%f, %f]'%(min(Z), max(Z))
233        print
234
235        indices = interp.point_indices
236        if indices is not None:
237            X = take(X, indices)
238            Y = take(Y, indices)       
239            Z = take(Z, indices)       
240            print 'Cropped points: %d points' %(len(Z))
241            print '    x in [%f, %f]'%(min(X), max(X))
242            print '    y in [%f, %f]'%(min(Y), max(Y))
243            print '    z in [%f, %f]'%(min(Z), max(Z))
244            print
245       
246        Xi = vertex_coordinates[:,0]
247        Eta = vertex_coordinates[:,1]   
248        Zeta = vertex_attributes               
249        print 'Mesh: %d vertices' %(len(Zeta))
250        print '    xi in [%f, %f]'%(min(Xi), max(Xi))
251        print '    eta in [%f, %f]'%(min(Eta), max(Eta))
252        print '    zeta in [%f, %f]'%(min(Zeta), max(Zeta))
253        print '+------------------------------------------------'
254
255    return vertex_attributes
256
257
258
259def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
260                    verbose = False, reduction = 1):
261    """Fits attributes from pts file to MxN rectangular mesh
262
263    Read pts file and create rectangular mesh of resolution MxN such that
264    it covers all points specified in pts file.
265
266    FIXME: This may be a temporary function until we decide on
267    netcdf formats etc
268
269    FIXME: Uses elevation hardwired
270    """
271
272    import  mesh_factory
273    from load_mesh.loadASCII import import_points_file
274   
275    if verbose: print 'Read pts'
276    points_dict = import_points_file(pts_name)
277    #points, attributes = util.read_xya(pts_name)
278
279    #Reduce number of points a bit
280    points = points_dict['pointlist'][::reduction]
281    elevation = points_dict['attributelist']['elevation']  #Must be elevation
282    elevation = elevation[::reduction]
283
284    if verbose: print 'Got %d data points' %len(points)
285
286    if verbose: print 'Create mesh'
287    #Find extent
288    max_x = min_x = points[0][0]
289    max_y = min_y = points[0][1]
290    for point in points[1:]:
291        x = point[0]
292        if x > max_x: max_x = x
293        if x < min_x: min_x = x
294        y = point[1]
295        if y > max_y: max_y = y
296        if y < min_y: min_y = y
297
298    #Create appropriate mesh
299    vertex_coordinates, triangles, boundary =\
300         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
301                                (min_x, min_y))
302
303    #Fit attributes to mesh
304    vertex_attributes = fit_to_mesh(vertex_coordinates,
305                        triangles,
306                        points,
307                        elevation, alpha=alpha, verbose=verbose)
308
309
310
311    return vertex_coordinates, triangles, boundary, vertex_attributes
312
313
314def _interpolation(*args, **kwargs):
315    """Private function for use with caching. Reason is that classes
316    may change their byte code between runs which is annoying.
317    """
318   
319    return Interpolation(*args, **kwargs)
320
321
322class Interpolation:
323
324    def __init__(self,
325                 vertex_coordinates,
326                 triangles,
327                 point_coordinates = None,
328                 alpha = None,
329                 verbose = False,
330                 expand_search = True,
331                 interp_only = False,
332                 max_points_per_cell = 30,
333                 mesh_origin = None,
334                 data_origin = None,
335                 precrop = False):
336
337
338        """ Build interpolation matrix mapping from
339        function values at vertices to function values at data points
340
341        Inputs:
342
343          vertex_coordinates: List of coordinate pairs [xi, eta] of
344          points constituting mesh (or a an m x 2 Numeric array)
345          Points may appear multiple times
346          (e.g. if vertices have discontinuities)
347
348          triangles: List of 3-tuples (or a Numeric array) of
349          integers representing indices of all vertices in the mesh.
350
351          point_coordinates: List of coordinate pairs [x, y] of
352          data points (or an nx2 Numeric array)
353          If point_coordinates is absent, only smoothing matrix will
354          be built
355
356          alpha: Smoothing parameter
357
358          data_origin and mesh_origin are 3-tuples consisting of
359          UTM zone, easting and northing. If specified
360          point coordinates and vertex coordinates are assumed to be
361          relative to their respective origins.
362
363        """
364
365        #Convert input to Numeric arrays
366        triangles = ensure_numeric(triangles, Int)
367        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
368
369        #Build underlying mesh
370        if verbose: print 'Building mesh'
371        #self.mesh = General_mesh(vertex_coordinates, triangles,
372        #FIXME: Trying the normal mesh while testing precrop,
373        #       The functionality of boundary_polygon is needed for that
374
375        #FIXME - geo ref does not have to go into mesh.
376        # Change the point co-ords to conform to the
377        # mesh co-ords early in the code
378        if mesh_origin is None:
379            geo = None
380        else:
381            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
382        self.mesh = Mesh(vertex_coordinates, triangles,
383                         geo_reference = geo)
384       
385        self.mesh.check_integrity()
386
387        self.data_origin = data_origin
388
389        self.point_indices = None
390
391        #Smoothing parameter
392        if alpha is None:
393            self.alpha = DEFAULT_ALPHA
394        else:   
395            self.alpha = alpha
396
397
398        if point_coordinates is not None:
399            if verbose: print 'Building interpolation matrix'
400            self.build_interpolation_matrix_A(point_coordinates,
401                                              verbose = verbose,
402                                              expand_search = expand_search,
403                                              interp_only = interp_only, 
404                                              max_points_per_cell =\
405                                              max_points_per_cell,
406                                              data_origin = data_origin,
407                                              precrop = precrop)
408        #Build coefficient matrices
409        if interp_only == False:
410            self.build_coefficient_matrix_B(point_coordinates,
411                                        verbose = verbose,
412                                        expand_search = expand_search,
413                                        max_points_per_cell =\
414                                        max_points_per_cell,
415                                        data_origin = data_origin,
416                                        precrop = precrop)
417
418    def set_point_coordinates(self, point_coordinates,
419                              data_origin = None,
420                              verbose = False,
421                              precrop = True):
422        """
423        A public interface to setting the point co-ordinates.
424        """
425        if point_coordinates is not None:
426            if verbose: print 'Building interpolation matrix'
427            self.build_interpolation_matrix_A(point_coordinates,
428                                              verbose = verbose,
429                                              data_origin = data_origin,
430                                              precrop = precrop)
431        self.build_coefficient_matrix_B(point_coordinates, data_origin)
432
433    def build_coefficient_matrix_B(self, point_coordinates=None,
434                                   verbose = False, expand_search = True,
435                                   max_points_per_cell=30,
436                                   data_origin = None,
437                                   precrop = False):
438        """Build final coefficient matrix"""
439
440
441        if self.alpha <> 0:
442            if verbose: print 'Building smoothing matrix'
443            self.build_smoothing_matrix_D()
444
445        if point_coordinates is not None:
446            if self.alpha <> 0:
447                self.B = self.AtA + self.alpha*self.D
448            else:
449                self.B = self.AtA
450
451            #Convert self.B matrix to CSR format for faster matrix vector
452            self.B = Sparse_CSR(self.B)
453
454    def build_interpolation_matrix_A(self, point_coordinates,
455                                     verbose = False, expand_search = True,
456                                     max_points_per_cell=30,
457                                     data_origin = None,
458                                     precrop = False,
459                                     interp_only = False):
460        """Build n x m interpolation matrix, where
461        n is the number of data points and
462        m is the number of basis functions phi_k (one per vertex)
463
464        This algorithm uses a quad tree data structure for fast binning of data points
465        origin is a 3-tuple consisting of UTM zone, easting and northing.
466        If specified coordinates are assumed to be relative to this origin.
467
468        This one will override any data_origin that may be specified in
469        interpolation instance
470
471        """
472
473
474
475        #FIXME (Ole): Check that this function is memeory efficient.
476        #6 million datapoints and 300000 basis functions
477        #causes out-of-memory situation
478        #First thing to check is whether there is room for self.A and self.AtA
479        #
480        #Maybe we need some sort of blocking
481
482        from pyvolution.quad import build_quadtree
483        from utilities.polygon import inside_polygon
484       
485
486        if data_origin is None:
487            data_origin = self.data_origin #Use the one from
488                                           #interpolation instance
489
490        #Convert input to Numeric arrays just in case.
491        point_coordinates = ensure_numeric(point_coordinates, Float)
492
493        #Keep track of discarded points (if any).
494        #This is only registered if precrop is True
495        self.cropped_points = False
496
497        #Shift data points to same origin as mesh (if specified)
498
499        #FIXME this will shift if there was no geo_ref.
500        #But all this should be removed anyhow.
501        #change coords before this point
502        mesh_origin = self.mesh.geo_reference.get_origin()
503        if point_coordinates is not None:
504            if data_origin is not None:
505                if mesh_origin is not None:
506
507                    #Transformation:
508                    #
509                    #Let x_0 be the reference point of the point coordinates
510                    #and xi_0 the reference point of the mesh.
511                    #
512                    #A point coordinate (x + x_0) is then made relative
513                    #to xi_0 by
514                    #
515                    # x_new = x + x_0 - xi_0
516                    #
517                    #and similarly for eta
518
519                    x_offset = data_origin[1] - mesh_origin[1]
520                    y_offset = data_origin[2] - mesh_origin[2]
521                else: #Shift back to a zero origin
522                    x_offset = data_origin[1]
523                    y_offset = data_origin[2]
524
525                point_coordinates[:,0] += x_offset
526                point_coordinates[:,1] += y_offset
527            else:
528                if mesh_origin is not None:
529                    #Use mesh origin for data points
530                    point_coordinates[:,0] -= mesh_origin[1]
531                    point_coordinates[:,1] -= mesh_origin[2]
532
533
534
535        #Remove points falling outside mesh boundary
536        #This reduced one example from 1356 seconds to 825 seconds
537
538       
539        if precrop is True:
540            from Numeric import take
541
542            if verbose: print 'Getting boundary polygon'
543            P = self.mesh.get_boundary_polygon()
544
545            if verbose: print 'Getting indices inside mesh boundary'
546            indices = inside_polygon(point_coordinates, P, verbose = verbose)
547
548
549            if len(indices) != point_coordinates.shape[0]:
550                self.cropped_points = True
551                if verbose:
552                    print 'Done - %d points outside mesh have been cropped.'\
553                          %(point_coordinates.shape[0] - len(indices))
554
555            point_coordinates = take(point_coordinates, indices)
556            self.point_indices = indices
557
558
559
560
561        #Build n x m interpolation matrix
562        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
563        n = point_coordinates.shape[0]     #Nbr of data points
564
565        if verbose: print 'Number of datapoints: %d' %n
566        if verbose: print 'Number of basis functions: %d' %m
567
568        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
569        #However, Sparse_CSR does not have the same methods as Sparse yet
570        #The tests will reveal what needs to be done
571
572        #
573        #self.A = Sparse_CSR(Sparse(n,m))
574        #self.AtA = Sparse_CSR(Sparse(m,m))
575        self.A = Sparse(n,m)
576        self.AtA = Sparse(m,m)
577
578        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
579        root = build_quadtree(self.mesh,
580                              max_points_per_cell = max_points_per_cell)
581        #root.show()
582        self.expanded_quad_searches = []
583        #Compute matrix elements
584        for i in range(n):
585            #For each data_coordinate point
586
587            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
588            x = point_coordinates[i]
589
590            #Find vertices near x
591            candidate_vertices = root.search(x[0], x[1])
592            is_more_elements = True
593
594            element_found, sigma0, sigma1, sigma2, k = \
595                self.search_triangles_of_vertices(candidate_vertices, x)
596            first_expansion = True
597            while not element_found and is_more_elements and expand_search:
598                #if verbose: print 'Expanding search'
599                if first_expansion == True:
600                    self.expanded_quad_searches.append(1)
601                    first_expansion = False
602                else:
603                    end = len(self.expanded_quad_searches) - 1
604                    assert end >= 0
605                    self.expanded_quad_searches[end] += 1
606                candidate_vertices, branch = root.expand_search()
607                if branch == []:
608                    # Searching all the verts from the root cell that haven't
609                    # been searched.  This is the last try
610                    element_found, sigma0, sigma1, sigma2, k = \
611                      self.search_triangles_of_vertices(candidate_vertices, x)
612                    is_more_elements = False
613                else:
614                    element_found, sigma0, sigma1, sigma2, k = \
615                      self.search_triangles_of_vertices(candidate_vertices, x)
616
617               
618            #Update interpolation matrix A if necessary
619            if element_found is True:
620                #Assign values to matrix A
621
622                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
623                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
624                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
625
626                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
627                js     = [j0,j1,j2]
628
629                for j in js:
630                    self.A[i,j] = sigmas[j]
631                    for k in js:
632                        if interp_only == False:
633                            self.AtA[j,k] += sigmas[j]*sigmas[k]
634            else:
635                pass
636                #Ok if there is no triangle for datapoint
637                #(as in brute force version)
638                #raise 'Could not find triangle for point', x
639
640
641
642    def search_triangles_of_vertices(self, candidate_vertices, x):
643            #Find triangle containing x:
644            element_found = False
645
646            # This will be returned if element_found = False
647            sigma2 = -10.0
648            sigma0 = -10.0
649            sigma1 = -10.0
650            k = -10.0
651            #print "*$* candidate_vertices", candidate_vertices
652            #For all vertices in same cell as point x
653            for v in candidate_vertices:
654                #FIXME (DSG-DSG): this catches verts with no triangle.
655                #Currently pmesh is producing these.
656                #this should be stopped,
657                if self.mesh.vertexlist[v] is None:
658                    continue
659                #for each triangle id (k) which has v as a vertex
660                for k, _ in self.mesh.vertexlist[v]:
661
662                    #Get the three vertex_points of candidate triangle
663                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
664                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
665                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
666
667                    #print "PDSG - k", k
668                    #print "PDSG - xi0", xi0
669                    #print "PDSG - xi1", xi1
670                    #print "PDSG - xi2", xi2
671                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
672                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
673
674                    #Get the three normals
675                    n0 = self.mesh.get_normal(k, 0)
676                    n1 = self.mesh.get_normal(k, 1)
677                    n2 = self.mesh.get_normal(k, 2)
678
679
680                    #Compute interpolation
681                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
682                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
683                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
684
685                    #print "PDSG - sigma0", sigma0
686                    #print "PDSG - sigma1", sigma1
687                    #print "PDSG - sigma2", sigma2
688
689                    #FIXME: Maybe move out to test or something
690                    epsilon = 1.0e-6
691                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
692
693                    #Check that this triangle contains the data point
694
695                    #Sigmas can get negative within
696                    #machine precision on some machines (e.g nautilus)
697                    #Hence the small eps
698                    eps = 1.0e-15
699                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
700                        element_found = True
701                        break
702
703                if element_found is True:
704                    #Don't look for any other triangle
705                    break
706            return element_found, sigma0, sigma1, sigma2, k
707
708
709
710    def build_interpolation_matrix_A_brute(self, point_coordinates):
711        """Build n x m interpolation matrix, where
712        n is the number of data points and
713        m is the number of basis functions phi_k (one per vertex)
714
715        This is the brute force which is too slow for large problems,
716        but could be used for testing
717        """
718
719
720        #Convert input to Numeric arrays
721        point_coordinates = ensure_numeric(point_coordinates, Float)
722
723        #Build n x m interpolation matrix
724        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
725        n = point_coordinates.shape[0]     #Nbr of data points
726
727        self.A = Sparse(n,m)
728        self.AtA = Sparse(m,m)
729
730        #Compute matrix elements
731        for i in range(n):
732            #For each data_coordinate point
733
734            x = point_coordinates[i]
735            element_found = False
736            k = 0
737            while not element_found and k < len(self.mesh):
738                #For each triangle (brute force)
739                #FIXME: Real algorithm should only visit relevant triangles
740
741                #Get the three vertex_points
742                xi0 = self.mesh.get_vertex_coordinate(k, 0)
743                xi1 = self.mesh.get_vertex_coordinate(k, 1)
744                xi2 = self.mesh.get_vertex_coordinate(k, 2)
745
746                #Get the three normals
747                n0 = self.mesh.get_normal(k, 0)
748                n1 = self.mesh.get_normal(k, 1)
749                n2 = self.mesh.get_normal(k, 2)
750
751                #Compute interpolation
752                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
753                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
754                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
755
756                #FIXME: Maybe move out to test or something
757                epsilon = 1.0e-6
758                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
759
760                #Check that this triangle contains data point
761                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
762                    element_found = True
763                    #Assign values to matrix A
764
765                    j0 = self.mesh.triangles[k,0] #Global vertex id
766                    #self.A[i, j0] = sigma0
767
768                    j1 = self.mesh.triangles[k,1] #Global vertex id
769                    #self.A[i, j1] = sigma1
770
771                    j2 = self.mesh.triangles[k,2] #Global vertex id
772                    #self.A[i, j2] = sigma2
773
774                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
775                    js     = [j0,j1,j2]
776
777                    for j in js:
778                        self.A[i,j] = sigmas[j]
779                        for k in js:
780                            self.AtA[j,k] += sigmas[j]*sigmas[k]
781                k = k+1
782
783
784
785    def get_A(self):
786        return self.A.todense()
787
788    def get_B(self):
789        return self.B.todense()
790
791    def get_D(self):
792        return self.D.todense()
793
794        #FIXME: Remember to re-introduce the 1/n factor in the
795        #interpolation term
796
797    def build_smoothing_matrix_D(self):
798        """Build m x m smoothing matrix, where
799        m is the number of basis functions phi_k (one per vertex)
800
801        The smoothing matrix is defined as
802
803        D = D1 + D2
804
805        where
806
807        [D1]_{k,l} = \int_\Omega
808           \frac{\partial \phi_k}{\partial x}
809           \frac{\partial \phi_l}{\partial x}\,
810           dx dy
811
812        [D2]_{k,l} = \int_\Omega
813           \frac{\partial \phi_k}{\partial y}
814           \frac{\partial \phi_l}{\partial y}\,
815           dx dy
816
817
818        The derivatives \frac{\partial \phi_k}{\partial x},
819        \frac{\partial \phi_k}{\partial x} for a particular triangle
820        are obtained by computing the gradient a_k, b_k for basis function k
821        """
822
823        #FIXME: algorithm might be optimised by computing local 9x9
824        #"element stiffness matrices:
825
826        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
827
828        self.D = Sparse(m,m)
829
830        #For each triangle compute contributions to D = D1+D2
831        for i in range(len(self.mesh)):
832
833            #Get area
834            area = self.mesh.areas[i]
835
836            #Get global vertex indices
837            v0 = self.mesh.triangles[i,0]
838            v1 = self.mesh.triangles[i,1]
839            v2 = self.mesh.triangles[i,2]
840
841            #Get the three vertex_points
842            xi0 = self.mesh.get_vertex_coordinate(i, 0)
843            xi1 = self.mesh.get_vertex_coordinate(i, 1)
844            xi2 = self.mesh.get_vertex_coordinate(i, 2)
845
846            #Compute gradients for each vertex
847            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
848                              1, 0, 0)
849
850            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
851                              0, 1, 0)
852
853            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
854                              0, 0, 1)
855
856            #Compute diagonal contributions
857            self.D[v0,v0] += (a0*a0 + b0*b0)*area
858            self.D[v1,v1] += (a1*a1 + b1*b1)*area
859            self.D[v2,v2] += (a2*a2 + b2*b2)*area
860
861            #Compute contributions for basis functions sharing edges
862            e01 = (a0*a1 + b0*b1)*area
863            self.D[v0,v1] += e01
864            self.D[v1,v0] += e01
865
866            e12 = (a1*a2 + b1*b2)*area
867            self.D[v1,v2] += e12
868            self.D[v2,v1] += e12
869
870            e20 = (a2*a0 + b2*b0)*area
871            self.D[v2,v0] += e20
872            self.D[v0,v2] += e20
873
874
875    def fit(self, z):
876        """Fit a smooth surface to given 1d array of data points z.
877
878        The smooth surface is computed at each vertex in the underlying
879        mesh using the formula given in the module doc string.
880
881        Pre Condition:
882          self.A, self.AtA and self.B have been initialised
883
884        Inputs:
885          z: Single 1d vector or array of data at the point_coordinates.
886        """
887
888        #Convert input to Numeric arrays
889        z = ensure_numeric(z, Float)
890
891        if len(z.shape) > 1 :
892            raise VectorShapeError, 'Can only deal with 1d data vector'
893
894        if self.point_indices is not None:
895            #Remove values for any points that were outside mesh
896            z = take(z, self.point_indices)
897
898        #Compute right hand side based on data
899        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
900        # after a matrix is built, before calcs.
901        Atz = self.A.trans_mult(z)
902
903
904        #Check sanity
905        n, m = self.A.shape
906        if n<m and self.alpha == 0.0:
907            msg = 'ERROR (least_squares): Too few data points\n'
908            msg += 'There are only %d data points and alpha == 0. ' %n
909            msg += 'Need at least %d\n' %m
910            msg += 'Alternatively, set smoothing parameter alpha to a small '
911            msg += 'positive value,\ne.g. 1.0e-3.'
912            raise msg
913
914
915
916        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
917        #FIXME: Should we store the result here for later use? (ON)
918
919
920    def fit_points(self, z, verbose=False):
921        """Like fit, but more robust when each point has two or more attributes
922        FIXME (Ole): The name fit_points doesn't carry any meaning
923        for me. How about something like fit_multiple or fit_columns?
924        """
925
926        try:
927            if verbose: print 'Solving penalised least_squares problem'
928            return self.fit(z)
929        except VectorShapeError, e:
930            # broadcasting is not supported.
931
932            #Convert input to Numeric arrays
933            z = ensure_numeric(z, Float)
934
935            #Build n x m interpolation matrix
936            m = self.mesh.coordinates.shape[0] #Number of vertices
937            n = z.shape[1]                     #Number of data points
938
939            f = zeros((m,n), Float) #Resulting columns
940
941            for i in range(z.shape[1]):
942                f[:,i] = self.fit(z[:,i])
943
944            return f
945
946
947    def interpolate(self, f):
948        """Evaluate smooth surface f at data points implied in self.A.
949
950        The mesh values representing a smooth surface are
951        assumed to be specified in f. This argument could,
952        for example have been obtained from the method self.fit()
953
954        Pre Condition:
955          self.A has been initialised
956
957        Inputs:
958          f: Vector or array of data at the mesh vertices.
959          If f is an array, interpolation will be done for each column as
960          per underlying matrix-matrix multiplication
961
962        Output:
963          Interpolated values at data points implied in self.A
964
965        """
966
967        return self.A * f
968
969    def cull_outsiders(self, f):
970        pass
971
972
973
974
975class Interpolation_function:
976    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
977    which is interpolated from time series defined at vertices of
978    triangular mesh (such as those stored in sww files)
979
980    Let m be the number of vertices, n the number of triangles
981    and p the number of timesteps.
982
983    Mandatory input
984        time:               px1 array of monotonously increasing times (Float)
985        quantities:         Dictionary of arrays or 1 array (Float)
986                            The arrays must either have dimensions pxm or mx1.
987                            The resulting function will be time dependent in
988                            the former case while it will be constan with
989                            respect to time in the latter case.
990       
991    Optional input:
992        quantity_names:     List of keys into the quantities dictionary
993        vertex_coordinates: mx2 array of coordinates (Float)
994        triangles:          nx3 array of indices into vertex_coordinates (Int)
995        interpolation_points: Nx2 array of coordinates to be interpolated to
996        verbose:            Level of reporting
997   
998   
999    The quantities returned by the callable object are specified by
1000    the list quantities which must contain the names of the
1001    quantities to be returned and also reflect the order, e.g. for
1002    the shallow water wave equation, on would have
1003    quantities = ['stage', 'xmomentum', 'ymomentum']
1004
1005    The parameter interpolation_points decides at which points interpolated
1006    quantities are to be computed whenever object is called.
1007    If None, return average value
1008    """
1009
1010   
1011   
1012    def __init__(self,
1013                 time,
1014                 quantities,
1015                 quantity_names = None, 
1016                 vertex_coordinates = None,
1017                 triangles = None,
1018                 interpolation_points = None,
1019                 verbose = False):
1020        """Initialise object and build spatial interpolation if required
1021        """
1022
1023        from Numeric import array, zeros, Float, alltrue, concatenate,\
1024             reshape, ArrayType
1025
1026
1027        from config import time_format
1028        import types
1029
1030
1031
1032        #Check temporal info
1033        time = ensure_numeric(time)       
1034        msg = 'Time must be a monotonuosly '
1035        msg += 'increasing sequence %s' %time
1036        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
1037
1038
1039        #Check if quantities is a single array only
1040        if type(quantities) != types.DictType:
1041            quantities = ensure_numeric(quantities)
1042            quantity_names = ['Attribute']
1043
1044            #Make it a dictionary
1045            quantities = {quantity_names[0]: quantities}
1046
1047
1048        #Use keys if no names are specified
1049        if quantity_names is None:
1050            quantity_names = quantities.keys()
1051
1052
1053        #Check spatial info
1054        if vertex_coordinates is None:
1055            self.spatial = False
1056        else:   
1057            vertex_coordinates = ensure_numeric(vertex_coordinates)
1058
1059            assert triangles is not None, 'Triangles array must be specified'
1060            triangles = ensure_numeric(triangles)
1061            self.spatial = True           
1062           
1063
1064 
1065        #Save for use with statistics
1066        self.quantity_names = quantity_names       
1067        self.quantities = quantities       
1068        self.vertex_coordinates = vertex_coordinates
1069        self.interpolation_points = interpolation_points
1070        self.T = time[:]  # Time assumed to be relative to starttime
1071        self.index = 0    # Initial time index
1072        self.precomputed_values = {}
1073           
1074
1075
1076        #Precomputed spatial interpolation if requested
1077        if interpolation_points is not None:
1078            if self.spatial is False:
1079                raise 'Triangles and vertex_coordinates must be specified'
1080           
1081            try:
1082                self.interpolation_points = ensure_numeric(interpolation_points)
1083            except:
1084                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1085                      'or a list of points\n'
1086                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1087                                      '...')
1088                raise msg
1089
1090
1091            m = len(self.interpolation_points)
1092            p = len(self.T)
1093           
1094            for name in quantity_names:
1095                self.precomputed_values[name] = zeros((p, m), Float)
1096
1097            #Build interpolator
1098            interpol = Interpolation(vertex_coordinates,
1099                                     triangles,
1100                                     point_coordinates = \
1101                                     self.interpolation_points,
1102                                     alpha = 0,
1103                                     precrop = False, 
1104                                     verbose = verbose)
1105
1106            if verbose: print 'Interpolate'
1107            for i, t in enumerate(self.T):
1108                #Interpolate quantities at this timestep
1109                if verbose and i%((p+10)/10)==0:
1110                    print ' time step %d of %d' %(i, p)
1111                   
1112                for name in quantity_names:
1113                    if len(quantities[name].shape) == 2:
1114                        result = interpol.interpolate(quantities[name][i,:])
1115                    else:
1116                       #Assume no time dependency
1117                       result = interpol.interpolate(quantities[name][:])
1118                       
1119                    self.precomputed_values[name][i, :] = result
1120                   
1121                       
1122
1123            #Report
1124            if verbose:
1125                print self.statistics()
1126                #self.print_statistics()
1127           
1128        else:
1129            #Store quantitites as is
1130            for name in quantity_names:
1131                self.precomputed_values[name] = quantities[name]
1132
1133
1134        #else:
1135        #    #Return an average, making this a time series
1136        #    for name in quantity_names:
1137        #        self.values[name] = zeros(len(self.T), Float)
1138        #
1139        #    if verbose: print 'Compute mean values'
1140        #    for i, t in enumerate(self.T):
1141        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1142        #        for name in quantity_names:
1143        #           self.values[name][i] = mean(quantities[name][i,:])
1144
1145
1146
1147
1148    def __repr__(self):
1149        #return 'Interpolation function (spatio-temporal)'
1150        return self.statistics()
1151   
1152
1153    def __call__(self, t, point_id = None, x = None, y = None):
1154        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1155
1156        Inputs:
1157          t: time - Model time. Must lie within existing timesteps
1158          point_id: index of one of the preprocessed points.
1159          x, y:     Overrides location, point_id ignored
1160         
1161          If spatial info is present and all of x,y,point_id
1162          are None an exception is raised
1163                   
1164          If no spatial info is present, point_id and x,y arguments are ignored
1165          making f a function of time only.
1166
1167         
1168          FIXME: point_id could also be a slice
1169          FIXME: What if x and y are vectors?
1170          FIXME: What about f(x,y) without t?
1171        """
1172
1173        from math import pi, cos, sin, sqrt
1174        from Numeric import zeros, Float
1175        from utilities.numerical_tools import mean       
1176
1177        if self.spatial is True:
1178            if point_id is None:
1179                if x is None or y is None:
1180                    msg = 'Either point_id or x and y must be specified'
1181                    raise msg
1182            else:
1183                if self.interpolation_points is None:
1184                    msg = 'Interpolation_function must be instantiated ' +\
1185                          'with a list of interpolation points before parameter ' +\
1186                          'point_id can be used'
1187                    raise msg
1188
1189
1190        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1191        msg += ' does not match model time: %s\n' %t
1192        if t < self.T[0]: raise msg
1193        if t > self.T[-1]: raise msg
1194
1195        oldindex = self.index #Time index
1196
1197        #Find current time slot
1198        while t > self.T[self.index]: self.index += 1
1199        while t < self.T[self.index]: self.index -= 1
1200
1201        if t == self.T[self.index]:
1202            #Protect against case where t == T[-1] (last time)
1203            # - also works in general when t == T[i]
1204            ratio = 0
1205        else:
1206            #t is now between index and index+1
1207            ratio = (t - self.T[self.index])/\
1208                    (self.T[self.index+1] - self.T[self.index])
1209
1210        #Compute interpolated values
1211        q = zeros(len(self.quantity_names), Float)
1212
1213        for i, name in enumerate(self.quantity_names):
1214            Q = self.precomputed_values[name]
1215
1216            if self.spatial is False:
1217                #If there is no spatial info               
1218                assert len(Q.shape) == 1
1219
1220                Q0 = Q[self.index]
1221                if ratio > 0: Q1 = Q[self.index+1]
1222
1223            else:
1224                if x is not None and y is not None:
1225                    #Interpolate to x, y
1226                   
1227                    raise 'x,y interpolation not yet implemented'
1228                else:
1229                    #Use precomputed point
1230                    Q0 = Q[self.index, point_id]
1231                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1232
1233            #Linear temporal interpolation   
1234            if ratio > 0:
1235                q[i] = Q0 + ratio*(Q1 - Q0)
1236            else:
1237                q[i] = Q0
1238
1239
1240        #Return vector of interpolated values
1241        #if len(q) == 1:
1242        #    return q[0]
1243        #else:
1244        #    return q
1245
1246
1247        #Return vector of interpolated values
1248        #FIXME:
1249        if self.spatial is True:
1250            return q
1251        else:
1252            #Replicate q according to x and y
1253            #This is e.g used for Wind_stress
1254            if x is None or y is None: 
1255                return q
1256            else:
1257                try:
1258                    N = len(x)
1259                except:
1260                    return q
1261                else:
1262                    from Numeric import ones, Float
1263                    #x is a vector - Create one constant column for each value
1264                    N = len(x)
1265                    assert len(y) == N, 'x and y must have same length'
1266                    res = []
1267                    for col in q:
1268                        res.append(col*ones(N, Float))
1269                       
1270                return res
1271
1272
1273    def statistics(self):
1274        """Output statistics about interpolation_function
1275        """
1276       
1277        vertex_coordinates = self.vertex_coordinates
1278        interpolation_points = self.interpolation_points               
1279        quantity_names = self.quantity_names
1280        quantities = self.quantities
1281        precomputed_values = self.precomputed_values                 
1282               
1283        x = vertex_coordinates[:,0]
1284        y = vertex_coordinates[:,1]               
1285
1286        str =  '------------------------------------------------\n'
1287        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1288        str += '  Extent:\n'
1289        str += '    x in [%f, %f], len(x) == %d\n'\
1290               %(min(x), max(x), len(x))
1291        str += '    y in [%f, %f], len(y) == %d\n'\
1292               %(min(y), max(y), len(y))
1293        str += '    t in [%f, %f], len(t) == %d\n'\
1294               %(min(self.T), max(self.T), len(self.T))
1295        str += '  Quantities:\n'
1296        for name in quantity_names:
1297            q = quantities[name][:].flat
1298            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1299
1300        if interpolation_points is not None:   
1301            str += '  Interpolation points (xi, eta):'\
1302                   ' number of points == %d\n' %interpolation_points.shape[0]
1303            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1304                                            max(interpolation_points[:,0]))
1305            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1306                                             max(interpolation_points[:,1]))
1307            str += '  Interpolated quantities (over all timesteps):\n'
1308       
1309            for name in quantity_names:
1310                q = precomputed_values[name][:].flat
1311                str += '    %s at interpolation points in [%f, %f]\n'\
1312                       %(name, min(q), max(q))
1313        str += '------------------------------------------------\n'
1314
1315        return str
1316
1317        #FIXME: Delete
1318        #print '------------------------------------------------'
1319        #print 'Interpolation_function statistics:'
1320        #print '  Extent:'
1321        #print '    x in [%f, %f], len(x) == %d'\
1322        #      %(min(x), max(x), len(x))
1323        #print '    y in [%f, %f], len(y) == %d'\
1324        #      %(min(y), max(y), len(y))
1325        #print '    t in [%f, %f], len(t) == %d'\
1326        #      %(min(self.T), max(self.T), len(self.T))
1327        #print '  Quantities:'
1328        #for name in quantity_names:
1329        #    q = quantities[name][:].flat
1330        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1331        #print '  Interpolation points (xi, eta):'\
1332        #      ' number of points == %d ' %interpolation_points.shape[0]
1333        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1334        #                             max(interpolation_points[:,0]))
1335        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1336        #                              max(interpolation_points[:,1]))
1337        #print '  Interpolated quantities (over all timesteps):'
1338        #
1339        #for name in quantity_names:
1340        #    q = precomputed_values[name][:].flat
1341        #    print '    %s at interpolation points in [%f, %f]'\
1342        #          %(name, min(q), max(q))
1343        #print '------------------------------------------------'
1344
1345
1346#-------------------------------------------------------------
1347if __name__ == "__main__":
1348    """
1349    Load in a mesh and data points with attributes.
1350    Fit the attributes to the mesh.
1351    Save a new mesh file.
1352    """
1353    import os, sys
1354    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1355            %os.path.basename(sys.argv[0])
1356
1357    if len(sys.argv) < 4:
1358        print usage
1359    else:
1360        mesh_file = sys.argv[1]
1361        point_file = sys.argv[2]
1362        mesh_output_file = sys.argv[3]
1363
1364        expand_search = False
1365        if len(sys.argv) > 4:
1366            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1367                expand_search = True
1368            else:
1369                expand_search = False
1370
1371        verbose = False
1372        if len(sys.argv) > 5:
1373            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1374                verbose = False
1375            else:
1376                verbose = True
1377
1378        if len(sys.argv) > 6:
1379            alpha = sys.argv[6]
1380        else:
1381            alpha = DEFAULT_ALPHA
1382
1383        # This is used more for testing
1384        if len(sys.argv) > 7:
1385            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1386                display_errors = False
1387            else:
1388                display_errors = True
1389           
1390        t0 = time.time()
1391        try:
1392            fit_to_mesh_file(mesh_file,
1393                         point_file,
1394                         mesh_output_file,
1395                         alpha,
1396                         verbose= verbose,
1397                         expand_search = expand_search,
1398                         display_errors = display_errors)
1399        except IOError,e:
1400            import sys; sys.exit(1)
1401
1402        print 'That took %.2f seconds' %(time.time()-t0)
1403
Note: See TracBrowser for help on using the repository browser.