source: anuga_core/source/anuga/fit_interpolate/fit.py @ 5793

Last change on this file since 5793 was 5793, checked in by kristy, 16 years ago
File size: 25.5 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4   putting point data onto the mesh.
5
6   The penalty term (or smoothing term) is controlled by the smoothing
7   parameter alpha.
8   With a value of alpha=0, the fit function will attempt
9   to interpolate as closely as possible in the least-squares sense.
10   With values alpha > 0, a certain amount of smoothing will be applied.
11   A positive alpha is essential in cases where there are too few
12   data points.
13   A negative alpha is not allowed.
14   A typical value of alpha is 1.0e-6
15
16
17   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
18   Geoscience Australia, 2004.
19
20   TO DO
21   * test geo_ref, geo_spatial
22
23   IDEAS
24   * (DSG-) Change the interface of fit, so a domain object can
25      be passed in. (I don't know if this is feasible). If could
26      save time/memory.
27"""
28import types
29
30from Numeric import zeros, Float, ArrayType,take, Int
31
32from anuga.abstract_2d_finite_volumes.neighbour_mesh import Mesh
33from anuga.caching import cache           
34from anuga.geospatial_data.geospatial_data import Geospatial_data, \
35     ensure_absolute
36from anuga.fit_interpolate.general_fit_interpolate import FitInterpolate
37from anuga.utilities.sparse import Sparse, Sparse_CSR
38from anuga.utilities.polygon import in_and_outside_polygon
39from anuga.fit_interpolate.search_functions import search_tree_of_vertices
40
41from anuga.utilities.cg_solve import conjugate_gradient
42from anuga.utilities.numerical_tools import ensure_numeric, gradient
43from anuga.config import default_smoothing_parameter as DEFAULT_ALPHA
44
45import exceptions
46class TooFewPointsError(exceptions.Exception): pass
47class VertsWithNoTrianglesError(exceptions.Exception): pass
48
49#DEFAULT_ALPHA = 0.001
50
51
52class Fit(FitInterpolate):
53   
54    def __init__(self,
55                 vertex_coordinates=None,
56                 triangles=None,
57                 mesh=None,
58                 mesh_origin=None,
59                 alpha = None,
60                 verbose=False,
61                 max_vertices_per_cell=None):
62
63
64        """
65        Fit data at points to the vertices of a mesh.
66
67        Inputs:
68
69          vertex_coordinates: List of coordinate pairs [xi, eta] of
70              points constituting a mesh (or an m x 2 Numeric array or
71              a geospatial object)
72              Points may appear multiple times
73              (e.g. if vertices have discontinuities)
74
75          triangles: List of 3-tuples (or a Numeric array) of
76              integers representing indices of all vertices in the mesh.
77
78          mesh_origin: A geo_reference object or 3-tuples consisting of
79              UTM zone, easting and northing.
80              If specified vertex coordinates are assumed to be
81              relative to their respective origins.
82
83          max_vertices_per_cell: Number of vertices in a quad tree cell
84          at which the cell is split into 4.
85
86          Note: Don't supply a vertex coords as a geospatial object and
87              a mesh origin, since geospatial has its own mesh origin.
88
89
90        Usage,
91        To use this in a blocking way, call  build_fit_subset, with z info,
92        and then fit, with no point coord, z info.
93       
94        """
95        # Initialise variabels
96        if alpha is None:
97            self.alpha = DEFAULT_ALPHA
98        else:   
99            self.alpha = alpha
100           
101        FitInterpolate.__init__(self,
102                 vertex_coordinates,
103                 triangles,
104                 mesh,
105                 mesh_origin,
106                 verbose,
107                 max_vertices_per_cell)
108       
109        m = self.mesh.number_of_nodes # Nbr of basis functions (vertices)
110       
111        self.AtA = None
112        self.Atz = None
113
114        self.point_count = 0
115        if self.alpha <> 0:
116            if verbose: print 'Building smoothing matrix'
117            self._build_smoothing_matrix_D()
118           
119        self.mesh_boundary_polygon = self.mesh.get_boundary_polygon()   
120           
121    def _build_coefficient_matrix_B(self,
122                                  verbose = False):
123        """
124        Build final coefficient matrix
125
126        Precon
127        If alpha is not zero, matrix D has been built
128        Matrix Ata has been built
129        """
130
131        if self.alpha <> 0:
132            #if verbose: print 'Building smoothing matrix'
133            #self._build_smoothing_matrix_D()
134            self.B = self.AtA + self.alpha*self.D
135        else:
136            self.B = self.AtA
137
138        #Convert self.B matrix to CSR format for faster matrix vector
139        self.B = Sparse_CSR(self.B)
140
141    def _build_smoothing_matrix_D(self):
142        """Build m x m smoothing matrix, where
143        m is the number of basis functions phi_k (one per vertex)
144
145        The smoothing matrix is defined as
146
147        D = D1 + D2
148
149        where
150
151        [D1]_{k,l} = \int_\Omega
152           \frac{\partial \phi_k}{\partial x}
153           \frac{\partial \phi_l}{\partial x}\,
154           dx dy
155
156        [D2]_{k,l} = \int_\Omega
157           \frac{\partial \phi_k}{\partial y}
158           \frac{\partial \phi_l}{\partial y}\,
159           dx dy
160
161
162        The derivatives \frac{\partial \phi_k}{\partial x},
163        \frac{\partial \phi_k}{\partial x} for a particular triangle
164        are obtained by computing the gradient a_k, b_k for basis function k
165        """
166       
167        #FIXME: algorithm might be optimised by computing local 9x9
168        #"element stiffness matrices:
169
170        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
171
172        self.D = Sparse(m,m)
173
174        #For each triangle compute contributions to D = D1+D2
175        for i in range(len(self.mesh)):
176
177            #Get area
178            area = self.mesh.areas[i]
179
180            #Get global vertex indices
181            v0 = self.mesh.triangles[i,0]
182            v1 = self.mesh.triangles[i,1]
183            v2 = self.mesh.triangles[i,2]
184
185            #Get the three vertex_points
186            xi0 = self.mesh.get_vertex_coordinate(i, 0)
187            xi1 = self.mesh.get_vertex_coordinate(i, 1)
188            xi2 = self.mesh.get_vertex_coordinate(i, 2)
189
190            #Compute gradients for each vertex
191            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
192                              1, 0, 0)
193
194            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
195                              0, 1, 0)
196
197            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
198                              0, 0, 1)
199
200            #Compute diagonal contributions
201            self.D[v0,v0] += (a0*a0 + b0*b0)*area
202            self.D[v1,v1] += (a1*a1 + b1*b1)*area
203            self.D[v2,v2] += (a2*a2 + b2*b2)*area
204
205            #Compute contributions for basis functions sharing edges
206            e01 = (a0*a1 + b0*b1)*area
207            self.D[v0,v1] += e01
208            self.D[v1,v0] += e01
209
210            e12 = (a1*a2 + b1*b2)*area
211            self.D[v1,v2] += e12
212            self.D[v2,v1] += e12
213
214            e20 = (a2*a0 + b2*b0)*area
215            self.D[v2,v0] += e20
216            self.D[v0,v2] += e20
217
218    def get_D(self):
219        return self.D.todense()
220
221
222
223    def _build_matrix_AtA_Atz(self,
224                              point_coordinates,
225                              z,
226                              verbose = False):
227        """Build:
228        AtA  m x m  interpolation matrix, and,
229        Atz  m x a  interpolation matrix where,
230        m is the number of basis functions phi_k (one per vertex)
231        a is the number of data attributes
232
233        This algorithm uses a quad tree data structure for fast binning of
234        data points.
235
236        If Ata is None, the matrices AtA and Atz are created.
237
238        This function can be called again and again, with sub-sets of
239        the point coordinates.  Call fit to get the results.
240       
241        Preconditions
242        z and points are numeric
243        Point_coordindates and mesh vertices have the same origin.
244
245        The number of attributes of the data points does not change
246        """
247        #Build n x m interpolation matrix
248
249        if self.AtA == None:
250            # AtA and Atz need to be initialised.
251            m = self.mesh.number_of_nodes
252            if len(z.shape) > 1:
253                att_num = z.shape[1]
254                self.Atz = zeros((m,att_num), Float)
255            else:
256                att_num = 1
257                self.Atz = zeros((m,), Float)
258            assert z.shape[0] == point_coordinates.shape[0] 
259
260            AtA = Sparse(m,m)
261            # The memory damage has been done by now.
262        else:
263             AtA = self.AtA #Did this for speed, did ~nothing
264        self.point_count += point_coordinates.shape[0]
265
266        #if verbose: print 'Getting indices inside mesh boundary'
267
268        inside_poly_indices, outside_poly_indices  = \
269                     in_and_outside_polygon(point_coordinates,
270                                            self.mesh_boundary_polygon,
271                                            closed = True,
272                                            verbose = False) # There's too much output if True
273        #print "self.inside_poly_indices",self.inside_poly_indices
274        #print "self.outside_poly_indices",self.outside_poly_indices
275
276       
277        n = len(inside_poly_indices)
278        #if verbose: print 'Building fitting matrix from %d points' %n       
279
280        #Compute matrix elements for points inside the mesh
281        triangles = self.mesh.triangles #Did this for speed, did ~nothing
282        for d, i in enumerate(inside_poly_indices):
283            # For each data_coordinate point
284            # if verbose and d%((n+10)/10)==0: print 'Doing %d of %d' %(d, n)
285            x = point_coordinates[i]
286            element_found, sigma0, sigma1, sigma2, k = \
287                           search_tree_of_vertices(self.root, self.mesh, x)
288           
289            if element_found is True:
290                j0 = triangles[k,0] #Global vertex id for sigma0
291                j1 = triangles[k,1] #Global vertex id for sigma1
292                j2 = triangles[k,2] #Global vertex id for sigma2
293
294                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
295                js     = [j0,j1,j2]
296
297                for j in js:
298                    self.Atz[j] +=  sigmas[j]*z[i]
299                    #print "self.Atz building", self.Atz
300                    #print "self.Atz[j]", self.Atz[j]
301                    #print " sigmas[j]", sigmas[j]
302                    #print "z[i]",z[i]
303                    #print "result", sigmas[j]*z[i]
304                   
305                    for k in js:
306                        AtA[j,k] += sigmas[j]*sigmas[k]
307            else:
308                msg = 'Could not find triangle for point', x
309                raise Exception(msg)
310            self.AtA = AtA
311       
312    def fit(self, point_coordinates_or_filename=None, z=None,
313            verbose=False,
314            point_origin=None,
315            attribute_name=None,
316            max_read_lines=500):
317        """Fit a smooth surface to given 1d array of data points z.
318
319        The smooth surface is computed at each vertex in the underlying
320        mesh using the formula given in the module doc string.
321
322        Inputs:
323        point_coordinates: The co-ordinates of the data points.
324              List of coordinate pairs [x, y] of
325              data points or an nx2 Numeric array or a Geospatial_data object
326              or points file filename
327          z: Single 1d vector or array of data at the point_coordinates.
328         
329        """
330        # use blocking to load in the point info
331        if type(point_coordinates_or_filename) == types.StringType:
332            msg = "Don't set a point origin when reading from a file"
333            assert point_origin is None, msg
334            filename = point_coordinates_or_filename
335
336            G_data = Geospatial_data(filename,
337                                     max_read_lines=max_read_lines,
338                                     load_file_now=False,
339                                     verbose=verbose)
340
341            for i, geo_block in enumerate(G_data):
342                if verbose is True and 0 == i%200: 
343                    # The time this will take
344                    # is dependant on the # of Triangles
345                       
346                    print 'Processing Block %d' %i
347                    # FIXME (Ole): It would be good to say how many blocks
348                    # there are here. But this is no longer necessary
349                    # for pts files as they are reported in geospatial_data
350                    # I suggest deleting this verbose output and make
351                    # Geospatial_data more informative for txt files.
352                    #
353                    # I still think so (12/12/7, Ole).
354           
355
356                   
357                # Build the array
358
359                points = geo_block.get_data_points(absolute=True)
360                #print "fit points", points
361                z = geo_block.get_attributes(attribute_name=attribute_name)
362                self.build_fit_subset(points, z, verbose=verbose)
363
364               
365            point_coordinates = None
366        else:
367            point_coordinates =  point_coordinates_or_filename
368           
369        if point_coordinates is None:
370            print 'Warning: no data points in fit'
371            assert self.AtA <> None, 'no interpolation matrix'
372            assert self.Atz <> None
373           
374            #FIXME (DSG) - do  a message
375        else:
376            point_coordinates = ensure_absolute(point_coordinates,
377                                                geo_reference=point_origin)
378            #if isinstance(point_coordinates,Geospatial_data) and z is None:
379            # z will come from the geo-ref
380            self.build_fit_subset(point_coordinates, z, verbose)
381
382        #Check sanity
383        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
384        n = self.point_count
385        if n<m and self.alpha == 0.0:
386            msg = 'ERROR (least_squares): Too few data points\n'
387            msg += 'There are only %d data points and alpha == 0. ' %n
388            msg += 'Need at least %d\n' %m
389            msg += 'Alternatively, set smoothing parameter alpha to a small '
390            msg += 'positive value,\ne.g. 1.0e-3.'
391            raise TooFewPointsError(msg)
392
393        self._build_coefficient_matrix_B(verbose)
394        loners = self.mesh.get_lone_vertices()
395        # FIXME  - make this as error message.
396        # test with
397        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
398        if len(loners)>0:
399            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
400            msg += 'All vertices should be part of a triangle.\n'
401            msg += 'In the future this will be inforced.\n'
402            msg += 'The following vertices are not part of a triangle;\n'
403            msg += str(loners)
404            print msg
405            #raise VertsWithNoTrianglesError(msg)
406       
407       
408        return conjugate_gradient(self.B, self.Atz, self.Atz,
409                                  imax=2*len(self.Atz) )
410
411       
412    def build_fit_subset(self, point_coordinates, z=None, attribute_name=None,
413                              verbose=False):
414        """Fit a smooth surface to given 1d array of data points z.
415
416        The smooth surface is computed at each vertex in the underlying
417        mesh using the formula given in the module doc string.
418
419        Inputs:
420        point_coordinates: The co-ordinates of the data points.
421              List of coordinate pairs [x, y] of
422              data points or an nx2 Numeric array or a Geospatial_data object
423        z: Single 1d vector or array of data at the point_coordinates.
424        attribute_name: Used to get the z values from the
425              geospatial object if no attribute_name is specified,
426              it's a bit of a lucky dip as to what attributes you get.
427              If there is only one attribute it will be that one.
428
429        """
430
431        # FIXME(DSG-DSG): Check that the vert and point coords
432        # have the same zone.
433        if isinstance(point_coordinates,Geospatial_data):
434            point_coordinates = point_coordinates.get_data_points( \
435                absolute = True)
436       
437        # Convert input to Numeric arrays
438        if z is not None:
439            z = ensure_numeric(z, Float)
440        else:
441            msg = 'z not specified'
442            assert isinstance(point_coordinates,Geospatial_data), msg
443            z = point_coordinates.get_attributes(attribute_name)
444
445        point_coordinates = ensure_numeric(point_coordinates, Float)
446        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
447
448
449############################################################################
450
451def fit_to_mesh(point_coordinates, # this can also be a points file name
452                vertex_coordinates=None,
453                triangles=None,
454                mesh=None,
455                point_attributes=None,
456                alpha=DEFAULT_ALPHA,
457                verbose=False,
458                acceptable_overshoot=1.01,
459                mesh_origin=None,
460                data_origin=None,
461                max_read_lines=None,
462                attribute_name=None,
463                use_cache = False):
464    """Wrapper around internal function _fit_to_mesh for use with caching.
465   
466    """
467   
468    args = (point_coordinates, )
469    kwargs = {'vertex_coordinates': vertex_coordinates,
470              'triangles': triangles,
471              'mesh': mesh,
472              'point_attributes': point_attributes,
473              'alpha': alpha,
474              'verbose': verbose,
475              'acceptable_overshoot': acceptable_overshoot,
476              'mesh_origin': mesh_origin,
477              'data_origin': data_origin,
478              'max_read_lines': max_read_lines,
479              'attribute_name': attribute_name,
480              'use_cache':use_cache
481              }
482
483    if use_cache is True:
484        if isinstance(point_coordinates, basestring):
485            # We assume that point_coordinates is the name of a .csv/.txt
486            # file which must be passed onto caching as a dependency
487            # (in case it has changed on disk)
488            dep = [point_coordinates]
489        else:
490            dep = None
491
492        return cache(_fit_to_mesh,
493                     args, kwargs,
494                     verbose=verbose,
495                     compression=False,
496                     dependencies=dep)
497    else:
498        return apply(_fit_to_mesh,
499                     args, kwargs)
500
501def _fit_to_mesh(point_coordinates, # this can also be a points file name
502                 vertex_coordinates=None,
503                 triangles=None,
504                 mesh=None,
505                 point_attributes=None,
506                 alpha=DEFAULT_ALPHA,
507                 verbose=False,
508                 acceptable_overshoot=1.01,
509                 mesh_origin=None,
510                 data_origin=None,
511                 max_read_lines=None,
512                 attribute_name=None,
513                 use_cache = False):
514    """
515    Fit a smooth surface to a triangulation,
516    given data points with attributes.
517
518
519        Inputs:
520        vertex_coordinates: List of coordinate pairs [xi, eta] of
521              points constituting a mesh (or an m x 2 Numeric array or
522              a geospatial object)
523              Points may appear multiple times
524              (e.g. if vertices have discontinuities)
525
526          triangles: List of 3-tuples (or a Numeric array) of
527          integers representing indices of all vertices in the mesh.
528
529          point_coordinates: List of coordinate pairs [x, y] of data points
530          (or an nx2 Numeric array). This can also be a .csv/.txt/.pts
531          file name.
532
533          alpha: Smoothing parameter.
534
535          acceptable overshoot: NOT IMPLEMENTED
536          controls the allowed factor by which
537          fitted values
538          may exceed the value of input data. The lower limit is defined
539          as min(z) - acceptable_overshoot*delta z and upper limit
540          as max(z) + acceptable_overshoot*delta z
541         
542
543          mesh_origin: A geo_reference object or 3-tuples consisting of
544              UTM zone, easting and northing.
545              If specified vertex coordinates are assumed to be
546              relative to their respective origins.
547         
548
549          point_attributes: Vector or array of data at the
550                            point_coordinates.
551
552    """
553
554    # Duncan and Ole think that this isn't worth caching.
555    # Caching happens at the higher level anyway.
556   
557
558    if mesh is None:
559        # FIXME(DSG): Throw errors if triangles or vertex_coordinates
560        # are None
561           
562        #Convert input to Numeric arrays
563        triangles = ensure_numeric(triangles, Int)
564        vertex_coordinates = ensure_absolute(vertex_coordinates,
565                                             geo_reference = mesh_origin)
566
567        if verbose: print 'FitInterpolate: Building mesh'       
568        mesh = Mesh(vertex_coordinates, triangles)
569        mesh.check_integrity()
570   
571    interp = Fit(mesh=mesh,
572                 verbose=verbose,
573                 alpha=alpha)
574
575    vertex_attributes = interp.fit(point_coordinates,
576                                   point_attributes,
577                                   point_origin=data_origin,
578                                   max_read_lines=max_read_lines,
579                                   attribute_name=attribute_name,
580                                   verbose=verbose)
581
582       
583    # Add the value checking stuff that's in least squares.
584    # Maybe this stuff should get pushed down into Fit.
585    # at least be a method of Fit.
586    # Or intigrate it into the fit method, saving teh max and min's
587    # as att's.
588   
589    return vertex_attributes
590
591
592#def _fit(*args, **kwargs):
593#    """Private function for use with caching. Reason is that classes
594#    may change their byte code between runs which is annoying.
595#    """
596#   
597#    return Fit(*args, **kwargs)
598
599
600def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
601                     alpha=DEFAULT_ALPHA, verbose= False,
602                     expand_search = False,
603                     precrop = False,
604                     display_errors = True):
605    """
606    Given a mesh file (tsh) and a point attribute file, fit
607    point attributes to the mesh and write a mesh file with the
608    results.
609
610    Note: the points file needs titles.  If you want anuga to use the tsh file,
611    make sure the title is elevation.
612
613    NOTE: Throws IOErrors, for a variety of file problems.
614   
615    """
616
617    from load_mesh.loadASCII import import_mesh_file, \
618         export_mesh_file, concatinate_attributelist
619
620
621    try:
622        mesh_dict = import_mesh_file(mesh_file)
623    except IOError,e:
624        if display_errors:
625            print "Could not load bad file. ", e
626        raise IOError  #Could not load bad mesh file.
627   
628    vertex_coordinates = mesh_dict['vertices']
629    triangles = mesh_dict['triangles']
630    if type(mesh_dict['vertex_attributes']) == ArrayType:
631        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
632    else:
633        old_point_attributes = mesh_dict['vertex_attributes']
634
635    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
636        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
637    else:
638        old_title_list = mesh_dict['vertex_attribute_titles']
639
640    if verbose: print 'tsh file %s loaded' %mesh_file
641
642    # load in the points file
643    try:
644        geo = Geospatial_data(point_file, verbose=verbose)
645    except IOError,e:
646        if display_errors:
647            print "Could not load bad file. ", e
648        raise IOError  #Re-raise exception 
649
650    point_coordinates = geo.get_data_points(absolute=True)
651    title_list,point_attributes = concatinate_attributelist( \
652        geo.get_all_attributes())
653
654    if mesh_dict.has_key('geo_reference') and \
655           not mesh_dict['geo_reference'] is None:
656        mesh_origin = mesh_dict['geo_reference'].get_origin()
657    else:
658        mesh_origin = None
659
660    if verbose: print "points file loaded"
661    if verbose: print "fitting to mesh"
662    f = fit_to_mesh(point_coordinates,
663                    vertex_coordinates,
664                    triangles,
665                    None,
666                    point_attributes,
667                    alpha = alpha,
668                    verbose = verbose,
669                    data_origin = None,
670                    mesh_origin = mesh_origin)
671    if verbose: print "finished fitting to mesh"
672
673    # convert array to list of lists
674    new_point_attributes = f.tolist()
675    #FIXME have this overwrite attributes with the same title - DSG
676    #Put the newer attributes last
677    if old_title_list <> []:
678        old_title_list.extend(title_list)
679        #FIXME can this be done a faster way? - DSG
680        for i in range(len(old_point_attributes)):
681            old_point_attributes[i].extend(new_point_attributes[i])
682        mesh_dict['vertex_attributes'] = old_point_attributes
683        mesh_dict['vertex_attribute_titles'] = old_title_list
684    else:
685        mesh_dict['vertex_attributes'] = new_point_attributes
686        mesh_dict['vertex_attribute_titles'] = title_list
687
688    if verbose: print "exporting to file ", mesh_output_file
689
690    try:
691        export_mesh_file(mesh_output_file, mesh_dict)
692    except IOError,e:
693        if display_errors:
694            print "Could not write file. ", e
695        raise IOError
Note: See TracBrowser for help on using the repository browser.