source: inundation/fit_interpolate/fit.py @ 3030

Last change on this file since 3030 was 3030, checked in by duncan, 18 years ago

improve warning message

File size: 20.5 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18
19   TO DO
20   * test geo_ref, geo_spatial
21"""
22
23from Numeric import zeros, Float, ArrayType,take
24
25from geospatial_data.geospatial_data import Geospatial_data, ensure_absolute
26from fit_interpolate.general_fit_interpolate import FitInterpolate
27from utilities.sparse import Sparse, Sparse_CSR
28from utilities.polygon import in_and_outside_polygon
29from fit_interpolate.search_functions import search_tree_of_vertices
30from utilities.cg_solve import conjugate_gradient
31from utilities.numerical_tools import ensure_numeric, gradient
32
33import exceptions
34class ToFewPointsError(exceptions.Exception): pass
35class VertsWithNoTrianglesError(exceptions.Exception): pass
36
37DEFAULT_ALPHA = 0.001
38
39
40class Fit(FitInterpolate):
41   
42    def __init__(self,
43                 vertex_coordinates,
44                 triangles,
45                 mesh_origin=None,
46                 alpha = None,
47                 verbose=False,
48                 max_vertices_per_cell=30):
49
50
51        """
52        Fit data at points to the vertices of a mesh.
53
54        Inputs:
55
56          vertex_coordinates: List of coordinate pairs [xi, eta] of
57              points constituting a mesh (or an m x 2 Numeric array or
58              a geospatial object)
59              Points may appear multiple times
60              (e.g. if vertices have discontinuities)
61
62          triangles: List of 3-tuples (or a Numeric array) of
63              integers representing indices of all vertices in the mesh.
64
65          mesh_origin: A geo_reference object or 3-tuples consisting of
66              UTM zone, easting and northing.
67              If specified vertex coordinates are assumed to be
68              relative to their respective origins.
69
70          max_vertices_per_cell: Number of vertices in a quad tree cell
71          at which the cell is split into 4.
72
73          Note: Don't supply a vertex coords as a geospatial object and
74              a mesh origin, since geospatial has its own mesh origin.
75        """
76        # Initialise variabels
77
78        if alpha is None:
79
80            self.alpha = DEFAULT_ALPHA
81        else:   
82            self.alpha = alpha
83        FitInterpolate.__init__(self,
84                 vertex_coordinates,
85                 triangles,
86                 mesh_origin,
87                 verbose,
88                 max_vertices_per_cell)
89       
90        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (vertices)
91       
92        self.AtA = None
93        self.Atz = None
94
95        self.point_count = 0
96        if self.alpha <> 0:
97            if verbose: print 'Building smoothing matrix'
98            self._build_smoothing_matrix_D()
99           
100    def _build_coefficient_matrix_B(self,
101                                  verbose = False):
102        """
103        Build final coefficient matrix
104
105        Precon
106        If alpha is not zero, matrix D has been built
107        Matrix Ata has been built
108        """
109
110        if self.alpha <> 0:
111            #if verbose: print 'Building smoothing matrix'
112            #self._build_smoothing_matrix_D()
113            self.B = self.AtA + self.alpha*self.D
114        else:
115            self.B = self.AtA
116
117        #Convert self.B matrix to CSR format for faster matrix vector
118        self.B = Sparse_CSR(self.B)
119
120    def _build_smoothing_matrix_D(self):
121        """Build m x m smoothing matrix, where
122        m is the number of basis functions phi_k (one per vertex)
123
124        The smoothing matrix is defined as
125
126        D = D1 + D2
127
128        where
129
130        [D1]_{k,l} = \int_\Omega
131           \frac{\partial \phi_k}{\partial x}
132           \frac{\partial \phi_l}{\partial x}\,
133           dx dy
134
135        [D2]_{k,l} = \int_\Omega
136           \frac{\partial \phi_k}{\partial y}
137           \frac{\partial \phi_l}{\partial y}\,
138           dx dy
139
140
141        The derivatives \frac{\partial \phi_k}{\partial x},
142        \frac{\partial \phi_k}{\partial x} for a particular triangle
143        are obtained by computing the gradient a_k, b_k for basis function k
144        """
145       
146        #FIXME: algorithm might be optimised by computing local 9x9
147        #"element stiffness matrices:
148
149        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
150
151        self.D = Sparse(m,m)
152
153        #For each triangle compute contributions to D = D1+D2
154        for i in range(len(self.mesh)):
155
156            #Get area
157            area = self.mesh.areas[i]
158
159            #Get global vertex indices
160            v0 = self.mesh.triangles[i,0]
161            v1 = self.mesh.triangles[i,1]
162            v2 = self.mesh.triangles[i,2]
163
164            #Get the three vertex_points
165            xi0 = self.mesh.get_vertex_coordinate(i, 0)
166            xi1 = self.mesh.get_vertex_coordinate(i, 1)
167            xi2 = self.mesh.get_vertex_coordinate(i, 2)
168
169            #Compute gradients for each vertex
170            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
171                              1, 0, 0)
172
173            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
174                              0, 1, 0)
175
176            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
177                              0, 0, 1)
178
179            #Compute diagonal contributions
180            self.D[v0,v0] += (a0*a0 + b0*b0)*area
181            self.D[v1,v1] += (a1*a1 + b1*b1)*area
182            self.D[v2,v2] += (a2*a2 + b2*b2)*area
183
184            #Compute contributions for basis functions sharing edges
185            e01 = (a0*a1 + b0*b1)*area
186            self.D[v0,v1] += e01
187            self.D[v1,v0] += e01
188
189            e12 = (a1*a2 + b1*b2)*area
190            self.D[v1,v2] += e12
191            self.D[v2,v1] += e12
192
193            e20 = (a2*a0 + b2*b0)*area
194            self.D[v2,v0] += e20
195            self.D[v0,v2] += e20
196
197
198    def get_D(self):
199        return self.D.todense()
200
201
202    def _build_matrix_AtA_Atz(self,
203                              point_coordinates,
204                              z,
205                              verbose = False):
206        """Build:
207        AtA  m x m  interpolation matrix, and,
208        Atz  m x a  interpolation matrix where,
209        m is the number of basis functions phi_k (one per vertex)
210        a is the number of data attributes
211
212        This algorithm uses a quad tree data structure for fast binning of
213        data points.
214
215        If Ata is None, the matrices AtA and Atz are created.
216
217        This function can be called again and again, with sub-sets of
218        the point coordinates.  Call fit to get the results.
219       
220        Preconditions
221        z and points are numeric
222        Point_coordindates and mesh vertices have the same origin.
223
224        The number of attributes of the data points does not change
225        """
226        #Build n x m interpolation matrix
227
228        if self.AtA == None:
229            # AtA and Atz need ot be initialised.
230            m = self.mesh.coordinates.shape[0] #Nbr of vertices
231            if len(z.shape) > 1:
232                att_num = z.shape[1]
233                self.Atz = zeros((m,att_num), Float)
234            else:
235                att_num = 1
236                self.Atz = zeros((m,), Float)
237            assert z.shape[0] == point_coordinates.shape[0] 
238
239            self.AtA = Sparse(m,m)
240        self.point_count += point_coordinates.shape[0]
241        #print "_build_matrix_AtA_Atz - self.point_count", self.point_count
242        if verbose: print 'Getting indices inside mesh boundary'
243        #print "self.mesh.get_boundary_polygon()",self.mesh.get_boundary_polygon()
244        self.inside_poly_indices, self.outside_poly_indices  = \
245                     in_and_outside_polygon(point_coordinates,
246                                            self.mesh.get_boundary_polygon(),
247                                            closed = True, verbose = verbose)
248        #print "self.inside_poly_indices",self.inside_poly_indices
249        #print "self.outside_poly_indices",self.outside_poly_indices
250
251       
252        n = len(self.inside_poly_indices)
253        #Compute matrix elements for points inside the mesh
254        for i in self.inside_poly_indices:
255            #For each data_coordinate point
256            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
257            x = point_coordinates[i]
258            element_found, sigma0, sigma1, sigma2, k = \
259                           search_tree_of_vertices(self.root, self.mesh, x)
260           
261            if element_found is True:
262                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
263                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
264                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
265
266                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
267                js     = [j0,j1,j2]
268
269                for j in js:
270                    self.Atz[j] +=  sigmas[j]*z[i]
271                    #print "self.Atz building", self.Atz
272                    #print "self.Atz[j]", self.Atz[j]
273                    #print " sigmas[j]", sigmas[j]
274                    #print "z[i]",z[i]
275                    #print "result", sigmas[j]*z[i]
276                   
277                    for k in js:
278                        self.AtA[j,k] += sigmas[j]*sigmas[k]
279            else:
280                msg = 'Could not find triangle for point', x
281                raise Exception(msg)
282   
283       
284    def fit(self, point_coordinates=None, z=None,
285                              verbose = False,
286                              point_origin = None):
287        """Fit a smooth surface to given 1d array of data points z.
288
289        The smooth surface is computed at each vertex in the underlying
290        mesh using the formula given in the module doc string.
291
292        Inputs:
293        point_coordinates: The co-ordinates of the data points.
294              List of coordinate pairs [x, y] of
295              data points or an nx2 Numeric array or a Geospatial_data object
296          z: Single 1d vector or array of data at the point_coordinates.
297         
298        """
299        if point_coordinates is None:
300            assert self.AtA <> None
301            assert self.Atz <> None
302            #FIXME (DSG) - do  a message
303        else:
304            point_coordinates = ensure_absolute(point_coordinates,
305                                                geo_reference=point_origin)
306            self.build_fit_subset(point_coordinates, z, verbose)
307
308        #Check sanity
309        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
310        n = self.point_count
311        if n<m and self.alpha == 0.0:
312            msg = 'ERROR (least_squares): Too few data points\n'
313            msg += 'There are only %d data points and alpha == 0. ' %n
314            msg += 'Need at least %d\n' %m
315            msg += 'Alternatively, set smoothing parameter alpha to a small '
316            msg += 'positive value,\ne.g. 1.0e-3.'
317            raise ToFewPointsError(msg)
318
319        self._build_coefficient_matrix_B(verbose)
320        loners = self.mesh.get_lone_vertices()
321        # FIXME  - make this as error message.
322        # test with
323        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
324        if len(loners)>0:
325            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
326            msg += 'All vertices should be part of a triangle.\n'
327            msg += 'In the future this will be inforced.\n'
328            msg += 'The following vertices are not part of a triangle;\n'
329            msg += str(loners)
330            print msg
331            #raise VertsWithNoTrianglesError(msg)
332       
333       
334        return conjugate_gradient(self.B, self.Atz, self.Atz,
335                                  imax=2*len(self.Atz) )
336
337       
338    def build_fit_subset(self, point_coordinates, z,
339                              verbose = False):
340        """Fit a smooth surface to given 1d array of data points z.
341
342        The smooth surface is computed at each vertex in the underlying
343        mesh using the formula given in the module doc string.
344
345        Inputs:
346        point_coordinates: The co-ordinates of the data points.
347              List of coordinate pairs [x, y] of
348              data points or an nx2 Numeric array or a Geospatial_data object
349          z: Single 1d vector or array of data at the point_coordinates.
350
351        """
352        #Note: Don't get the z info from Geospatial_data.attributes yet.
353        # If we did fit would have to handle attribute title info.
354
355        #FIXME(DSG-DSG): Check that the vert and point coords
356        #have the same zone.
357        if isinstance(point_coordinates,Geospatial_data):
358            point_coordinates = point_coordinates.get_data_points( \
359                absolute = True)
360       
361        #Convert input to Numeric arrays
362        z = ensure_numeric(z, Float)
363        point_coordinates = ensure_numeric(point_coordinates, Float)
364
365        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
366
367
368############################################################################
369
370def fit_to_mesh(vertex_coordinates,
371                triangles,
372                point_coordinates,
373                point_attributes,
374                alpha = DEFAULT_ALPHA,
375                verbose = False,
376                acceptable_overshoot = 1.01,
377                mesh_origin = None,
378                data_origin = None,
379                use_cache = False):
380    """
381    Fit a smooth surface to a triangulation,
382    given data points with attributes.
383
384
385        Inputs:
386        vertex_coordinates: List of coordinate pairs [xi, eta] of
387              points constituting a mesh (or an m x 2 Numeric array or
388              a geospatial object)
389              Points may appear multiple times
390              (e.g. if vertices have discontinuities)
391
392          triangles: List of 3-tuples (or a Numeric array) of
393          integers representing indices of all vertices in the mesh.
394
395          point_coordinates: List of coordinate pairs [x, y] of data points
396          (or an nx2 Numeric array)
397
398          alpha: Smoothing parameter.
399
400          acceptable overshoot: controls the allowed factor by which fitted values
401          may exceed the value of input data. The lower limit is defined
402          as min(z) - acceptable_overshoot*delta z and upper limit
403          as max(z) + acceptable_overshoot*delta z
404
405          mesh_origin: A geo_reference object or 3-tuples consisting of
406              UTM zone, easting and northing.
407              If specified vertex coordinates are assumed to be
408              relative to their respective origins.
409         
410
411          point_attributes: Vector or array of data at the
412                            point_coordinates.
413
414    """
415    #Since this is a wrapper for fit, lets handle the geo_spatial att's
416    if use_cache is True:
417        interp = cache(_fit,
418                       (vertex_coordinates,
419                        triangles),
420                       {'verbose': verbose,
421                        'mesh_origin': mesh_origin,
422                        'alpha':alpha},
423                       verbose = verbose)       
424       
425    else:
426        interp = Fit(vertex_coordinates,
427                     triangles,
428                     verbose = verbose,
429                     mesh_origin = mesh_origin,
430                     alpha=alpha)
431       
432    vertex_attributes = interp.fit(point_coordinates,
433                                   point_attributes,
434                                   point_origin = data_origin,
435                                   verbose = verbose)
436
437       
438    # Add the value checking stuff that's in least squares.
439    # Maybe this stuff should get pushed down into Fit.
440    # at least be a method of Fit.
441    # Or intigrate it into the fit method, saving teh max and min's
442    # as att's.
443   
444    return vertex_attributes
445
446
447def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
448                     alpha=DEFAULT_ALPHA, verbose= False,
449                     display_errors = True):
450    """
451    Given a mesh file (tsh) and a point attribute file (xya), fit
452    point attributes to the mesh and write a mesh file with the
453    results.
454
455
456    If data_origin is not None it is assumed to be
457    a 3-tuple with geo referenced
458    UTM coordinates (zone, easting, northing)
459
460    NOTE: Throws IOErrors, for a variety of file problems.
461   
462    """
463
464    # Question
465    # should data_origin and mesh_origin be passed in?
466    # No they should be in the data structure
467    #
468    #Should the origin of the mesh be changed using this function?
469    # That is overloading this function.  Have it as a seperate
470    # method, at least initially.
471   
472    from load_mesh.loadASCII import import_mesh_file, \
473                 import_points_file, export_mesh_file, \
474                 concatinate_attributelist
475
476    # FIXME: Use geospatial instead of import_points_file
477    try:
478        mesh_dict = import_mesh_file(mesh_file)
479    except IOError,e:
480        if display_errors:
481            print "Could not load bad file. ", e
482        raise IOError  #Re-raise exception
483       
484    vertex_coordinates = mesh_dict['vertices']
485    triangles = mesh_dict['triangles']
486    if type(mesh_dict['vertex_attributes']) == ArrayType:
487        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
488    else:
489        old_point_attributes = mesh_dict['vertex_attributes']
490
491    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
492        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
493    else:
494        old_title_list = mesh_dict['vertex_attribute_titles']
495
496    if verbose: print 'tsh file %s loaded' %mesh_file
497
498    # load in the .pts file
499    try:
500        point_dict = import_points_file(point_file, verbose=verbose)
501    except IOError,e:
502        if display_errors:
503            print "Could not load bad file. ", e
504        raise IOError  #Re-raise exception 
505
506    point_coordinates = point_dict['pointlist']
507    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
508
509    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
510        data_origin = point_dict['geo_reference'].get_origin()
511    else:
512        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
513
514    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
515        mesh_origin = mesh_dict['geo_reference'].get_origin()
516    else:
517        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
518
519    if verbose: print "points file loaded"
520    if verbose: print "fitting to mesh"
521    f = fit_to_mesh(vertex_coordinates,
522                    triangles,
523                    point_coordinates,
524                    point_attributes,
525                    alpha = alpha,
526                    verbose = verbose,
527                    data_origin = data_origin,
528                    mesh_origin = mesh_origin)
529    if verbose: print "finished fitting to mesh"
530
531    # convert array to list of lists
532    new_point_attributes = f.tolist()
533    #FIXME have this overwrite attributes with the same title - DSG
534    #Put the newer attributes last
535    if old_title_list <> []:
536        old_title_list.extend(title_list)
537        #FIXME can this be done a faster way? - DSG
538        for i in range(len(old_point_attributes)):
539            old_point_attributes[i].extend(new_point_attributes[i])
540        mesh_dict['vertex_attributes'] = old_point_attributes
541        mesh_dict['vertex_attribute_titles'] = old_title_list
542    else:
543        mesh_dict['vertex_attributes'] = new_point_attributes
544        mesh_dict['vertex_attribute_titles'] = title_list
545
546    if verbose: print "exporting to file ", mesh_output_file
547
548    try:
549        export_mesh_file(mesh_output_file, mesh_dict)
550    except IOError,e:
551        if display_errors:
552            print "Could not write file. ", e
553        raise IOError
554
555
556def _fit(*args, **kwargs):
557    """Private function for use with caching. Reason is that classes
558    may change their byte code between runs which is annoying.
559    """
560   
561    return Fit(*args, **kwargs)
562
Note: See TracBrowser for help on using the repository browser.