source: inundation/fit_interpolate/fit.py @ 2942

Last change on this file since 2942 was 2939, checked in by duncan, 19 years ago

adding fit_to_mesh_file

File size: 19.9 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18
19   TO DO
20   * test geo_ref, geo_spatial
21"""
22
23from Numeric import zeros, Float, ArrayType
24
25from geospatial_data.geospatial_data import Geospatial_data, ensure_absolute
26from fit_interpolate.general_fit_interpolate import FitInterpolate
27from utilities.sparse import Sparse, Sparse_CSR
28from utilities.polygon import in_and_outside_polygon
29from fit_interpolate.search_functions import search_tree_of_vertices
30from utilities.cg_solve import conjugate_gradient
31from utilities.numerical_tools import ensure_numeric, gradient
32
33import exceptions
34class ToFewPointsError(exceptions.Exception): pass
35
36DEFAULT_ALPHA = 0.001
37
38
39class Fit(FitInterpolate):
40   
41    def __init__(self,
42                 vertex_coordinates,
43                 triangles,
44                 mesh_origin=None,
45                 alpha = None,
46                 verbose=False,
47                 max_vertices_per_cell=30):
48
49
50        """
51        Fit data at points to the vertices of a mesh.
52
53        Inputs:
54
55          vertex_coordinates: List of coordinate pairs [xi, eta] of
56              points constituting a mesh (or an m x 2 Numeric array or
57              a geospatial object)
58              Points may appear multiple times
59              (e.g. if vertices have discontinuities)
60
61          triangles: List of 3-tuples (or a Numeric array) of
62              integers representing indices of all vertices in the mesh.
63
64          mesh_origin: A geo_reference object or 3-tuples consisting of
65              UTM zone, easting and northing.
66              If specified vertex coordinates are assumed to be
67              relative to their respective origins.
68
69          max_vertices_per_cell: Number of vertices in a quad tree cell
70          at which the cell is split into 4.
71
72          Note: Don't supply a vertex coords as a geospatial object and
73              a mesh origin, since geospatial has its own mesh origin.
74        """
75
76        # Initialise variabels
77        #self._A_can_be_reused = False
78        #self._point_coordinates = None
79
80        if alpha is None:
81
82            self.alpha = DEFAULT_ALPHA
83        else:   
84            self.alpha = alpha
85        FitInterpolate.__init__(self,
86                 vertex_coordinates,
87                 triangles,
88                 mesh_origin,
89                 verbose,
90                 max_vertices_per_cell)
91       
92        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (vertices)
93       
94        self.AtA = None
95        self.Atz = None
96
97        self.point_count = 0
98        if self.alpha <> 0:
99            if verbose: print 'Building smoothing matrix'
100            self._build_smoothing_matrix_D()
101           
102    def _build_coefficient_matrix_B(self,
103                                  verbose = False):
104        """
105        Build final coefficient matrix
106
107        Precon
108        If alpha is not zero, matrix D has been built
109        Matrix Ata has been built
110        """
111
112        if self.alpha <> 0:
113            #if verbose: print 'Building smoothing matrix'
114            #self._build_smoothing_matrix_D()
115            self.B = self.AtA + self.alpha*self.D
116        else:
117            self.B = self.AtA
118
119        #Convert self.B matrix to CSR format for faster matrix vector
120        self.B = Sparse_CSR(self.B)
121
122    def _build_smoothing_matrix_D(self):
123        """Build m x m smoothing matrix, where
124        m is the number of basis functions phi_k (one per vertex)
125
126        The smoothing matrix is defined as
127
128        D = D1 + D2
129
130        where
131
132        [D1]_{k,l} = \int_\Omega
133           \frac{\partial \phi_k}{\partial x}
134           \frac{\partial \phi_l}{\partial x}\,
135           dx dy
136
137        [D2]_{k,l} = \int_\Omega
138           \frac{\partial \phi_k}{\partial y}
139           \frac{\partial \phi_l}{\partial y}\,
140           dx dy
141
142
143        The derivatives \frac{\partial \phi_k}{\partial x},
144        \frac{\partial \phi_k}{\partial x} for a particular triangle
145        are obtained by computing the gradient a_k, b_k for basis function k
146        """
147       
148        #FIXME: algorithm might be optimised by computing local 9x9
149        #"element stiffness matrices:
150
151        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
152
153        self.D = Sparse(m,m)
154
155        #For each triangle compute contributions to D = D1+D2
156        for i in range(len(self.mesh)):
157
158            #Get area
159            area = self.mesh.areas[i]
160
161            #Get global vertex indices
162            v0 = self.mesh.triangles[i,0]
163            v1 = self.mesh.triangles[i,1]
164            v2 = self.mesh.triangles[i,2]
165
166            #Get the three vertex_points
167            xi0 = self.mesh.get_vertex_coordinate(i, 0)
168            xi1 = self.mesh.get_vertex_coordinate(i, 1)
169            xi2 = self.mesh.get_vertex_coordinate(i, 2)
170
171            #Compute gradients for each vertex
172            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
173                              1, 0, 0)
174
175            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
176                              0, 1, 0)
177
178            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
179                              0, 0, 1)
180
181            #Compute diagonal contributions
182            self.D[v0,v0] += (a0*a0 + b0*b0)*area
183            self.D[v1,v1] += (a1*a1 + b1*b1)*area
184            self.D[v2,v2] += (a2*a2 + b2*b2)*area
185
186            #Compute contributions for basis functions sharing edges
187            e01 = (a0*a1 + b0*b1)*area
188            self.D[v0,v1] += e01
189            self.D[v1,v0] += e01
190
191            e12 = (a1*a2 + b1*b2)*area
192            self.D[v1,v2] += e12
193            self.D[v2,v1] += e12
194
195            e20 = (a2*a0 + b2*b0)*area
196            self.D[v2,v0] += e20
197            self.D[v0,v2] += e20
198
199
200    def get_D(self):
201        return self.D.todense()
202
203
204    def _build_matrix_AtA_Atz(self,
205                              point_coordinates,
206                              z,
207                              verbose = False):
208        """Build:
209        AtA  m x m  interpolation matrix, and,
210        Atz  m x a  interpolation matrix where,
211        m is the number of basis functions phi_k (one per vertex)
212        a is the number of data attributes
213
214        This algorithm uses a quad tree data structure for fast binning of
215        data points.
216
217        If Ata is None, the matrices AtA and Atz are created.
218
219        This function can be called again and again, with sub-sets of
220        the point coordinates.  Call fit to get the results.
221       
222        Preconditions
223        z and points are numeric
224        Point_coordindates and mesh vertices have the same origin.
225
226        The number of attributes of the data points does not change
227        """
228        #Build n x m interpolation matrix
229
230        if self.AtA == None:
231            # AtA and Atz need ot be initialised.
232            m = self.mesh.coordinates.shape[0] #Nbr of vertices
233            if len(z.shape) > 1:
234                att_num = z.shape[1]
235                self.Atz = zeros((m,att_num), Float)
236            else:
237                att_num = 1
238                self.Atz = zeros((m,), Float)
239            assert z.shape[0] == point_coordinates.shape[0] 
240
241            self.AtA = Sparse(m,m)
242        self.point_count += point_coordinates.shape[0]
243        #print "_build_matrix_AtA_Atz - self.point_count", self.point_count
244        if verbose: print 'Getting indices inside mesh boundary'
245        #print "self.mesh.get_boundary_polygon()",self.mesh.get_boundary_polygon()
246        self.inside_poly_indices, self.outside_poly_indices  = \
247                     in_and_outside_polygon(point_coordinates,
248                                            self.mesh.get_boundary_polygon(),
249                                            closed = True, verbose = verbose)
250        #print "self.inside_poly_indices",self.inside_poly_indices
251        #print "self.outside_poly_indices",self.outside_poly_indices
252
253       
254        n = len(self.inside_poly_indices)
255        #Compute matrix elements for points inside the mesh
256        for i in self.inside_poly_indices:
257            #For each data_coordinate point
258            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
259            x = point_coordinates[i]
260            element_found, sigma0, sigma1, sigma2, k = \
261                           search_tree_of_vertices(self.root, self.mesh, x)
262           
263            if element_found is True:
264                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
265                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
266                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
267
268                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
269                js     = [j0,j1,j2]
270
271                for j in js:
272                    self.Atz[j] +=  sigmas[j]*z[i]
273                    #print "self.Atz building", self.Atz
274                    #print "self.Atz[j]", self.Atz[j]
275                    #print " sigmas[j]", sigmas[j]
276                    #print "z[i]",z[i]
277                    #print "result", sigmas[j]*z[i]
278                   
279                    for k in js:
280                        self.AtA[j,k] += sigmas[j]*sigmas[k]
281            else:
282                msg = 'Could not find triangle for point', x
283                raise Exception(msg)
284   
285       
286    def fit(self, point_coordinates=None, z=None,
287                              verbose = False,
288                              point_origin = None):
289        """Fit a smooth surface to given 1d array of data points z.
290
291        The smooth surface is computed at each vertex in the underlying
292        mesh using the formula given in the module doc string.
293
294        Inputs:
295        point_coordinates: The co-ordinates of the data points.
296              List of coordinate pairs [x, y] of
297              data points or an nx2 Numeric array or a Geospatial_data object
298          z: Single 1d vector or array of data at the point_coordinates.
299         
300        """
301        if point_coordinates is None:
302            assert self.AtA <> None
303            assert self.Atz <> None
304            #FIXME (DSG) - do  a message
305        else:
306            point_coordinates = ensure_absolute(point_coordinates,
307                                                geo_reference=point_origin)
308            self.build_fit_subset(point_coordinates, z, verbose)
309       
310        #Check sanity
311        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
312        n = self.point_count
313        if n<m and self.alpha == 0.0:
314            msg = 'ERROR (least_squares): Too few data points\n'
315            msg += 'There are only %d data points and alpha == 0. ' %n
316            msg += 'Need at least %d\n' %m
317            msg += 'Alternatively, set smoothing parameter alpha to a small '
318            msg += 'positive value,\ne.g. 1.0e-3.'
319            raise ToFewPointsError(msg)
320
321        self._build_coefficient_matrix_B(verbose)
322        return conjugate_gradient(self.B, self.Atz, self.Atz,
323                                  imax=2*len(self.Atz) )
324
325       
326    def build_fit_subset(self, point_coordinates, z,
327                              verbose = False):
328        """Fit a smooth surface to given 1d array of data points z.
329
330        The smooth surface is computed at each vertex in the underlying
331        mesh using the formula given in the module doc string.
332
333        Inputs:
334        point_coordinates: The co-ordinates of the data points.
335              List of coordinate pairs [x, y] of
336              data points or an nx2 Numeric array or a Geospatial_data object
337          z: Single 1d vector or array of data at the point_coordinates.
338
339        """
340        #Note: Don't get the z info from Geospatial_data.attributes yet.
341        # If we did fit would have to handle attribute title info.
342
343        #FIXME(DSG-DSG): Check that the vert and point coords
344        #have the same zone.
345        if isinstance(point_coordinates,Geospatial_data):
346            point_coordinates = point_coordinates.get_data_points( \
347                absolute = True)
348       
349        #Convert input to Numeric arrays
350        z = ensure_numeric(z, Float)
351        point_coordinates = ensure_numeric(point_coordinates, Float)
352
353        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
354
355
356############################################################################
357
358def fit_to_mesh(vertex_coordinates,
359                triangles,
360                point_coordinates,
361                point_attributes,
362                alpha = DEFAULT_ALPHA,
363                verbose = False,
364                acceptable_overshoot = 1.01,
365                mesh_origin = None,
366                data_origin = None,
367                use_cache = False):
368    """
369    Fit a smooth surface to a triangulation,
370    given data points with attributes.
371
372
373        Inputs:
374        vertex_coordinates: List of coordinate pairs [xi, eta] of
375              points constituting a mesh (or an m x 2 Numeric array or
376              a geospatial object)
377              Points may appear multiple times
378              (e.g. if vertices have discontinuities)
379
380          triangles: List of 3-tuples (or a Numeric array) of
381          integers representing indices of all vertices in the mesh.
382
383          point_coordinates: List of coordinate pairs [x, y] of data points
384          (or an nx2 Numeric array)
385
386          alpha: Smoothing parameter.
387
388          acceptable overshoot: controls the allowed factor by which fitted values
389          may exceed the value of input data. The lower limit is defined
390          as min(z) - acceptable_overshoot*delta z and upper limit
391          as max(z) + acceptable_overshoot*delta z
392
393          mesh_origin: A geo_reference object or 3-tuples consisting of
394              UTM zone, easting and northing.
395              If specified vertex coordinates are assumed to be
396              relative to their respective origins.
397         
398
399          point_attributes: Vector or array of data at the
400                            point_coordinates.
401
402    """
403    #Since this is a wrapper for fit, lets handle the geo_spatial att's
404    if use_cache is True:
405        interp = cache(_fit,
406                       (vertex_coordinates,
407                        triangles),
408                       {'verbose': verbose,
409                        'mesh_origin': mesh_origin,
410                        'alpha':alpha},
411                       verbose = verbose)       
412       
413    else:
414        interp = Fit(vertex_coordinates,
415                     triangles,
416                     verbose = verbose,
417                     mesh_origin = mesh_origin,
418                     alpha=alpha)
419       
420    vertex_attributes = interp.fit(point_coordinates,
421                                   point_attributes,
422                                   point_origin = data_origin,
423                                   verbose = verbose)
424
425       
426    # Add the value checking stuff that's in least squares.
427    # Maybe this stuff should get pushed down into Fit.
428    # at least be a method of Fit.
429    # Or intigrate it into the fit method, saving teh max and min's
430    # as att's.
431   
432    return vertex_attributes
433
434
435def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
436                     alpha=DEFAULT_ALPHA, verbose= False,
437                     display_errors = True):
438    """
439    Given a mesh file (tsh) and a point attribute file (xya), fit
440    point attributes to the mesh and write a mesh file with the
441    results.
442
443
444    If data_origin is not None it is assumed to be
445    a 3-tuple with geo referenced
446    UTM coordinates (zone, easting, northing)
447
448    NOTE: Throws IOErrors, for a variety of file problems.
449   
450    """
451
452    # Question
453    # should data_origin and mesh_origin be passed in?
454    # No they should be in the data structure
455    #
456    #Should the origin of the mesh be changed using this function?
457    # That is overloading this function.  Have it as a seperate
458    # method, at least initially.
459   
460    from load_mesh.loadASCII import import_mesh_file, \
461                 import_points_file, export_mesh_file, \
462                 concatinate_attributelist
463
464    # FIXME: Use geospatial instead of import_points_file
465    try:
466        mesh_dict = import_mesh_file(mesh_file)
467    except IOError,e:
468        if display_errors:
469            print "Could not load bad file. ", e
470        raise IOError  #Re-raise exception
471       
472    vertex_coordinates = mesh_dict['vertices']
473    triangles = mesh_dict['triangles']
474    if type(mesh_dict['vertex_attributes']) == ArrayType:
475        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
476    else:
477        old_point_attributes = mesh_dict['vertex_attributes']
478
479    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
480        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
481    else:
482        old_title_list = mesh_dict['vertex_attribute_titles']
483
484    if verbose: print 'tsh file %s loaded' %mesh_file
485
486    # load in the .pts file
487    try:
488        point_dict = import_points_file(point_file, verbose=verbose)
489    except IOError,e:
490        if display_errors:
491            print "Could not load bad file. ", e
492        raise IOError  #Re-raise exception 
493
494    point_coordinates = point_dict['pointlist']
495    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
496
497    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
498        data_origin = point_dict['geo_reference'].get_origin()
499    else:
500        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
501
502    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
503        mesh_origin = mesh_dict['geo_reference'].get_origin()
504    else:
505        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
506
507    if verbose: print "points file loaded"
508    if verbose: print "fitting to mesh"
509    f = fit_to_mesh(vertex_coordinates,
510                    triangles,
511                    point_coordinates,
512                    point_attributes,
513                    alpha = alpha,
514                    verbose = verbose,
515                    data_origin = data_origin,
516                    mesh_origin = mesh_origin)
517    if verbose: print "finished fitting to mesh"
518
519    # convert array to list of lists
520    new_point_attributes = f.tolist()
521    #FIXME have this overwrite attributes with the same title - DSG
522    #Put the newer attributes last
523    if old_title_list <> []:
524        old_title_list.extend(title_list)
525        #FIXME can this be done a faster way? - DSG
526        for i in range(len(old_point_attributes)):
527            old_point_attributes[i].extend(new_point_attributes[i])
528        mesh_dict['vertex_attributes'] = old_point_attributes
529        mesh_dict['vertex_attribute_titles'] = old_title_list
530    else:
531        mesh_dict['vertex_attributes'] = new_point_attributes
532        mesh_dict['vertex_attribute_titles'] = title_list
533
534    if verbose: print "exporting to file ", mesh_output_file
535
536    try:
537        export_mesh_file(mesh_output_file, mesh_dict)
538    except IOError,e:
539        if display_errors:
540            print "Could not write file. ", e
541        raise IOError
542
543
544def _fit(*args, **kwargs):
545    """Private function for use with caching. Reason is that classes
546    may change their byte code between runs which is annoying.
547    """
548   
549    return Fit(*args, **kwargs)
550
Note: See TracBrowser for help on using the repository browser.