source: inundation/fit_interpolate/fit.py @ 3116

Last change on this file since 3116 was 3116, checked in by ole, 18 years ago

Supporting modifications for parallel wollongong

File size: 20.5 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18
19   TO DO
20   * test geo_ref, geo_spatial
21"""
22
23from Numeric import zeros, Float, ArrayType,take
24
25from geospatial_data.geospatial_data import Geospatial_data, ensure_absolute
26from fit_interpolate.general_fit_interpolate import FitInterpolate
27from utilities.sparse import Sparse, Sparse_CSR
28from utilities.polygon import in_and_outside_polygon
29from fit_interpolate.search_functions import search_tree_of_vertices
30from utilities.cg_solve import conjugate_gradient
31from utilities.numerical_tools import ensure_numeric, gradient
32
33import exceptions
34class ToFewPointsError(exceptions.Exception): pass
35class VertsWithNoTrianglesError(exceptions.Exception): pass
36
37DEFAULT_ALPHA = 0.001
38
39
40class Fit(FitInterpolate):
41   
42    def __init__(self,
43                 vertex_coordinates,
44                 triangles,
45                 mesh_origin=None,
46                 alpha = None,
47                 verbose=False,
48                 max_vertices_per_cell=30):
49
50
51        """
52        Fit data at points to the vertices of a mesh.
53
54        Inputs:
55
56          vertex_coordinates: List of coordinate pairs [xi, eta] of
57              points constituting a mesh (or an m x 2 Numeric array or
58              a geospatial object)
59              Points may appear multiple times
60              (e.g. if vertices have discontinuities)
61
62          triangles: List of 3-tuples (or a Numeric array) of
63              integers representing indices of all vertices in the mesh.
64
65          mesh_origin: A geo_reference object or 3-tuples consisting of
66              UTM zone, easting and northing.
67              If specified vertex coordinates are assumed to be
68              relative to their respective origins.
69
70          max_vertices_per_cell: Number of vertices in a quad tree cell
71          at which the cell is split into 4.
72
73          Note: Don't supply a vertex coords as a geospatial object and
74              a mesh origin, since geospatial has its own mesh origin.
75        """
76        # Initialise variabels
77
78        if alpha is None:
79
80            self.alpha = DEFAULT_ALPHA
81        else:   
82            self.alpha = alpha
83        FitInterpolate.__init__(self,
84                 vertex_coordinates,
85                 triangles,
86                 mesh_origin,
87                 verbose,
88                 max_vertices_per_cell)
89       
90        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (vertices)
91       
92        self.AtA = None
93        self.Atz = None
94
95        self.point_count = 0
96        if self.alpha <> 0:
97            if verbose: print 'Building smoothing matrix'
98            self._build_smoothing_matrix_D()
99           
100    def _build_coefficient_matrix_B(self,
101                                  verbose = False):
102        """
103        Build final coefficient matrix
104
105        Precon
106        If alpha is not zero, matrix D has been built
107        Matrix Ata has been built
108        """
109
110        if self.alpha <> 0:
111            #if verbose: print 'Building smoothing matrix'
112            #self._build_smoothing_matrix_D()
113            self.B = self.AtA + self.alpha*self.D
114        else:
115            self.B = self.AtA
116
117        #Convert self.B matrix to CSR format for faster matrix vector
118        self.B = Sparse_CSR(self.B)
119
120    def _build_smoothing_matrix_D(self):
121        """Build m x m smoothing matrix, where
122        m is the number of basis functions phi_k (one per vertex)
123
124        The smoothing matrix is defined as
125
126        D = D1 + D2
127
128        where
129
130        [D1]_{k,l} = \int_\Omega
131           \frac{\partial \phi_k}{\partial x}
132           \frac{\partial \phi_l}{\partial x}\,
133           dx dy
134
135        [D2]_{k,l} = \int_\Omega
136           \frac{\partial \phi_k}{\partial y}
137           \frac{\partial \phi_l}{\partial y}\,
138           dx dy
139
140
141        The derivatives \frac{\partial \phi_k}{\partial x},
142        \frac{\partial \phi_k}{\partial x} for a particular triangle
143        are obtained by computing the gradient a_k, b_k for basis function k
144        """
145       
146        #FIXME: algorithm might be optimised by computing local 9x9
147        #"element stiffness matrices:
148
149        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
150
151        self.D = Sparse(m,m)
152
153        #For each triangle compute contributions to D = D1+D2
154        for i in range(len(self.mesh)):
155
156            #Get area
157            area = self.mesh.areas[i]
158
159            #Get global vertex indices
160            v0 = self.mesh.triangles[i,0]
161            v1 = self.mesh.triangles[i,1]
162            v2 = self.mesh.triangles[i,2]
163
164            #Get the three vertex_points
165            xi0 = self.mesh.get_vertex_coordinate(i, 0)
166            xi1 = self.mesh.get_vertex_coordinate(i, 1)
167            xi2 = self.mesh.get_vertex_coordinate(i, 2)
168
169            #Compute gradients for each vertex
170            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
171                              1, 0, 0)
172
173            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
174                              0, 1, 0)
175
176            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
177                              0, 0, 1)
178
179            #Compute diagonal contributions
180            self.D[v0,v0] += (a0*a0 + b0*b0)*area
181            self.D[v1,v1] += (a1*a1 + b1*b1)*area
182            self.D[v2,v2] += (a2*a2 + b2*b2)*area
183
184            #Compute contributions for basis functions sharing edges
185            e01 = (a0*a1 + b0*b1)*area
186            self.D[v0,v1] += e01
187            self.D[v1,v0] += e01
188
189            e12 = (a1*a2 + b1*b2)*area
190            self.D[v1,v2] += e12
191            self.D[v2,v1] += e12
192
193            e20 = (a2*a0 + b2*b0)*area
194            self.D[v2,v0] += e20
195            self.D[v0,v2] += e20
196
197
198    def get_D(self):
199        return self.D.todense()
200
201
202    def _build_matrix_AtA_Atz(self,
203                              point_coordinates,
204                              z,
205                              verbose = False):
206        """Build:
207        AtA  m x m  interpolation matrix, and,
208        Atz  m x a  interpolation matrix where,
209        m is the number of basis functions phi_k (one per vertex)
210        a is the number of data attributes
211
212        This algorithm uses a quad tree data structure for fast binning of
213        data points.
214
215        If Ata is None, the matrices AtA and Atz are created.
216
217        This function can be called again and again, with sub-sets of
218        the point coordinates.  Call fit to get the results.
219       
220        Preconditions
221        z and points are numeric
222        Point_coordindates and mesh vertices have the same origin.
223
224        The number of attributes of the data points does not change
225        """
226        #Build n x m interpolation matrix
227
228        if self.AtA == None:
229            # AtA and Atz need ot be initialised.
230            m = self.mesh.coordinates.shape[0] #Nbr of vertices
231            if len(z.shape) > 1:
232                att_num = z.shape[1]
233                self.Atz = zeros((m,att_num), Float)
234            else:
235                att_num = 1
236                self.Atz = zeros((m,), Float)
237            assert z.shape[0] == point_coordinates.shape[0] 
238
239            self.AtA = Sparse(m,m)
240        self.point_count += point_coordinates.shape[0]
241        #print "_build_matrix_AtA_Atz - self.point_count", self.point_count
242        if verbose: print 'Getting indices inside mesh boundary'
243        #print 'point_coordinates.shape', point_coordinates.shape         
244        #print 'self.mesh.get_boundary_polygon()',\
245        #      self.mesh.get_boundary_polygon()
246
247        self.inside_poly_indices, self.outside_poly_indices  = \
248                     in_and_outside_polygon(point_coordinates,
249                                            self.mesh.get_boundary_polygon(),
250                                            closed = True, verbose = verbose)
251        #print "self.inside_poly_indices",self.inside_poly_indices
252        #print "self.outside_poly_indices",self.outside_poly_indices
253
254       
255        n = len(self.inside_poly_indices)
256        #Compute matrix elements for points inside the mesh
257        for i in self.inside_poly_indices:
258            #For each data_coordinate point
259            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
260            x = point_coordinates[i]
261            element_found, sigma0, sigma1, sigma2, k = \
262                           search_tree_of_vertices(self.root, self.mesh, x)
263           
264            if element_found is True:
265                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
266                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
267                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
268
269                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
270                js     = [j0,j1,j2]
271
272                for j in js:
273                    self.Atz[j] +=  sigmas[j]*z[i]
274                    #print "self.Atz building", self.Atz
275                    #print "self.Atz[j]", self.Atz[j]
276                    #print " sigmas[j]", sigmas[j]
277                    #print "z[i]",z[i]
278                    #print "result", sigmas[j]*z[i]
279                   
280                    for k in js:
281                        self.AtA[j,k] += sigmas[j]*sigmas[k]
282            else:
283                msg = 'Could not find triangle for point', x
284                raise Exception(msg)
285   
286       
287    def fit(self, point_coordinates=None, z=None,
288                              verbose = False,
289                              point_origin = None):
290        """Fit a smooth surface to given 1d array of data points z.
291
292        The smooth surface is computed at each vertex in the underlying
293        mesh using the formula given in the module doc string.
294
295        Inputs:
296        point_coordinates: The co-ordinates of the data points.
297              List of coordinate pairs [x, y] of
298              data points or an nx2 Numeric array or a Geospatial_data object
299          z: Single 1d vector or array of data at the point_coordinates.
300         
301        """
302        if point_coordinates is None:
303            assert self.AtA <> None
304            assert self.Atz <> None
305            #FIXME (DSG) - do  a message
306        else:
307            point_coordinates = ensure_absolute(point_coordinates,
308                                                geo_reference=point_origin)
309            self.build_fit_subset(point_coordinates, z, verbose)
310
311        #Check sanity
312        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
313        n = self.point_count
314        if n<m and self.alpha == 0.0:
315            msg = 'ERROR (least_squares): Too few data points\n'
316            msg += 'There are only %d data points and alpha == 0. ' %n
317            msg += 'Need at least %d\n' %m
318            msg += 'Alternatively, set smoothing parameter alpha to a small '
319            msg += 'positive value,\ne.g. 1.0e-3.'
320            raise ToFewPointsError(msg)
321
322        self._build_coefficient_matrix_B(verbose)
323        loners = self.mesh.get_lone_vertices()
324        # FIXME  - make this as error message.
325        # test with
326        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
327        if len(loners)>0:
328            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
329            msg += 'All vertices should be part of a triangle.\n'
330            msg += 'In the future this will be inforced.\n'
331            msg += 'The following vertices are not part of a triangle;\n'
332            msg += str(loners)
333            print msg
334            #raise VertsWithNoTrianglesError(msg)
335       
336       
337        return conjugate_gradient(self.B, self.Atz, self.Atz,
338                                  imax=2*len(self.Atz) )
339
340       
341    def build_fit_subset(self, point_coordinates, z,
342                              verbose = False):
343        """Fit a smooth surface to given 1d array of data points z.
344
345        The smooth surface is computed at each vertex in the underlying
346        mesh using the formula given in the module doc string.
347
348        Inputs:
349        point_coordinates: The co-ordinates of the data points.
350              List of coordinate pairs [x, y] of
351              data points or an nx2 Numeric array or a Geospatial_data object
352          z: Single 1d vector or array of data at the point_coordinates.
353
354        """
355        #Note: Don't get the z info from Geospatial_data.attributes yet.
356        # If we did fit would have to handle attribute title info.
357
358        #FIXME(DSG-DSG): Check that the vert and point coords
359        #have the same zone.
360        if isinstance(point_coordinates,Geospatial_data):
361            point_coordinates = point_coordinates.get_data_points( \
362                absolute = True)
363       
364        #Convert input to Numeric arrays
365        z = ensure_numeric(z, Float)
366        point_coordinates = ensure_numeric(point_coordinates, Float)
367
368        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
369
370
371############################################################################
372
373def fit_to_mesh(vertex_coordinates,
374                triangles,
375                point_coordinates,
376                point_attributes,
377                alpha = DEFAULT_ALPHA,
378                verbose = False,
379                acceptable_overshoot = 1.01,
380                mesh_origin = None,
381                data_origin = None,
382                use_cache = False):
383    """
384    Fit a smooth surface to a triangulation,
385    given data points with attributes.
386
387
388        Inputs:
389        vertex_coordinates: List of coordinate pairs [xi, eta] of
390              points constituting a mesh (or an m x 2 Numeric array or
391              a geospatial object)
392              Points may appear multiple times
393              (e.g. if vertices have discontinuities)
394
395          triangles: List of 3-tuples (or a Numeric array) of
396          integers representing indices of all vertices in the mesh.
397
398          point_coordinates: List of coordinate pairs [x, y] of data points
399          (or an nx2 Numeric array)
400
401          alpha: Smoothing parameter.
402
403          acceptable overshoot: controls the allowed factor by which fitted values
404          may exceed the value of input data. The lower limit is defined
405          as min(z) - acceptable_overshoot*delta z and upper limit
406          as max(z) + acceptable_overshoot*delta z
407
408          mesh_origin: A geo_reference object or 3-tuples consisting of
409              UTM zone, easting and northing.
410              If specified vertex coordinates are assumed to be
411              relative to their respective origins.
412         
413
414          point_attributes: Vector or array of data at the
415                            point_coordinates.
416
417    """
418    #Since this is a wrapper for fit, lets handle the geo_spatial att's
419    if use_cache is True:
420        interp = cache(_fit,
421                       (vertex_coordinates,
422                        triangles),
423                       {'verbose': verbose,
424                        'mesh_origin': mesh_origin,
425                        'alpha':alpha},
426                       verbose = verbose)       
427       
428    else:
429        interp = Fit(vertex_coordinates,
430                     triangles,
431                     verbose = verbose,
432                     mesh_origin = mesh_origin,
433                     alpha=alpha)
434       
435    vertex_attributes = interp.fit(point_coordinates,
436                                   point_attributes,
437                                   point_origin = data_origin,
438                                   verbose = verbose)
439
440       
441    # Add the value checking stuff that's in least squares.
442    # Maybe this stuff should get pushed down into Fit.
443    # at least be a method of Fit.
444    # Or intigrate it into the fit method, saving teh max and min's
445    # as att's.
446   
447    return vertex_attributes
448
449
450def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
451                     alpha=DEFAULT_ALPHA, verbose= False,
452                     display_errors = True):
453    """
454    Given a mesh file (tsh) and a point attribute file (xya), fit
455    point attributes to the mesh and write a mesh file with the
456    results.
457
458
459    If data_origin is not None it is assumed to be
460    a 3-tuple with geo referenced
461    UTM coordinates (zone, easting, northing)
462
463    NOTE: Throws IOErrors, for a variety of file problems.
464   
465    """
466
467    # Question
468    # should data_origin and mesh_origin be passed in?
469    # No they should be in the data structure
470    #
471    #Should the origin of the mesh be changed using this function?
472    # That is overloading this function.  Have it as a seperate
473    # method, at least initially.
474   
475    from load_mesh.loadASCII import import_mesh_file, \
476                 import_points_file, export_mesh_file, \
477                 concatinate_attributelist
478
479    # FIXME: Use geospatial instead of import_points_file
480    try:
481        mesh_dict = import_mesh_file(mesh_file)
482    except IOError,e:
483        if display_errors:
484            print "Could not load bad file. ", e
485        raise IOError  #Re-raise exception
486       
487    vertex_coordinates = mesh_dict['vertices']
488    triangles = mesh_dict['triangles']
489    if type(mesh_dict['vertex_attributes']) == ArrayType:
490        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
491    else:
492        old_point_attributes = mesh_dict['vertex_attributes']
493
494    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
495        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
496    else:
497        old_title_list = mesh_dict['vertex_attribute_titles']
498
499    if verbose: print 'tsh file %s loaded' %mesh_file
500
501    # load in the .pts file
502    try:
503        point_dict = import_points_file(point_file, verbose=verbose)
504    except IOError,e:
505        if display_errors:
506            print "Could not load bad file. ", e
507        raise IOError  #Re-raise exception 
508
509    point_coordinates = point_dict['pointlist']
510    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
511
512    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
513        data_origin = point_dict['geo_reference'].get_origin()
514    else:
515        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
516
517    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
518        mesh_origin = mesh_dict['geo_reference'].get_origin()
519    else:
520        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
521
522    if verbose: print "points file loaded"
523    if verbose: print "fitting to mesh"
524    f = fit_to_mesh(vertex_coordinates,
525                    triangles,
526                    point_coordinates,
527                    point_attributes,
528                    alpha = alpha,
529                    verbose = verbose,
530                    data_origin = data_origin,
531                    mesh_origin = mesh_origin)
532    if verbose: print "finished fitting to mesh"
533
534    # convert array to list of lists
535    new_point_attributes = f.tolist()
536    #FIXME have this overwrite attributes with the same title - DSG
537    #Put the newer attributes last
538    if old_title_list <> []:
539        old_title_list.extend(title_list)
540        #FIXME can this be done a faster way? - DSG
541        for i in range(len(old_point_attributes)):
542            old_point_attributes[i].extend(new_point_attributes[i])
543        mesh_dict['vertex_attributes'] = old_point_attributes
544        mesh_dict['vertex_attribute_titles'] = old_title_list
545    else:
546        mesh_dict['vertex_attributes'] = new_point_attributes
547        mesh_dict['vertex_attribute_titles'] = title_list
548
549    if verbose: print "exporting to file ", mesh_output_file
550
551    try:
552        export_mesh_file(mesh_output_file, mesh_dict)
553    except IOError,e:
554        if display_errors:
555            print "Could not write file. ", e
556        raise IOError
557
558
559def _fit(*args, **kwargs):
560    """Private function for use with caching. Reason is that classes
561    may change their byte code between runs which is annoying.
562    """
563   
564    return Fit(*args, **kwargs)
565
Note: See TracBrowser for help on using the repository browser.