source: anuga_core/source/anuga/fit_interpolate/fit.py @ 4150

Last change on this file since 4150 was 4138, checked in by duncan, 18 years ago

tweak and bug fix

File size: 22.8 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18
19   TO DO
20   * test geo_ref, geo_spatial
21
22   IDEAS
23   * (DSG-) Change the interface of fit, so a domain object can
24      be passed in. (I don't know if this is feasible). If could
25      save time/memory.
26"""
27import types
28
29from Numeric import zeros, Float, ArrayType,take
30
31from anuga.caching import cache           
32from anuga.geospatial_data.geospatial_data import Geospatial_data, ensure_absolute
33from anuga.fit_interpolate.general_fit_interpolate import FitInterpolate
34from anuga.utilities.sparse import Sparse, Sparse_CSR
35from anuga.utilities.polygon import in_and_outside_polygon
36from anuga.fit_interpolate.search_functions import search_tree_of_vertices
37from anuga.utilities.cg_solve import conjugate_gradient
38from anuga.utilities.numerical_tools import ensure_numeric, gradient
39
40import exceptions
41class ToFewPointsError(exceptions.Exception): pass
42class VertsWithNoTrianglesError(exceptions.Exception): pass
43
44DEFAULT_ALPHA = 0.001
45
46
47class Fit(FitInterpolate):
48   
49    def __init__(self,
50                 vertex_coordinates,
51                 triangles,
52                 mesh_origin=None,
53                 alpha = None,
54                 verbose=False,
55                 max_vertices_per_cell=30):
56
57
58        """
59        Fit data at points to the vertices of a mesh.
60
61        Inputs:
62
63          vertex_coordinates: List of coordinate pairs [xi, eta] of
64              points constituting a mesh (or an m x 2 Numeric array or
65              a geospatial object)
66              Points may appear multiple times
67              (e.g. if vertices have discontinuities)
68
69          triangles: List of 3-tuples (or a Numeric array) of
70              integers representing indices of all vertices in the mesh.
71
72          mesh_origin: A geo_reference object or 3-tuples consisting of
73              UTM zone, easting and northing.
74              If specified vertex coordinates are assumed to be
75              relative to their respective origins.
76
77          max_vertices_per_cell: Number of vertices in a quad tree cell
78          at which the cell is split into 4.
79
80          Note: Don't supply a vertex coords as a geospatial object and
81              a mesh origin, since geospatial has its own mesh origin.
82
83
84        Usage,
85        To use this in a blocking way, call  build_fit_subset, with z info,
86        and then fit, with no point coord, z info.
87       
88        """
89        # Initialise variabels
90
91        if alpha is None:
92
93            self.alpha = DEFAULT_ALPHA
94        else:   
95            self.alpha = alpha
96        FitInterpolate.__init__(self,
97                 vertex_coordinates,
98                 triangles,
99                 mesh_origin,
100                 verbose,
101                 max_vertices_per_cell)
102       
103        m = self.mesh.number_of_nodes # Nbr of basis functions (vertices)
104       
105        self.AtA = None
106        self.Atz = None
107
108        self.point_count = 0
109        if self.alpha <> 0:
110            if verbose: print 'Building smoothing matrix'
111            self._build_smoothing_matrix_D()
112           
113    def _build_coefficient_matrix_B(self,
114                                  verbose = False):
115        """
116        Build final coefficient matrix
117
118        Precon
119        If alpha is not zero, matrix D has been built
120        Matrix Ata has been built
121        """
122
123        if self.alpha <> 0:
124            #if verbose: print 'Building smoothing matrix'
125            #self._build_smoothing_matrix_D()
126            self.B = self.AtA + self.alpha*self.D
127        else:
128            self.B = self.AtA
129
130        #Convert self.B matrix to CSR format for faster matrix vector
131        self.B = Sparse_CSR(self.B)
132
133    def _build_smoothing_matrix_D(self):
134        """Build m x m smoothing matrix, where
135        m is the number of basis functions phi_k (one per vertex)
136
137        The smoothing matrix is defined as
138
139        D = D1 + D2
140
141        where
142
143        [D1]_{k,l} = \int_\Omega
144           \frac{\partial \phi_k}{\partial x}
145           \frac{\partial \phi_l}{\partial x}\,
146           dx dy
147
148        [D2]_{k,l} = \int_\Omega
149           \frac{\partial \phi_k}{\partial y}
150           \frac{\partial \phi_l}{\partial y}\,
151           dx dy
152
153
154        The derivatives \frac{\partial \phi_k}{\partial x},
155        \frac{\partial \phi_k}{\partial x} for a particular triangle
156        are obtained by computing the gradient a_k, b_k for basis function k
157        """
158       
159        #FIXME: algorithm might be optimised by computing local 9x9
160        #"element stiffness matrices:
161
162        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
163
164        self.D = Sparse(m,m)
165
166        #For each triangle compute contributions to D = D1+D2
167        for i in range(len(self.mesh)):
168
169            #Get area
170            area = self.mesh.areas[i]
171
172            #Get global vertex indices
173            v0 = self.mesh.triangles[i,0]
174            v1 = self.mesh.triangles[i,1]
175            v2 = self.mesh.triangles[i,2]
176
177            #Get the three vertex_points
178            xi0 = self.mesh.get_vertex_coordinate(i, 0)
179            xi1 = self.mesh.get_vertex_coordinate(i, 1)
180            xi2 = self.mesh.get_vertex_coordinate(i, 2)
181
182            #Compute gradients for each vertex
183            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
184                              1, 0, 0)
185
186            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
187                              0, 1, 0)
188
189            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
190                              0, 0, 1)
191
192            #Compute diagonal contributions
193            self.D[v0,v0] += (a0*a0 + b0*b0)*area
194            self.D[v1,v1] += (a1*a1 + b1*b1)*area
195            self.D[v2,v2] += (a2*a2 + b2*b2)*area
196
197            #Compute contributions for basis functions sharing edges
198            e01 = (a0*a1 + b0*b1)*area
199            self.D[v0,v1] += e01
200            self.D[v1,v0] += e01
201
202            e12 = (a1*a2 + b1*b2)*area
203            self.D[v1,v2] += e12
204            self.D[v2,v1] += e12
205
206            e20 = (a2*a0 + b2*b0)*area
207            self.D[v2,v0] += e20
208            self.D[v0,v2] += e20
209
210
211    def get_D(self):
212        return self.D.todense()
213
214
215    def _build_matrix_AtA_Atz(self,
216                              point_coordinates,
217                              z,
218                              verbose = False):
219        """Build:
220        AtA  m x m  interpolation matrix, and,
221        Atz  m x a  interpolation matrix where,
222        m is the number of basis functions phi_k (one per vertex)
223        a is the number of data attributes
224
225        This algorithm uses a quad tree data structure for fast binning of
226        data points.
227
228        If Ata is None, the matrices AtA and Atz are created.
229
230        This function can be called again and again, with sub-sets of
231        the point coordinates.  Call fit to get the results.
232       
233        Preconditions
234        z and points are numeric
235        Point_coordindates and mesh vertices have the same origin.
236
237        The number of attributes of the data points does not change
238        """
239        #Build n x m interpolation matrix
240
241        if self.AtA == None:
242            # AtA and Atz need to be initialised.
243            m = self.mesh.number_of_nodes
244            if len(z.shape) > 1:
245                att_num = z.shape[1]
246                self.Atz = zeros((m,att_num), Float)
247            else:
248                att_num = 1
249                self.Atz = zeros((m,), Float)
250            assert z.shape[0] == point_coordinates.shape[0] 
251
252            self.AtA = Sparse(m,m)
253            # The memory damage has been done by now.
254           
255        self.point_count += point_coordinates.shape[0]
256        #print "_build_matrix_AtA_Atz - self.point_count", self.point_count
257        if verbose: print 'Getting indices inside mesh boundary'
258        #print 'point_coordinates.shape', point_coordinates.shape         
259        #print 'self.mesh.get_boundary_polygon()',\
260        #      self.mesh.get_boundary_polygon()
261
262        # Why are these global?
263        self.inside_poly_indices, self.outside_poly_indices  = \
264                     in_and_outside_polygon(point_coordinates,
265                                            self.mesh.get_boundary_polygon(),
266                                            closed = True, verbose = verbose)
267        #print "self.inside_poly_indices",self.inside_poly_indices
268        #print "self.outside_poly_indices",self.outside_poly_indices
269
270       
271        n = len(self.inside_poly_indices)
272        if verbose: print 'Building fitting matrix from %d points' %n       
273        #Compute matrix elements for points inside the mesh
274        for k, i in enumerate(self.inside_poly_indices):
275            #For each data_coordinate point
276            if verbose and k%((n+10)/10)==0: print 'Doing %d of %d' %(k, n)
277            x = point_coordinates[i]
278            element_found, sigma0, sigma1, sigma2, k = \
279                           search_tree_of_vertices(self.root, self.mesh, x)
280           
281            if element_found is True:
282                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
283                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
284                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
285
286                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
287                js     = [j0,j1,j2]
288
289                for j in js:
290                    self.Atz[j] +=  sigmas[j]*z[i]
291                    #print "self.Atz building", self.Atz
292                    #print "self.Atz[j]", self.Atz[j]
293                    #print " sigmas[j]", sigmas[j]
294                    #print "z[i]",z[i]
295                    #print "result", sigmas[j]*z[i]
296                   
297                    for k in js:
298                        self.AtA[j,k] += sigmas[j]*sigmas[k]
299            else:
300                msg = 'Could not find triangle for point', x
301                raise Exception(msg)
302   
303       
304    def fit(self, point_coordinates_or_filename=None, z=None,
305            verbose=False,
306            point_origin=None,
307            attribute_name=None,
308            max_read_lines=500):
309        """Fit a smooth surface to given 1d array of data points z.
310
311        The smooth surface is computed at each vertex in the underlying
312        mesh using the formula given in the module doc string.
313
314        Inputs:
315        point_coordinates: The co-ordinates of the data points.
316              List of coordinate pairs [x, y] of
317              data points or an nx2 Numeric array or a Geospatial_data object
318          z: Single 1d vector or array of data at the point_coordinates.
319         
320        """
321
322        # use blocking to load in the point info
323        if type(point_coordinates_or_filename) == types.StringType:
324            msg = "Don't set a point origin when reading from a file"
325            assert point_origin is None, msg
326            filename = point_coordinates_or_filename
327            for i,geo_block in  enumerate(Geospatial_data(filename,
328                                              max_read_lines=max_read_lines,
329                                              load_file_now=False)):
330                if verbose is True and 0 == i%200: # round every 5 minutes
331                    print 'Block %i' %i
332                # build the array
333                points = geo_block.get_data_points(absolute=True)
334                z = geo_block.get_attributes(attribute_name=attribute_name)
335                self.build_fit_subset(points, z)
336            point_coordinates = None
337        else:
338            point_coordinates =  point_coordinates_or_filename
339           
340        if point_coordinates is None:
341            assert self.AtA <> None
342            assert self.Atz <> None
343            #FIXME (DSG) - do  a message
344        else:
345            point_coordinates = ensure_absolute(point_coordinates,
346                                                geo_reference=point_origin)
347            #if isinstance(point_coordinates,Geospatial_data) and z is None:
348            # z will come from the geo-ref
349            self.build_fit_subset(point_coordinates, z, verbose)
350
351        #Check sanity
352        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
353        n = self.point_count
354        if n<m and self.alpha == 0.0:
355            msg = 'ERROR (least_squares): Too few data points\n'
356            msg += 'There are only %d data points and alpha == 0. ' %n
357            msg += 'Need at least %d\n' %m
358            msg += 'Alternatively, set smoothing parameter alpha to a small '
359            msg += 'positive value,\ne.g. 1.0e-3.'
360            raise ToFewPointsError(msg)
361
362        self._build_coefficient_matrix_B(verbose)
363        loners = self.mesh.get_lone_vertices()
364        # FIXME  - make this as error message.
365        # test with
366        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
367        if len(loners)>0:
368            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
369            msg += 'All vertices should be part of a triangle.\n'
370            msg += 'In the future this will be inforced.\n'
371            msg += 'The following vertices are not part of a triangle;\n'
372            msg += str(loners)
373            print msg
374            #raise VertsWithNoTrianglesError(msg)
375       
376       
377        return conjugate_gradient(self.B, self.Atz, self.Atz,
378                                  imax=2*len(self.Atz) )
379
380       
381    def build_fit_subset(self, point_coordinates, z=None, attribute_name=None,
382                              verbose=False):
383        """Fit a smooth surface to given 1d array of data points z.
384
385        The smooth surface is computed at each vertex in the underlying
386        mesh using the formula given in the module doc string.
387
388        Inputs:
389        point_coordinates: The co-ordinates of the data points.
390              List of coordinate pairs [x, y] of
391              data points or an nx2 Numeric array or a Geospatial_data object
392        z: Single 1d vector or array of data at the point_coordinates.
393        attribute_name: Used to get the z values from the
394              geospatial object if no attribute_name is specified,
395              it's a bit of a lucky dip as to what attributes you get.
396              If there is only one attribute it will be that one.
397
398        """
399
400        #FIXME(DSG-DSG): Check that the vert and point coords
401        #have the same zone.
402        if isinstance(point_coordinates,Geospatial_data):
403            point_coordinates = point_coordinates.get_data_points( \
404                absolute = True)
405       
406        #Convert input to Numeric arrays
407        if z is not None:
408            z = ensure_numeric(z, Float)
409        else:
410            msg = 'z not specified'
411            assert isinstance(point_coordinates,Geospatial_data), msg
412            z = point_coordinates.get_attributes(attribute_name)
413           
414        point_coordinates = ensure_numeric(point_coordinates, Float)
415
416        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
417
418
419############################################################################
420
421def fit_to_mesh(vertex_coordinates,
422                triangles,
423                point_coordinates, # this can also be a .csv/.txt file name
424                point_attributes=None,
425                alpha=DEFAULT_ALPHA,
426                verbose=False,
427                acceptable_overshoot=1.01,
428                mesh_origin=None,
429                data_origin=None,
430                max_read_lines=None,
431                attribute_name=None,
432                use_cache = False):
433    """
434    Fit a smooth surface to a triangulation,
435    given data points with attributes.
436
437
438        Inputs:
439        vertex_coordinates: List of coordinate pairs [xi, eta] of
440              points constituting a mesh (or an m x 2 Numeric array or
441              a geospatial object)
442              Points may appear multiple times
443              (e.g. if vertices have discontinuities)
444
445          triangles: List of 3-tuples (or a Numeric array) of
446          integers representing indices of all vertices in the mesh.
447
448          point_coordinates: List of coordinate pairs [x, y] of data points
449          (or an nx2 Numeric array)
450
451          alpha: Smoothing parameter.
452
453          acceptable overshoot: controls the allowed factor by which fitted values
454          may exceed the value of input data. The lower limit is defined
455          as min(z) - acceptable_overshoot*delta z and upper limit
456          as max(z) + acceptable_overshoot*delta z
457
458          mesh_origin: A geo_reference object or 3-tuples consisting of
459              UTM zone, easting and northing.
460              If specified vertex coordinates are assumed to be
461              relative to their respective origins.
462         
463
464          point_attributes: Vector or array of data at the
465                            point_coordinates.
466
467    """
468    #Since this is a wrapper for fit, lets handle the geo_spatial att's
469    if use_cache is True:
470        interp = cache(_fit,
471                       (vertex_coordinates,
472                        triangles),
473                       {'verbose': verbose,
474                        'mesh_origin': mesh_origin,
475                        'alpha':alpha},
476                       verbose = verbose)       
477       
478    else:
479        interp = Fit(vertex_coordinates,
480                     triangles,
481                     verbose=verbose,
482                     mesh_origin=mesh_origin,
483                     alpha=alpha)
484       
485    vertex_attributes = interp.fit(point_coordinates,
486                                   point_attributes,
487                                   point_origin=data_origin,
488                                   max_read_lines=max_read_lines,
489                                   attribute_name=attribute_name,
490                                   verbose=verbose)
491
492       
493    # Add the value checking stuff that's in least squares.
494    # Maybe this stuff should get pushed down into Fit.
495    # at least be a method of Fit.
496    # Or intigrate it into the fit method, saving teh max and min's
497    # as att's.
498   
499    return vertex_attributes
500
501
502def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
503                     alpha=DEFAULT_ALPHA, verbose= False,
504                     display_errors = True):
505    """
506    Given a mesh file (tsh) and a point attribute file (xya), fit
507    point attributes to the mesh and write a mesh file with the
508    results.
509
510
511    If data_origin is not None it is assumed to be
512    a 3-tuple with geo referenced
513    UTM coordinates (zone, easting, northing)
514
515    NOTE: Throws IOErrors, for a variety of file problems.
516   
517    """
518
519    # Question
520    # should data_origin and mesh_origin be passed in?
521    # No they should be in the data structure
522    #
523    #Should the origin of the mesh be changed using this function?
524    # That is overloading this function.  Have it as a seperate
525    # method, at least initially.
526   
527    from load_mesh.loadASCII import import_mesh_file, \
528                 import_points_file, export_mesh_file, \
529                 concatinate_attributelist
530
531    # FIXME: Use geospatial instead of import_points_file
532    try:
533        mesh_dict = import_mesh_file(mesh_file)
534    except IOError,e:
535        if display_errors:
536            print "Could not load bad file. ", e
537        raise IOError  #Re-raise exception
538       
539    vertex_coordinates = mesh_dict['vertices']
540    triangles = mesh_dict['triangles']
541    if type(mesh_dict['vertex_attributes']) == ArrayType:
542        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
543    else:
544        old_point_attributes = mesh_dict['vertex_attributes']
545
546    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
547        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
548    else:
549        old_title_list = mesh_dict['vertex_attribute_titles']
550
551    if verbose: print 'tsh file %s loaded' %mesh_file
552
553    # load in the .pts file
554    try:
555        point_dict = import_points_file(point_file, verbose=verbose)
556    except IOError,e:
557        if display_errors:
558            print "Could not load bad file. ", e
559        raise IOError  #Re-raise exception 
560
561    point_coordinates = point_dict['pointlist']
562    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
563
564    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
565        data_origin = point_dict['geo_reference'].get_origin()
566    else:
567        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
568
569    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
570        mesh_origin = mesh_dict['geo_reference'].get_origin()
571    else:
572        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
573
574    if verbose: print "points file loaded"
575    if verbose: print "fitting to mesh"
576    f = fit_to_mesh(vertex_coordinates,
577                    triangles,
578                    point_coordinates,
579                    point_attributes,
580                    alpha = alpha,
581                    verbose = verbose,
582                    data_origin = data_origin,
583                    mesh_origin = mesh_origin)
584    if verbose: print "finished fitting to mesh"
585
586    # convert array to list of lists
587    new_point_attributes = f.tolist()
588    #FIXME have this overwrite attributes with the same title - DSG
589    #Put the newer attributes last
590    if old_title_list <> []:
591        old_title_list.extend(title_list)
592        #FIXME can this be done a faster way? - DSG
593        for i in range(len(old_point_attributes)):
594            old_point_attributes[i].extend(new_point_attributes[i])
595        mesh_dict['vertex_attributes'] = old_point_attributes
596        mesh_dict['vertex_attribute_titles'] = old_title_list
597    else:
598        mesh_dict['vertex_attributes'] = new_point_attributes
599        mesh_dict['vertex_attribute_titles'] = title_list
600
601    if verbose: print "exporting to file ", mesh_output_file
602
603    try:
604        export_mesh_file(mesh_output_file, mesh_dict)
605    except IOError,e:
606        if display_errors:
607            print "Could not write file. ", e
608        raise IOError
609
610
611def _fit(*args, **kwargs):
612    """Private function for use with caching. Reason is that classes
613    may change their byte code between runs which is annoying.
614    """
615   
616    return Fit(*args, **kwargs)
617
Note: See TracBrowser for help on using the repository browser.