source: anuga_core/source/anuga/fit_interpolate/fit.py @ 5485

Last change on this file since 5485 was 5352, checked in by ole, 16 years ago

Moved default smoothing parameter from fit to config.
Removed print statements in data_manager.

File size: 25.4 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4   putting point data onto the mesh.
5
6   The penalty term (or smoothing term) is controlled by the smoothing
7   parameter alpha.
8   With a value of alpha=0, the fit function will attempt
9   to interpolate as closely as possible in the least-squares sense.
10   With values alpha > 0, a certain amount of smoothing will be applied.
11   A positive alpha is essential in cases where there are too few
12   data points.
13   A negative alpha is not allowed.
14   A typical value of alpha is 1.0e-6
15
16
17   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
18   Geoscience Australia, 2004.
19
20   TO DO
21   * test geo_ref, geo_spatial
22
23   IDEAS
24   * (DSG-) Change the interface of fit, so a domain object can
25      be passed in. (I don't know if this is feasible). If could
26      save time/memory.
27"""
28import types
29
30from Numeric import zeros, Float, ArrayType,take, Int
31
32from anuga.abstract_2d_finite_volumes.neighbour_mesh import Mesh
33from anuga.caching import cache           
34from anuga.geospatial_data.geospatial_data import Geospatial_data, \
35     ensure_absolute
36from anuga.fit_interpolate.general_fit_interpolate import FitInterpolate
37from anuga.utilities.sparse import Sparse, Sparse_CSR
38from anuga.utilities.polygon import in_and_outside_polygon
39from anuga.fit_interpolate.search_functions import search_tree_of_vertices
40
41from anuga.utilities.cg_solve import conjugate_gradient
42from anuga.utilities.numerical_tools import ensure_numeric, gradient
43from anuga.config import default_smoothing_parameter as DEFAULT_ALPHA
44
45import exceptions
46class TooFewPointsError(exceptions.Exception): pass
47class VertsWithNoTrianglesError(exceptions.Exception): pass
48
49#DEFAULT_ALPHA = 0.001
50
51
52class Fit(FitInterpolate):
53   
54    def __init__(self,
55                 vertex_coordinates=None,
56                 triangles=None,
57                 mesh=None,
58                 mesh_origin=None,
59                 alpha = None,
60                 verbose=False,
61                 max_vertices_per_cell=None):
62
63
64        """
65        Fit data at points to the vertices of a mesh.
66
67        Inputs:
68
69          vertex_coordinates: List of coordinate pairs [xi, eta] of
70              points constituting a mesh (or an m x 2 Numeric array or
71              a geospatial object)
72              Points may appear multiple times
73              (e.g. if vertices have discontinuities)
74
75          triangles: List of 3-tuples (or a Numeric array) of
76              integers representing indices of all vertices in the mesh.
77
78          mesh_origin: A geo_reference object or 3-tuples consisting of
79              UTM zone, easting and northing.
80              If specified vertex coordinates are assumed to be
81              relative to their respective origins.
82
83          max_vertices_per_cell: Number of vertices in a quad tree cell
84          at which the cell is split into 4.
85
86          Note: Don't supply a vertex coords as a geospatial object and
87              a mesh origin, since geospatial has its own mesh origin.
88
89
90        Usage,
91        To use this in a blocking way, call  build_fit_subset, with z info,
92        and then fit, with no point coord, z info.
93       
94        """
95        # Initialise variabels
96        if alpha is None:
97            self.alpha = DEFAULT_ALPHA
98        else:   
99            self.alpha = alpha
100           
101        FitInterpolate.__init__(self,
102                 vertex_coordinates,
103                 triangles,
104                 mesh,
105                 mesh_origin,
106                 verbose,
107                 max_vertices_per_cell)
108       
109        m = self.mesh.number_of_nodes # Nbr of basis functions (vertices)
110       
111        self.AtA = None
112        self.Atz = None
113
114        self.point_count = 0
115        if self.alpha <> 0:
116            if verbose: print 'Building smoothing matrix'
117            self._build_smoothing_matrix_D()
118           
119        self.mesh_boundary_polygon = self.mesh.get_boundary_polygon()   
120           
121    def _build_coefficient_matrix_B(self,
122                                  verbose = False):
123        """
124        Build final coefficient matrix
125
126        Precon
127        If alpha is not zero, matrix D has been built
128        Matrix Ata has been built
129        """
130
131        if self.alpha <> 0:
132            #if verbose: print 'Building smoothing matrix'
133            #self._build_smoothing_matrix_D()
134            self.B = self.AtA + self.alpha*self.D
135        else:
136            self.B = self.AtA
137
138        #Convert self.B matrix to CSR format for faster matrix vector
139        self.B = Sparse_CSR(self.B)
140
141    def _build_smoothing_matrix_D(self):
142        """Build m x m smoothing matrix, where
143        m is the number of basis functions phi_k (one per vertex)
144
145        The smoothing matrix is defined as
146
147        D = D1 + D2
148
149        where
150
151        [D1]_{k,l} = \int_\Omega
152           \frac{\partial \phi_k}{\partial x}
153           \frac{\partial \phi_l}{\partial x}\,
154           dx dy
155
156        [D2]_{k,l} = \int_\Omega
157           \frac{\partial \phi_k}{\partial y}
158           \frac{\partial \phi_l}{\partial y}\,
159           dx dy
160
161
162        The derivatives \frac{\partial \phi_k}{\partial x},
163        \frac{\partial \phi_k}{\partial x} for a particular triangle
164        are obtained by computing the gradient a_k, b_k for basis function k
165        """
166       
167        #FIXME: algorithm might be optimised by computing local 9x9
168        #"element stiffness matrices:
169
170        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
171
172        self.D = Sparse(m,m)
173
174        #For each triangle compute contributions to D = D1+D2
175        for i in range(len(self.mesh)):
176
177            #Get area
178            area = self.mesh.areas[i]
179
180            #Get global vertex indices
181            v0 = self.mesh.triangles[i,0]
182            v1 = self.mesh.triangles[i,1]
183            v2 = self.mesh.triangles[i,2]
184
185            #Get the three vertex_points
186            xi0 = self.mesh.get_vertex_coordinate(i, 0)
187            xi1 = self.mesh.get_vertex_coordinate(i, 1)
188            xi2 = self.mesh.get_vertex_coordinate(i, 2)
189
190            #Compute gradients for each vertex
191            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
192                              1, 0, 0)
193
194            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
195                              0, 1, 0)
196
197            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
198                              0, 0, 1)
199
200            #Compute diagonal contributions
201            self.D[v0,v0] += (a0*a0 + b0*b0)*area
202            self.D[v1,v1] += (a1*a1 + b1*b1)*area
203            self.D[v2,v2] += (a2*a2 + b2*b2)*area
204
205            #Compute contributions for basis functions sharing edges
206            e01 = (a0*a1 + b0*b1)*area
207            self.D[v0,v1] += e01
208            self.D[v1,v0] += e01
209
210            e12 = (a1*a2 + b1*b2)*area
211            self.D[v1,v2] += e12
212            self.D[v2,v1] += e12
213
214            e20 = (a2*a0 + b2*b0)*area
215            self.D[v2,v0] += e20
216            self.D[v0,v2] += e20
217
218    def get_D(self):
219        return self.D.todense()
220
221
222
223    def _build_matrix_AtA_Atz(self,
224                              point_coordinates,
225                              z,
226                              verbose = False):
227        """Build:
228        AtA  m x m  interpolation matrix, and,
229        Atz  m x a  interpolation matrix where,
230        m is the number of basis functions phi_k (one per vertex)
231        a is the number of data attributes
232
233        This algorithm uses a quad tree data structure for fast binning of
234        data points.
235
236        If Ata is None, the matrices AtA and Atz are created.
237
238        This function can be called again and again, with sub-sets of
239        the point coordinates.  Call fit to get the results.
240       
241        Preconditions
242        z and points are numeric
243        Point_coordindates and mesh vertices have the same origin.
244
245        The number of attributes of the data points does not change
246        """
247        #Build n x m interpolation matrix
248
249        if self.AtA == None:
250            # AtA and Atz need to be initialised.
251            m = self.mesh.number_of_nodes
252            if len(z.shape) > 1:
253                att_num = z.shape[1]
254                self.Atz = zeros((m,att_num), Float)
255            else:
256                att_num = 1
257                self.Atz = zeros((m,), Float)
258            assert z.shape[0] == point_coordinates.shape[0] 
259
260            AtA = Sparse(m,m)
261            # The memory damage has been done by now.
262        else:
263             AtA = self.AtA #Did this for speed, did ~nothing
264        self.point_count += point_coordinates.shape[0]
265
266        #if verbose: print 'Getting indices inside mesh boundary'
267
268        inside_poly_indices, outside_poly_indices  = \
269                     in_and_outside_polygon(point_coordinates,
270                                            self.mesh_boundary_polygon,
271                                            closed = True,
272                                            verbose = False) # There's too much output if True
273        #print "self.inside_poly_indices",self.inside_poly_indices
274        #print "self.outside_poly_indices",self.outside_poly_indices
275
276       
277        n = len(inside_poly_indices)
278        #if verbose: print 'Building fitting matrix from %d points' %n       
279
280        #Compute matrix elements for points inside the mesh
281        triangles = self.mesh.triangles #Did this for speed, did ~nothing
282        for d, i in enumerate(inside_poly_indices):
283            # For each data_coordinate point
284            # if verbose and d%((n+10)/10)==0: print 'Doing %d of %d' %(d, n)
285            x = point_coordinates[i]
286            element_found, sigma0, sigma1, sigma2, k = \
287                           search_tree_of_vertices(self.root, self.mesh, x)
288           
289            if element_found is True:
290                j0 = triangles[k,0] #Global vertex id for sigma0
291                j1 = triangles[k,1] #Global vertex id for sigma1
292                j2 = triangles[k,2] #Global vertex id for sigma2
293
294                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
295                js     = [j0,j1,j2]
296
297                for j in js:
298                    self.Atz[j] +=  sigmas[j]*z[i]
299                    #print "self.Atz building", self.Atz
300                    #print "self.Atz[j]", self.Atz[j]
301                    #print " sigmas[j]", sigmas[j]
302                    #print "z[i]",z[i]
303                    #print "result", sigmas[j]*z[i]
304                   
305                    for k in js:
306                        AtA[j,k] += sigmas[j]*sigmas[k]
307            else:
308                msg = 'Could not find triangle for point', x
309                raise Exception(msg)
310            self.AtA = AtA
311       
312    def fit(self, point_coordinates_or_filename=None, z=None,
313            verbose=False,
314            point_origin=None,
315            attribute_name=None,
316            max_read_lines=500):
317        """Fit a smooth surface to given 1d array of data points z.
318
319        The smooth surface is computed at each vertex in the underlying
320        mesh using the formula given in the module doc string.
321
322        Inputs:
323        point_coordinates: The co-ordinates of the data points.
324              List of coordinate pairs [x, y] of
325              data points or an nx2 Numeric array or a Geospatial_data object
326              or points file filename
327          z: Single 1d vector or array of data at the point_coordinates.
328         
329        """
330        # use blocking to load in the point info
331        if type(point_coordinates_or_filename) == types.StringType:
332            msg = "Don't set a point origin when reading from a file"
333            assert point_origin is None, msg
334            filename = point_coordinates_or_filename
335
336            G_data = Geospatial_data(filename,
337                                     max_read_lines=max_read_lines,
338                                     load_file_now=False,
339                                     verbose=verbose)
340
341            for i, geo_block in enumerate(G_data):
342                if verbose is True and 0 == i%200: 
343                    # The time this will take
344                    # is dependant on the # of Triangles
345                       
346                    print 'Processing Block %d' %i
347                    # FIXME (Ole): It would be good to say how many blocks
348                    # there are here. But this is no longer necessary
349                    # for pts files as they are reported in geospatial_data
350                    # I suggest deleting this verbose output and make
351                    # Geospatial_data more informative for txt files.
352                    #
353                    # I still think so (12/12/7, Ole).
354           
355
356                   
357                # Build the array
358
359                points = geo_block.get_data_points(absolute=True)
360                #print "fit points", points
361                z = geo_block.get_attributes(attribute_name=attribute_name)
362                self.build_fit_subset(points, z, verbose=verbose)
363
364               
365            point_coordinates = None
366        else:
367            point_coordinates =  point_coordinates_or_filename
368           
369        if point_coordinates is None:
370            assert self.AtA <> None
371            assert self.Atz <> None
372            #FIXME (DSG) - do  a message
373        else:
374            point_coordinates = ensure_absolute(point_coordinates,
375                                                geo_reference=point_origin)
376            #if isinstance(point_coordinates,Geospatial_data) and z is None:
377            # z will come from the geo-ref
378            self.build_fit_subset(point_coordinates, z, verbose)
379
380        #Check sanity
381        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
382        n = self.point_count
383        if n<m and self.alpha == 0.0:
384            msg = 'ERROR (least_squares): Too few data points\n'
385            msg += 'There are only %d data points and alpha == 0. ' %n
386            msg += 'Need at least %d\n' %m
387            msg += 'Alternatively, set smoothing parameter alpha to a small '
388            msg += 'positive value,\ne.g. 1.0e-3.'
389            raise TooFewPointsError(msg)
390
391        self._build_coefficient_matrix_B(verbose)
392        loners = self.mesh.get_lone_vertices()
393        # FIXME  - make this as error message.
394        # test with
395        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
396        if len(loners)>0:
397            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
398            msg += 'All vertices should be part of a triangle.\n'
399            msg += 'In the future this will be inforced.\n'
400            msg += 'The following vertices are not part of a triangle;\n'
401            msg += str(loners)
402            print msg
403            #raise VertsWithNoTrianglesError(msg)
404       
405       
406        return conjugate_gradient(self.B, self.Atz, self.Atz,
407                                  imax=2*len(self.Atz) )
408
409       
410    def build_fit_subset(self, point_coordinates, z=None, attribute_name=None,
411                              verbose=False):
412        """Fit a smooth surface to given 1d array of data points z.
413
414        The smooth surface is computed at each vertex in the underlying
415        mesh using the formula given in the module doc string.
416
417        Inputs:
418        point_coordinates: The co-ordinates of the data points.
419              List of coordinate pairs [x, y] of
420              data points or an nx2 Numeric array or a Geospatial_data object
421        z: Single 1d vector or array of data at the point_coordinates.
422        attribute_name: Used to get the z values from the
423              geospatial object if no attribute_name is specified,
424              it's a bit of a lucky dip as to what attributes you get.
425              If there is only one attribute it will be that one.
426
427        """
428
429        # FIXME(DSG-DSG): Check that the vert and point coords
430        # have the same zone.
431        if isinstance(point_coordinates,Geospatial_data):
432            point_coordinates = point_coordinates.get_data_points( \
433                absolute = True)
434       
435        # Convert input to Numeric arrays
436        if z is not None:
437            z = ensure_numeric(z, Float)
438        else:
439            msg = 'z not specified'
440            assert isinstance(point_coordinates,Geospatial_data), msg
441            z = point_coordinates.get_attributes(attribute_name)
442
443        point_coordinates = ensure_numeric(point_coordinates, Float)
444        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
445
446
447############################################################################
448
449def fit_to_mesh(point_coordinates, # this can also be a points file name
450                vertex_coordinates=None,
451                triangles=None,
452                mesh=None,
453                point_attributes=None,
454                alpha=DEFAULT_ALPHA,
455                verbose=False,
456                acceptable_overshoot=1.01,
457                mesh_origin=None,
458                data_origin=None,
459                max_read_lines=None,
460                attribute_name=None,
461                use_cache = False):
462    """Wrapper around internal function _fit_to_mesh for use with caching.
463   
464    """
465   
466    args = (point_coordinates, )
467    kwargs = {'vertex_coordinates': vertex_coordinates,
468              'triangles': triangles,
469              'mesh': mesh,
470              'point_attributes': point_attributes,
471              'alpha': alpha,
472              'verbose': verbose,
473              'acceptable_overshoot': acceptable_overshoot,
474              'mesh_origin': mesh_origin,
475              'data_origin': data_origin,
476              'max_read_lines': max_read_lines,
477              'attribute_name': attribute_name,
478              'use_cache':use_cache
479              }
480
481    if use_cache is True:
482        if isinstance(point_coordinates, basestring):
483            # We assume that point_coordinates is the name of a .csv/.txt
484            # file which must be passed onto caching as a dependency
485            # (in case it has changed on disk)
486            dep = [point_coordinates]
487        else:
488            dep = None
489
490        return cache(_fit_to_mesh,
491                     args, kwargs,
492                     verbose=verbose,
493                     compression=False,
494                     dependencies=dep)
495    else:
496        return apply(_fit_to_mesh,
497                     args, kwargs)
498
499def _fit_to_mesh(point_coordinates, # this can also be a points file name
500                 vertex_coordinates=None,
501                 triangles=None,
502                 mesh=None,
503                 point_attributes=None,
504                 alpha=DEFAULT_ALPHA,
505                 verbose=False,
506                 acceptable_overshoot=1.01,
507                 mesh_origin=None,
508                 data_origin=None,
509                 max_read_lines=None,
510                 attribute_name=None,
511                 use_cache = False):
512    """
513    Fit a smooth surface to a triangulation,
514    given data points with attributes.
515
516
517        Inputs:
518        vertex_coordinates: List of coordinate pairs [xi, eta] of
519              points constituting a mesh (or an m x 2 Numeric array or
520              a geospatial object)
521              Points may appear multiple times
522              (e.g. if vertices have discontinuities)
523
524          triangles: List of 3-tuples (or a Numeric array) of
525          integers representing indices of all vertices in the mesh.
526
527          point_coordinates: List of coordinate pairs [x, y] of data points
528          (or an nx2 Numeric array). This can also be a .csv/.txt/.pts
529          file name.
530
531          alpha: Smoothing parameter.
532
533          acceptable overshoot: NOT IMPLEMENTED
534          controls the allowed factor by which
535          fitted values
536          may exceed the value of input data. The lower limit is defined
537          as min(z) - acceptable_overshoot*delta z and upper limit
538          as max(z) + acceptable_overshoot*delta z
539         
540
541          mesh_origin: A geo_reference object or 3-tuples consisting of
542              UTM zone, easting and northing.
543              If specified vertex coordinates are assumed to be
544              relative to their respective origins.
545         
546
547          point_attributes: Vector or array of data at the
548                            point_coordinates.
549
550    """
551
552    # Duncan and Ole think that this isn't worth caching.
553    # Caching happens at the higher level anyway.
554   
555
556    if mesh is None:
557        # FIXME(DSG): Throw errors if triangles or vertex_coordinates
558        # are None
559           
560        #Convert input to Numeric arrays
561        triangles = ensure_numeric(triangles, Int)
562        vertex_coordinates = ensure_absolute(vertex_coordinates,
563                                             geo_reference = mesh_origin)
564
565        if verbose: print 'FitInterpolate: Building mesh'       
566        mesh = Mesh(vertex_coordinates, triangles)
567        mesh.check_integrity()
568   
569    interp = Fit(mesh=mesh,
570                 verbose=verbose,
571                 alpha=alpha)
572
573    vertex_attributes = interp.fit(point_coordinates,
574                                   point_attributes,
575                                   point_origin=data_origin,
576                                   max_read_lines=max_read_lines,
577                                   attribute_name=attribute_name,
578                                   verbose=verbose)
579
580       
581    # Add the value checking stuff that's in least squares.
582    # Maybe this stuff should get pushed down into Fit.
583    # at least be a method of Fit.
584    # Or intigrate it into the fit method, saving teh max and min's
585    # as att's.
586   
587    return vertex_attributes
588
589
590#def _fit(*args, **kwargs):
591#    """Private function for use with caching. Reason is that classes
592#    may change their byte code between runs which is annoying.
593#    """
594#   
595#    return Fit(*args, **kwargs)
596
597
598def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
599                     alpha=DEFAULT_ALPHA, verbose= False,
600                     expand_search = False,
601                     precrop = False,
602                     display_errors = True):
603    """
604    Given a mesh file (tsh) and a point attribute file, fit
605    point attributes to the mesh and write a mesh file with the
606    results.
607
608    Note: the points file needs titles.  If you want anuga to use the tsh file,
609    make sure the title is elevation.
610
611    NOTE: Throws IOErrors, for a variety of file problems.
612   
613    """
614
615    from load_mesh.loadASCII import import_mesh_file, \
616         export_mesh_file, concatinate_attributelist
617
618
619    try:
620        mesh_dict = import_mesh_file(mesh_file)
621    except IOError,e:
622        if display_errors:
623            print "Could not load bad file. ", e
624        raise IOError  #Could not load bad mesh file.
625   
626    vertex_coordinates = mesh_dict['vertices']
627    triangles = mesh_dict['triangles']
628    if type(mesh_dict['vertex_attributes']) == ArrayType:
629        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
630    else:
631        old_point_attributes = mesh_dict['vertex_attributes']
632
633    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
634        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
635    else:
636        old_title_list = mesh_dict['vertex_attribute_titles']
637
638    if verbose: print 'tsh file %s loaded' %mesh_file
639
640    # load in the points file
641    try:
642        geo = Geospatial_data(point_file, verbose=verbose)
643    except IOError,e:
644        if display_errors:
645            print "Could not load bad file. ", e
646        raise IOError  #Re-raise exception 
647
648    point_coordinates = geo.get_data_points(absolute=True)
649    title_list,point_attributes = concatinate_attributelist( \
650        geo.get_all_attributes())
651
652    if mesh_dict.has_key('geo_reference') and \
653           not mesh_dict['geo_reference'] is None:
654        mesh_origin = mesh_dict['geo_reference'].get_origin()
655    else:
656        mesh_origin = None
657
658    if verbose: print "points file loaded"
659    if verbose: print "fitting to mesh"
660    f = fit_to_mesh(point_coordinates,
661                    vertex_coordinates,
662                    triangles,
663                    None,
664                    point_attributes,
665                    alpha = alpha,
666                    verbose = verbose,
667                    data_origin = None,
668                    mesh_origin = mesh_origin)
669    if verbose: print "finished fitting to mesh"
670
671    # convert array to list of lists
672    new_point_attributes = f.tolist()
673    #FIXME have this overwrite attributes with the same title - DSG
674    #Put the newer attributes last
675    if old_title_list <> []:
676        old_title_list.extend(title_list)
677        #FIXME can this be done a faster way? - DSG
678        for i in range(len(old_point_attributes)):
679            old_point_attributes[i].extend(new_point_attributes[i])
680        mesh_dict['vertex_attributes'] = old_point_attributes
681        mesh_dict['vertex_attribute_titles'] = old_title_list
682    else:
683        mesh_dict['vertex_attributes'] = new_point_attributes
684        mesh_dict['vertex_attribute_titles'] = title_list
685
686    if verbose: print "exporting to file ", mesh_output_file
687
688    try:
689        export_mesh_file(mesh_output_file, mesh_dict)
690    except IOError,e:
691        if display_errors:
692            print "Could not write file. ", e
693        raise IOError
Note: See TracBrowser for help on using the repository browser.