source: inundation/fit_interpolate/fit.py @ 3455

Last change on this file since 3455 was 3455, checked in by duncan, 18 years ago

comments

File size: 20.7 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18
19   TO DO
20   * test geo_ref, geo_spatial
21
22   IDEAS
23   * (DSG-) Change the interface of fit, so a domain object can
24      be passed in. (I don't know if this is feasible). If could
25      save time.
26"""
27
28from Numeric import zeros, Float, ArrayType,take
29
30from geospatial_data.geospatial_data import Geospatial_data, ensure_absolute
31from fit_interpolate.general_fit_interpolate import FitInterpolate
32from utilities.sparse import Sparse, Sparse_CSR
33from utilities.polygon import in_and_outside_polygon
34from fit_interpolate.search_functions import search_tree_of_vertices
35from utilities.cg_solve import conjugate_gradient
36from utilities.numerical_tools import ensure_numeric, gradient
37
38import exceptions
39class ToFewPointsError(exceptions.Exception): pass
40class VertsWithNoTrianglesError(exceptions.Exception): pass
41
42DEFAULT_ALPHA = 0.001
43
44
45class Fit(FitInterpolate):
46   
47    def __init__(self,
48                 vertex_coordinates,
49                 triangles,
50                 mesh_origin=None,
51                 alpha = None,
52                 verbose=False,
53                 max_vertices_per_cell=30):
54
55
56        """
57        Fit data at points to the vertices of a mesh.
58
59        Inputs:
60
61          vertex_coordinates: List of coordinate pairs [xi, eta] of
62              points constituting a mesh (or an m x 2 Numeric array or
63              a geospatial object)
64              Points may appear multiple times
65              (e.g. if vertices have discontinuities)
66
67          triangles: List of 3-tuples (or a Numeric array) of
68              integers representing indices of all vertices in the mesh.
69
70          mesh_origin: A geo_reference object or 3-tuples consisting of
71              UTM zone, easting and northing.
72              If specified vertex coordinates are assumed to be
73              relative to their respective origins.
74
75          max_vertices_per_cell: Number of vertices in a quad tree cell
76          at which the cell is split into 4.
77
78          Note: Don't supply a vertex coords as a geospatial object and
79              a mesh origin, since geospatial has its own mesh origin.
80        """
81        # Initialise variabels
82
83        if alpha is None:
84
85            self.alpha = DEFAULT_ALPHA
86        else:   
87            self.alpha = alpha
88        FitInterpolate.__init__(self,
89                 vertex_coordinates,
90                 triangles,
91                 mesh_origin,
92                 verbose,
93                 max_vertices_per_cell)
94       
95        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (vertices)
96       
97        self.AtA = None
98        self.Atz = None
99
100        self.point_count = 0
101        if self.alpha <> 0:
102            if verbose: print 'Building smoothing matrix'
103            self._build_smoothing_matrix_D()
104           
105    def _build_coefficient_matrix_B(self,
106                                  verbose = False):
107        """
108        Build final coefficient matrix
109
110        Precon
111        If alpha is not zero, matrix D has been built
112        Matrix Ata has been built
113        """
114
115        if self.alpha <> 0:
116            #if verbose: print 'Building smoothing matrix'
117            #self._build_smoothing_matrix_D()
118            self.B = self.AtA + self.alpha*self.D
119        else:
120            self.B = self.AtA
121
122        #Convert self.B matrix to CSR format for faster matrix vector
123        self.B = Sparse_CSR(self.B)
124
125    def _build_smoothing_matrix_D(self):
126        """Build m x m smoothing matrix, where
127        m is the number of basis functions phi_k (one per vertex)
128
129        The smoothing matrix is defined as
130
131        D = D1 + D2
132
133        where
134
135        [D1]_{k,l} = \int_\Omega
136           \frac{\partial \phi_k}{\partial x}
137           \frac{\partial \phi_l}{\partial x}\,
138           dx dy
139
140        [D2]_{k,l} = \int_\Omega
141           \frac{\partial \phi_k}{\partial y}
142           \frac{\partial \phi_l}{\partial y}\,
143           dx dy
144
145
146        The derivatives \frac{\partial \phi_k}{\partial x},
147        \frac{\partial \phi_k}{\partial x} for a particular triangle
148        are obtained by computing the gradient a_k, b_k for basis function k
149        """
150       
151        #FIXME: algorithm might be optimised by computing local 9x9
152        #"element stiffness matrices:
153
154        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
155
156        self.D = Sparse(m,m)
157
158        #For each triangle compute contributions to D = D1+D2
159        for i in range(len(self.mesh)):
160
161            #Get area
162            area = self.mesh.areas[i]
163
164            #Get global vertex indices
165            v0 = self.mesh.triangles[i,0]
166            v1 = self.mesh.triangles[i,1]
167            v2 = self.mesh.triangles[i,2]
168
169            #Get the three vertex_points
170            xi0 = self.mesh.get_vertex_coordinate(i, 0)
171            xi1 = self.mesh.get_vertex_coordinate(i, 1)
172            xi2 = self.mesh.get_vertex_coordinate(i, 2)
173
174            #Compute gradients for each vertex
175            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
176                              1, 0, 0)
177
178            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
179                              0, 1, 0)
180
181            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
182                              0, 0, 1)
183
184            #Compute diagonal contributions
185            self.D[v0,v0] += (a0*a0 + b0*b0)*area
186            self.D[v1,v1] += (a1*a1 + b1*b1)*area
187            self.D[v2,v2] += (a2*a2 + b2*b2)*area
188
189            #Compute contributions for basis functions sharing edges
190            e01 = (a0*a1 + b0*b1)*area
191            self.D[v0,v1] += e01
192            self.D[v1,v0] += e01
193
194            e12 = (a1*a2 + b1*b2)*area
195            self.D[v1,v2] += e12
196            self.D[v2,v1] += e12
197
198            e20 = (a2*a0 + b2*b0)*area
199            self.D[v2,v0] += e20
200            self.D[v0,v2] += e20
201
202
203    def get_D(self):
204        return self.D.todense()
205
206
207    def _build_matrix_AtA_Atz(self,
208                              point_coordinates,
209                              z,
210                              verbose = False):
211        """Build:
212        AtA  m x m  interpolation matrix, and,
213        Atz  m x a  interpolation matrix where,
214        m is the number of basis functions phi_k (one per vertex)
215        a is the number of data attributes
216
217        This algorithm uses a quad tree data structure for fast binning of
218        data points.
219
220        If Ata is None, the matrices AtA and Atz are created.
221
222        This function can be called again and again, with sub-sets of
223        the point coordinates.  Call fit to get the results.
224       
225        Preconditions
226        z and points are numeric
227        Point_coordindates and mesh vertices have the same origin.
228
229        The number of attributes of the data points does not change
230        """
231        #Build n x m interpolation matrix
232
233        if self.AtA == None:
234            # AtA and Atz need ot be initialised.
235            m = self.mesh.coordinates.shape[0] #Nbr of vertices
236            if len(z.shape) > 1:
237                att_num = z.shape[1]
238                self.Atz = zeros((m,att_num), Float)
239            else:
240                att_num = 1
241                self.Atz = zeros((m,), Float)
242            assert z.shape[0] == point_coordinates.shape[0] 
243
244            self.AtA = Sparse(m,m)
245        self.point_count += point_coordinates.shape[0]
246        #print "_build_matrix_AtA_Atz - self.point_count", self.point_count
247        if verbose: print 'Getting indices inside mesh boundary'
248        #print 'point_coordinates.shape', point_coordinates.shape         
249        #print 'self.mesh.get_boundary_polygon()',\
250        #      self.mesh.get_boundary_polygon()
251
252        self.inside_poly_indices, self.outside_poly_indices  = \
253                     in_and_outside_polygon(point_coordinates,
254                                            self.mesh.get_boundary_polygon(),
255                                            closed = True, verbose = verbose)
256        #print "self.inside_poly_indices",self.inside_poly_indices
257        #print "self.outside_poly_indices",self.outside_poly_indices
258
259       
260        n = len(self.inside_poly_indices)
261        #Compute matrix elements for points inside the mesh
262        for i in self.inside_poly_indices:
263            #For each data_coordinate point
264            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
265            x = point_coordinates[i]
266            element_found, sigma0, sigma1, sigma2, k = \
267                           search_tree_of_vertices(self.root, self.mesh, x)
268           
269            if element_found is True:
270                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
271                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
272                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
273
274                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
275                js     = [j0,j1,j2]
276
277                for j in js:
278                    self.Atz[j] +=  sigmas[j]*z[i]
279                    #print "self.Atz building", self.Atz
280                    #print "self.Atz[j]", self.Atz[j]
281                    #print " sigmas[j]", sigmas[j]
282                    #print "z[i]",z[i]
283                    #print "result", sigmas[j]*z[i]
284                   
285                    for k in js:
286                        self.AtA[j,k] += sigmas[j]*sigmas[k]
287            else:
288                msg = 'Could not find triangle for point', x
289                raise Exception(msg)
290   
291       
292    def fit(self, point_coordinates=None, z=None,
293                              verbose = False,
294                              point_origin = None):
295        """Fit a smooth surface to given 1d array of data points z.
296
297        The smooth surface is computed at each vertex in the underlying
298        mesh using the formula given in the module doc string.
299
300        Inputs:
301        point_coordinates: The co-ordinates of the data points.
302              List of coordinate pairs [x, y] of
303              data points or an nx2 Numeric array or a Geospatial_data object
304          z: Single 1d vector or array of data at the point_coordinates.
305         
306        """
307        if point_coordinates is None:
308            assert self.AtA <> None
309            assert self.Atz <> None
310            #FIXME (DSG) - do  a message
311        else:
312            point_coordinates = ensure_absolute(point_coordinates,
313                                                geo_reference=point_origin)
314            self.build_fit_subset(point_coordinates, z, verbose)
315
316        #Check sanity
317        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
318        n = self.point_count
319        if n<m and self.alpha == 0.0:
320            msg = 'ERROR (least_squares): Too few data points\n'
321            msg += 'There are only %d data points and alpha == 0. ' %n
322            msg += 'Need at least %d\n' %m
323            msg += 'Alternatively, set smoothing parameter alpha to a small '
324            msg += 'positive value,\ne.g. 1.0e-3.'
325            raise ToFewPointsError(msg)
326
327        self._build_coefficient_matrix_B(verbose)
328        loners = self.mesh.get_lone_vertices()
329        # FIXME  - make this as error message.
330        # test with
331        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
332        if len(loners)>0:
333            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
334            msg += 'All vertices should be part of a triangle.\n'
335            msg += 'In the future this will be inforced.\n'
336            msg += 'The following vertices are not part of a triangle;\n'
337            msg += str(loners)
338            print msg
339            #raise VertsWithNoTrianglesError(msg)
340       
341       
342        return conjugate_gradient(self.B, self.Atz, self.Atz,
343                                  imax=2*len(self.Atz) )
344
345       
346    def build_fit_subset(self, point_coordinates, z,
347                              verbose = False):
348        """Fit a smooth surface to given 1d array of data points z.
349
350        The smooth surface is computed at each vertex in the underlying
351        mesh using the formula given in the module doc string.
352
353        Inputs:
354        point_coordinates: The co-ordinates of the data points.
355              List of coordinate pairs [x, y] of
356              data points or an nx2 Numeric array or a Geospatial_data object
357          z: Single 1d vector or array of data at the point_coordinates.
358
359        """
360        #Note: Don't get the z info from Geospatial_data.attributes yet.
361        # If we did fit would have to handle attribute title info.
362
363        #FIXME(DSG-DSG): Check that the vert and point coords
364        #have the same zone.
365        if isinstance(point_coordinates,Geospatial_data):
366            point_coordinates = point_coordinates.get_data_points( \
367                absolute = True)
368       
369        #Convert input to Numeric arrays
370        z = ensure_numeric(z, Float)
371        point_coordinates = ensure_numeric(point_coordinates, Float)
372
373        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
374
375
376############################################################################
377
378def fit_to_mesh(vertex_coordinates,
379                triangles,
380                point_coordinates,
381                point_attributes,
382                alpha = DEFAULT_ALPHA,
383                verbose = False,
384                acceptable_overshoot = 1.01,
385                mesh_origin = None,
386                data_origin = None,
387                use_cache = False):
388    """
389    Fit a smooth surface to a triangulation,
390    given data points with attributes.
391
392
393        Inputs:
394        vertex_coordinates: List of coordinate pairs [xi, eta] of
395              points constituting a mesh (or an m x 2 Numeric array or
396              a geospatial object)
397              Points may appear multiple times
398              (e.g. if vertices have discontinuities)
399
400          triangles: List of 3-tuples (or a Numeric array) of
401          integers representing indices of all vertices in the mesh.
402
403          point_coordinates: List of coordinate pairs [x, y] of data points
404          (or an nx2 Numeric array)
405
406          alpha: Smoothing parameter.
407
408          acceptable overshoot: controls the allowed factor by which fitted values
409          may exceed the value of input data. The lower limit is defined
410          as min(z) - acceptable_overshoot*delta z and upper limit
411          as max(z) + acceptable_overshoot*delta z
412
413          mesh_origin: A geo_reference object or 3-tuples consisting of
414              UTM zone, easting and northing.
415              If specified vertex coordinates are assumed to be
416              relative to their respective origins.
417         
418
419          point_attributes: Vector or array of data at the
420                            point_coordinates.
421
422    """
423    #Since this is a wrapper for fit, lets handle the geo_spatial att's
424    if use_cache is True:
425        interp = cache(_fit,
426                       (vertex_coordinates,
427                        triangles),
428                       {'verbose': verbose,
429                        'mesh_origin': mesh_origin,
430                        'alpha':alpha},
431                       verbose = verbose)       
432       
433    else:
434        interp = Fit(vertex_coordinates,
435                     triangles,
436                     verbose = verbose,
437                     mesh_origin = mesh_origin,
438                     alpha=alpha)
439       
440    vertex_attributes = interp.fit(point_coordinates,
441                                   point_attributes,
442                                   point_origin = data_origin,
443                                   verbose = verbose)
444
445       
446    # Add the value checking stuff that's in least squares.
447    # Maybe this stuff should get pushed down into Fit.
448    # at least be a method of Fit.
449    # Or intigrate it into the fit method, saving teh max and min's
450    # as att's.
451   
452    return vertex_attributes
453
454
455def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
456                     alpha=DEFAULT_ALPHA, verbose= False,
457                     display_errors = True):
458    """
459    Given a mesh file (tsh) and a point attribute file (xya), fit
460    point attributes to the mesh and write a mesh file with the
461    results.
462
463
464    If data_origin is not None it is assumed to be
465    a 3-tuple with geo referenced
466    UTM coordinates (zone, easting, northing)
467
468    NOTE: Throws IOErrors, for a variety of file problems.
469   
470    """
471
472    # Question
473    # should data_origin and mesh_origin be passed in?
474    # No they should be in the data structure
475    #
476    #Should the origin of the mesh be changed using this function?
477    # That is overloading this function.  Have it as a seperate
478    # method, at least initially.
479   
480    from load_mesh.loadASCII import import_mesh_file, \
481                 import_points_file, export_mesh_file, \
482                 concatinate_attributelist
483
484    # FIXME: Use geospatial instead of import_points_file
485    try:
486        mesh_dict = import_mesh_file(mesh_file)
487    except IOError,e:
488        if display_errors:
489            print "Could not load bad file. ", e
490        raise IOError  #Re-raise exception
491       
492    vertex_coordinates = mesh_dict['vertices']
493    triangles = mesh_dict['triangles']
494    if type(mesh_dict['vertex_attributes']) == ArrayType:
495        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
496    else:
497        old_point_attributes = mesh_dict['vertex_attributes']
498
499    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
500        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
501    else:
502        old_title_list = mesh_dict['vertex_attribute_titles']
503
504    if verbose: print 'tsh file %s loaded' %mesh_file
505
506    # load in the .pts file
507    try:
508        point_dict = import_points_file(point_file, verbose=verbose)
509    except IOError,e:
510        if display_errors:
511            print "Could not load bad file. ", e
512        raise IOError  #Re-raise exception 
513
514    point_coordinates = point_dict['pointlist']
515    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
516
517    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
518        data_origin = point_dict['geo_reference'].get_origin()
519    else:
520        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
521
522    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
523        mesh_origin = mesh_dict['geo_reference'].get_origin()
524    else:
525        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
526
527    if verbose: print "points file loaded"
528    if verbose: print "fitting to mesh"
529    f = fit_to_mesh(vertex_coordinates,
530                    triangles,
531                    point_coordinates,
532                    point_attributes,
533                    alpha = alpha,
534                    verbose = verbose,
535                    data_origin = data_origin,
536                    mesh_origin = mesh_origin)
537    if verbose: print "finished fitting to mesh"
538
539    # convert array to list of lists
540    new_point_attributes = f.tolist()
541    #FIXME have this overwrite attributes with the same title - DSG
542    #Put the newer attributes last
543    if old_title_list <> []:
544        old_title_list.extend(title_list)
545        #FIXME can this be done a faster way? - DSG
546        for i in range(len(old_point_attributes)):
547            old_point_attributes[i].extend(new_point_attributes[i])
548        mesh_dict['vertex_attributes'] = old_point_attributes
549        mesh_dict['vertex_attribute_titles'] = old_title_list
550    else:
551        mesh_dict['vertex_attributes'] = new_point_attributes
552        mesh_dict['vertex_attribute_titles'] = title_list
553
554    if verbose: print "exporting to file ", mesh_output_file
555
556    try:
557        export_mesh_file(mesh_output_file, mesh_dict)
558    except IOError,e:
559        if display_errors:
560            print "Could not write file. ", e
561        raise IOError
562
563
564def _fit(*args, **kwargs):
565    """Private function for use with caching. Reason is that classes
566    may change their byte code between runs which is annoying.
567    """
568   
569    return Fit(*args, **kwargs)
570
Note: See TracBrowser for help on using the repository browser.