source: trunk/anuga_core/source/anuga/fit_interpolate/fit.py @ 8466

Last change on this file since 8466 was 8125, checked in by wilsonr, 14 years ago

Changes to address ticket 360.

File size: 25.0 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4   putting point data onto the mesh.
5
6   The penalty term (or smoothing term) is controlled by the smoothing
7   parameter alpha.
8   With a value of alpha=0, the fit function will attempt
9   to interpolate as closely as possible in the least-squares sense.
10   With values alpha > 0, a certain amount of smoothing will be applied.
11   A positive alpha is essential in cases where there are too few
12   data points.
13   A negative alpha is not allowed.
14   A typical value of alpha is 1.0e-6
15
16
17   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
18   Geoscience Australia, 2004.
19
20   TO DO
21   * test geo_ref, geo_spatial
22
23   IDEAS
24   * (DSG-) Change the interface of fit, so a domain object can
25      be passed in. (I don't know if this is feasible). If could
26      save time/memory.
27"""
28
29from anuga.abstract_2d_finite_volumes.neighbour_mesh import Mesh
30from anuga.caching import cache           
31from anuga.geospatial_data.geospatial_data import Geospatial_data, \
32     ensure_absolute
33from anuga.fit_interpolate.general_fit_interpolate import FitInterpolate
34from anuga.utilities.sparse import Sparse, Sparse_CSR
35from anuga.geometry.polygon import inside_polygon, is_inside_polygon
36from anuga.pmesh.mesh_quadtree import MeshQuadtree
37
38from anuga.utilities.cg_solve import conjugate_gradient
39from anuga.utilities.numerical_tools import ensure_numeric, gradient
40from anuga.config import default_smoothing_parameter as DEFAULT_ALPHA
41import anuga.utilities.log as log
42
43import exceptions
44class TooFewPointsError(exceptions.Exception): pass
45class VertsWithNoTrianglesError(exceptions.Exception): pass
46
47import numpy as num
48
49
50class Fit(FitInterpolate):
51   
52    def __init__(self,
53                 vertex_coordinates=None,
54                 triangles=None,
55                 mesh=None,
56                 mesh_origin=None,
57                 alpha = None,
58                 verbose=False):
59
60
61        """
62        Fit data at points to the vertices of a mesh.
63
64        Inputs:
65
66          vertex_coordinates: List of coordinate pairs [xi, eta] of
67              points constituting a mesh (or an m x 2 numeric array or
68              a geospatial object)
69              Points may appear multiple times
70              (e.g. if vertices have discontinuities)
71
72          triangles: List of 3-tuples (or a numeric array) of
73              integers representing indices of all vertices in the mesh.
74
75          mesh_origin: A geo_reference object or 3-tuples consisting of
76              UTM zone, easting and northing.
77              If specified vertex coordinates are assumed to be
78              relative to their respective origins.
79
80          Note: Don't supply a vertex coords as a geospatial object and
81              a mesh origin, since geospatial has its own mesh origin.
82
83
84        Usage,
85        To use this in a blocking way, call  build_fit_subset, with z info,
86        and then fit, with no point coord, z info.
87       
88        """
89        # Initialise variabels
90        if alpha is None:
91            self.alpha = DEFAULT_ALPHA
92        else:   
93            self.alpha = alpha
94           
95        FitInterpolate.__init__(self,
96                 vertex_coordinates,
97                 triangles,
98                 mesh,
99                 mesh_origin,
100                 verbose)
101       
102        m = self.mesh.number_of_nodes # Nbr of basis functions (vertices)
103       
104        self.AtA = None
105        self.Atz = None
106
107        self.point_count = 0
108        if self.alpha <> 0:
109            if verbose: log.critical('Building smoothing matrix')
110            self._build_smoothing_matrix_D()
111           
112        bd_poly = self.mesh.get_boundary_polygon()   
113        self.mesh_boundary_polygon = ensure_numeric(bd_poly)
114           
115    def _build_coefficient_matrix_B(self,
116                                  verbose = False):
117        """
118        Build final coefficient matrix
119
120        Precon
121        If alpha is not zero, matrix D has been built
122        Matrix Ata has been built
123        """
124
125        if self.alpha <> 0:
126            #if verbose: log.critical('Building smoothing matrix')
127            #self._build_smoothing_matrix_D()
128            self.B = self.AtA + self.alpha*self.D
129        else:
130            self.B = self.AtA
131
132        # Convert self.B matrix to CSR format for faster matrix vector
133        self.B = Sparse_CSR(self.B)
134
135    def _build_smoothing_matrix_D(self):
136        """Build m x m smoothing matrix, where
137        m is the number of basis functions phi_k (one per vertex)
138
139        The smoothing matrix is defined as
140
141        D = D1 + D2
142
143        where
144
145        [D1]_{k,l} = \int_\Omega
146           \frac{\partial \phi_k}{\partial x}
147           \frac{\partial \phi_l}{\partial x}\,
148           dx dy
149
150        [D2]_{k,l} = \int_\Omega
151           \frac{\partial \phi_k}{\partial y}
152           \frac{\partial \phi_l}{\partial y}\,
153           dx dy
154
155
156        The derivatives \frac{\partial \phi_k}{\partial x},
157        \frac{\partial \phi_k}{\partial x} for a particular triangle
158        are obtained by computing the gradient a_k, b_k for basis function k
159        """
160       
161        # FIXME: algorithm might be optimised by computing local 9x9
162        # "element stiffness matrices:
163
164        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
165
166        self.D = Sparse(m,m)
167
168        # For each triangle compute contributions to D = D1+D2
169        for i in range(len(self.mesh)):
170
171            # Get area
172            area = self.mesh.areas[i]
173
174            # Get global vertex indices
175            v0 = self.mesh.triangles[i,0]
176            v1 = self.mesh.triangles[i,1]
177            v2 = self.mesh.triangles[i,2]
178
179            # Get the three vertex_points
180            xi0 = self.mesh.get_vertex_coordinate(i, 0)
181            xi1 = self.mesh.get_vertex_coordinate(i, 1)
182            xi2 = self.mesh.get_vertex_coordinate(i, 2)
183
184            # Compute gradients for each vertex
185            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
186                              1, 0, 0)
187
188            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
189                              0, 1, 0)
190
191            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
192                              0, 0, 1)
193
194            # Compute diagonal contributions
195            self.D[v0,v0] += (a0*a0 + b0*b0)*area
196            self.D[v1,v1] += (a1*a1 + b1*b1)*area
197            self.D[v2,v2] += (a2*a2 + b2*b2)*area
198
199            # Compute contributions for basis functions sharing edges
200            e01 = (a0*a1 + b0*b1)*area
201            self.D[v0,v1] += e01
202            self.D[v1,v0] += e01
203
204            e12 = (a1*a2 + b1*b2)*area
205            self.D[v1,v2] += e12
206            self.D[v2,v1] += e12
207
208            e20 = (a2*a0 + b2*b0)*area
209            self.D[v2,v0] += e20
210            self.D[v0,v2] += e20
211
212    def get_D(self):
213        return self.D.todense()
214
215
216
217    def _build_matrix_AtA_Atz(self,
218                              point_coordinates,
219                              z,
220                              verbose = False):
221        """Build:
222        AtA  m x m  interpolation matrix, and,
223        Atz  m x a  interpolation matrix where,
224        m is the number of basis functions phi_k (one per vertex)
225        a is the number of data attributes
226
227        This algorithm uses a quad tree data structure for fast binning of
228        data points.
229
230        If Ata is None, the matrices AtA and Atz are created.
231
232        This function can be called again and again, with sub-sets of
233        the point coordinates.  Call fit to get the results.
234       
235        Preconditions
236        z and points are numeric
237        Point_coordindates and mesh vertices have the same origin.
238
239        The number of attributes of the data points does not change
240        """
241       
242        # Build n x m interpolation matrix
243        if self.AtA == None:
244            # AtA and Atz need to be initialised.
245            m = self.mesh.number_of_nodes
246            if len(z.shape) > 1:
247                att_num = z.shape[1]
248                self.Atz = num.zeros((m,att_num), num.float)
249            else:
250                att_num = 1
251                self.Atz = num.zeros((m,), num.float)
252            assert z.shape[0] == point_coordinates.shape[0] 
253
254            AtA = Sparse(m,m)
255            # The memory damage has been done by now.
256        else:
257             AtA = self.AtA # Did this for speed, did ~nothing
258        self.point_count += point_coordinates.shape[0]
259
260
261        inside_indices = inside_polygon(point_coordinates,
262                                        self.mesh_boundary_polygon,
263                                        closed=True,
264                                        verbose=False) # Suppress output
265       
266        n = len(inside_indices)
267
268        # Compute matrix elements for points inside the mesh
269        triangles = self.mesh.triangles # Shorthand
270        for d, i in enumerate(inside_indices):
271            # For each data_coordinate point
272            # if verbose and d%((n+10)/10)==0: log.critical('Doing %d of %d'
273                                                            # %(d, n))
274            x = point_coordinates[i]
275           
276            element_found, sigma0, sigma1, sigma2, k = \
277                           self.root.search_fast(x)
278           
279            if element_found is True:
280                j0 = triangles[k,0] # Global vertex id for sigma0
281                j1 = triangles[k,1] # Global vertex id for sigma1
282                j2 = triangles[k,2] # Global vertex id for sigma2
283
284                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
285                js     = [j0,j1,j2]
286
287                for j in js:
288                    self.Atz[j] +=  sigmas[j]*z[i]
289                   
290                    for k in js:
291                        AtA[j,k] += sigmas[j]*sigmas[k]
292            else:
293                flag = is_inside_polygon(x,
294                                         self.mesh_boundary_polygon,
295                                         closed=True,
296                                         verbose=False) # Suppress output
297                msg = 'Point (%f, %f) is not inside mesh boundary' % tuple(x)
298                assert flag is True, msg               
299               
300                # data point has fallen within a hole - so ignore it.
301               
302        self.AtA = AtA
303
304       
305    def fit(self, point_coordinates_or_filename=None, z=None,
306            verbose=False,
307            point_origin=None,
308            attribute_name=None,
309            max_read_lines=500):
310        """Fit a smooth surface to given 1d array of data points z.
311
312        The smooth surface is computed at each vertex in the underlying
313        mesh using the formula given in the module doc string.
314
315        Inputs:
316        point_coordinates: The co-ordinates of the data points.
317              List of coordinate pairs [x, y] of
318              data points or an nx2 numeric array or a Geospatial_data object
319              or points file filename
320          z: Single 1d vector or array of data at the point_coordinates.
321         
322        """
323       
324        # Use blocking to load in the point info
325        if isinstance(point_coordinates_or_filename, basestring):
326            msg = "Don't set a point origin when reading from a file"
327            assert point_origin is None, msg
328            filename = point_coordinates_or_filename
329
330            G_data = Geospatial_data(filename,
331                                     max_read_lines=max_read_lines,
332                                     load_file_now=False,
333                                     verbose=verbose)
334
335            for i, geo_block in enumerate(G_data):
336                if verbose is True and 0 == i%200: 
337                    # The time this will take
338                    # is dependant on the # of Triangles
339                       
340                    log.critical('Processing Block %d' % i)
341                    # FIXME (Ole): It would be good to say how many blocks
342                    # there are here. But this is no longer necessary
343                    # for pts files as they are reported in geospatial_data
344                    # I suggest deleting this verbose output and make
345                    # Geospatial_data more informative for txt files.
346                    #
347                    # I still think so (12/12/7, Ole).
348           
349
350                   
351                # Build the array
352
353                points = geo_block.get_data_points(absolute=True)
354                z = geo_block.get_attributes(attribute_name=attribute_name)
355                self.build_fit_subset(points, z, verbose=verbose)
356
357                # FIXME(Ole): I thought this test would make sense here
358                # See test_fitting_example_that_crashed_2 in test_shallow_water_domain.py
359                # Committed 11 March 2009
360                msg = 'Matrix AtA was not built'
361                assert self.AtA is not None, msg
362               
363            point_coordinates = None
364        else:
365            point_coordinates =  point_coordinates_or_filename
366           
367        if point_coordinates is None:
368            if verbose: log.critical('Warning: no data points in fit')
369            msg = 'No interpolation matrix.'
370            assert self.AtA is not None, msg
371            assert self.Atz is not None
372           
373            # FIXME (DSG) - do  a message
374        else:
375            point_coordinates = ensure_absolute(point_coordinates,
376                                                geo_reference=point_origin)
377            # if isinstance(point_coordinates,Geospatial_data) and z is None:
378            # z will come from the geo-ref
379            self.build_fit_subset(point_coordinates, z, verbose)
380
381        # Check sanity
382        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
383        n = self.point_count
384        if n<m and self.alpha == 0.0:
385            msg = 'ERROR (least_squares): Too few data points\n'
386            msg += 'There are only %d data points and alpha == 0. ' %n
387            msg += 'Need at least %d\n' %m
388            msg += 'Alternatively, set smoothing parameter alpha to a small '
389            msg += 'positive value,\ne.g. 1.0e-3.'
390            raise TooFewPointsError(msg)
391
392        self._build_coefficient_matrix_B(verbose)
393        loners = self.mesh.get_lone_vertices()
394        # FIXME  - make this as error message.
395        # test with
396        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
397        if len(loners)>0:
398            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
399            msg += 'All vertices should be part of a triangle.\n'
400            msg += 'In the future this will be inforced.\n'
401            msg += 'The following vertices are not part of a triangle;\n'
402            msg += str(loners)
403            log.critical(msg)
404            #raise VertsWithNoTrianglesError(msg)
405       
406       
407        return conjugate_gradient(self.B, self.Atz, self.Atz,
408                                  imax=2*len(self.Atz) )
409
410       
411    def build_fit_subset(self, point_coordinates, z=None, attribute_name=None,
412                              verbose=False):
413        """Fit a smooth surface to given 1d array of data points z.
414
415        The smooth surface is computed at each vertex in the underlying
416        mesh using the formula given in the module doc string.
417
418        Inputs:
419        point_coordinates: The co-ordinates of the data points.
420              List of coordinate pairs [x, y] of
421              data points or an nx2 numeric array or a Geospatial_data object
422        z: Single 1d vector or array of data at the point_coordinates.
423        attribute_name: Used to get the z values from the
424              geospatial object if no attribute_name is specified,
425              it's a bit of a lucky dip as to what attributes you get.
426              If there is only one attribute it will be that one.
427
428        """
429
430        # FIXME(DSG-DSG): Check that the vert and point coords
431        # have the same zone.
432        if isinstance(point_coordinates,Geospatial_data):
433            point_coordinates = point_coordinates.get_data_points( \
434                absolute = True)
435       
436        # Convert input to numeric arrays
437        if z is not None:
438            z = ensure_numeric(z, num.float)
439        else:
440            msg = 'z not specified'
441            assert isinstance(point_coordinates,Geospatial_data), msg
442            z = point_coordinates.get_attributes(attribute_name)
443
444        point_coordinates = ensure_numeric(point_coordinates, num.float)
445        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
446
447
448############################################################################
449
450def fit_to_mesh(point_coordinates, # this can also be a points file name
451                vertex_coordinates=None,
452                triangles=None,
453                mesh=None,
454                point_attributes=None,
455                alpha=DEFAULT_ALPHA,
456                verbose=False,
457                mesh_origin=None,
458                data_origin=None,
459                max_read_lines=None,
460                attribute_name=None,
461                use_cache=False):
462    """Wrapper around internal function _fit_to_mesh for use with caching.
463   
464    """
465   
466    args = (point_coordinates, )
467    kwargs = {'vertex_coordinates': vertex_coordinates,
468              'triangles': triangles,
469              'mesh': mesh,
470              'point_attributes': point_attributes,
471              'alpha': alpha,
472              'verbose': verbose,
473              'mesh_origin': mesh_origin,
474              'data_origin': data_origin,
475              'max_read_lines': max_read_lines,
476              'attribute_name': attribute_name
477              }
478
479    if use_cache is True:
480        if isinstance(point_coordinates, basestring):
481            # We assume that point_coordinates is the name of a .csv/.txt
482            # file which must be passed onto caching as a dependency
483            # (in case it has changed on disk)
484            dep = [point_coordinates]
485        else:
486            dep = None
487
488           
489        #from caching import myhash
490        #import copy
491        #print args
492        #print kwargs
493        #print 'hashing:'
494        #print 'args', myhash( (args, kwargs) )
495        #print 'again', myhash( copy.deepcopy( (args, kwargs)) )       
496       
497        #print 'mesh hash', myhash( kwargs['mesh'] )       
498       
499        #print '-------------------------'
500        #print 'vertices hash', myhash( kwargs['mesh'].nodes )
501        #print 'triangles hash', myhash( kwargs['mesh'].triangles )
502        #print '-------------------------'       
503       
504        #for key in mesh.__dict__:
505        #    print key, myhash(mesh.__dict__[key])
506       
507        #for key in mesh.quantities.keys():
508        #    print key, myhash(mesh.quantities[key])
509       
510        #import sys; sys.exit()
511           
512        return cache(_fit_to_mesh,
513                     args, kwargs,
514                     verbose=verbose,
515                     compression=False,
516                     dependencies=dep)
517    else:
518        return apply(_fit_to_mesh,
519                     args, kwargs)
520
521def _fit_to_mesh(point_coordinates, # this can also be a points file name
522                 vertex_coordinates=None,
523                 triangles=None,
524                 mesh=None,
525                 point_attributes=None,
526                 alpha=DEFAULT_ALPHA,
527                 verbose=False,
528                 mesh_origin=None,
529                 data_origin=None,
530                 max_read_lines=None,
531                 attribute_name=None):
532    """
533    Fit a smooth surface to a triangulation,
534    given data points with attributes.
535
536
537        Inputs:
538        vertex_coordinates: List of coordinate pairs [xi, eta] of
539              points constituting a mesh (or an m x 2 numeric array or
540              a geospatial object)
541              Points may appear multiple times
542              (e.g. if vertices have discontinuities)
543
544          triangles: List of 3-tuples (or a numeric array) of
545          integers representing indices of all vertices in the mesh.
546
547          point_coordinates: List of coordinate pairs [x, y] of data points
548          (or an nx2 numeric array). This can also be a .csv/.txt/.pts
549          file name.
550
551          alpha: Smoothing parameter.
552
553          mesh_origin: A geo_reference object or 3-tuples consisting of
554              UTM zone, easting and northing.
555              If specified vertex coordinates are assumed to be
556              relative to their respective origins.
557
558          point_attributes: Vector or array of data at the
559                            point_coordinates.
560
561    """
562
563    if mesh is None:
564        # FIXME(DSG): Throw errors if triangles or vertex_coordinates
565        # are None
566           
567        #Convert input to numeric arrays
568        triangles = ensure_numeric(triangles, num.int)
569        vertex_coordinates = ensure_absolute(vertex_coordinates,
570                                             geo_reference = mesh_origin)
571
572        if verbose: log.critical('FitInterpolate: Building mesh')
573        mesh = Mesh(vertex_coordinates, triangles)
574        mesh.check_integrity()
575   
576   
577    interp = Fit(mesh=mesh,
578                 verbose=verbose,
579                 alpha=alpha)
580
581    vertex_attributes = interp.fit(point_coordinates,
582                                   point_attributes,
583                                   point_origin=data_origin,
584                                   max_read_lines=max_read_lines,
585                                   attribute_name=attribute_name,
586                                   verbose=verbose)
587
588       
589    # Add the value checking stuff that's in least squares.
590    # Maybe this stuff should get pushed down into Fit.
591    # at least be a method of Fit.
592    # Or intigrate it into the fit method, saving teh max and min's
593    # as att's.
594   
595    return vertex_attributes
596
597
598#def _fit(*args, **kwargs):
599#    """Private function for use with caching. Reason is that classes
600#    may change their byte code between runs which is annoying.
601#    """
602#   
603#    return Fit(*args, **kwargs)
604
605
606def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
607                     alpha=DEFAULT_ALPHA, verbose= False,
608                     expand_search = False,
609                     precrop = False,
610                     display_errors = True):
611    """
612    Given a mesh file (tsh) and a point attribute file, fit
613    point attributes to the mesh and write a mesh file with the
614    results.
615
616    Note: the points file needs titles.  If you want anuga to use the tsh file,
617    make sure the title is elevation.
618
619    NOTE: Throws IOErrors, for a variety of file problems.
620   
621    """
622
623    from load_mesh.loadASCII import import_mesh_file, \
624         export_mesh_file, concatinate_attributelist
625
626
627    try:
628        mesh_dict = import_mesh_file(mesh_file)
629    except IOError,e:
630        if display_errors:
631            log.critical("Could not load bad file: %s" % str(e))
632        raise IOError  #Could not load bad mesh file.
633   
634    vertex_coordinates = mesh_dict['vertices']
635    triangles = mesh_dict['triangles']
636    if isinstance(mesh_dict['vertex_attributes'], num.ndarray):
637        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
638    else:
639        old_point_attributes = mesh_dict['vertex_attributes']
640
641    if isinstance(mesh_dict['vertex_attribute_titles'], num.ndarray):
642        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
643    else:
644        old_title_list = mesh_dict['vertex_attribute_titles']
645
646    if verbose: log.critical('tsh file %s loaded' % mesh_file)
647
648    # load in the points file
649    try:
650        geo = Geospatial_data(point_file, verbose=verbose)
651    except IOError,e:
652        if display_errors:
653            log.critical("Could not load bad file: %s" % str(e))
654        raise IOError  #Re-raise exception 
655
656    point_coordinates = geo.get_data_points(absolute=True)
657    title_list,point_attributes = concatinate_attributelist( \
658        geo.get_all_attributes())
659
660    if mesh_dict.has_key('geo_reference') and \
661           not mesh_dict['geo_reference'] is None:
662        mesh_origin = mesh_dict['geo_reference'].get_origin()
663    else:
664        mesh_origin = None
665
666    if verbose: log.critical("points file loaded")
667    if verbose: log.critical("fitting to mesh")
668    f = fit_to_mesh(point_coordinates,
669                    vertex_coordinates,
670                    triangles,
671                    None,
672                    point_attributes,
673                    alpha = alpha,
674                    verbose = verbose,
675                    data_origin = None,
676                    mesh_origin = mesh_origin)
677    if verbose: log.critical("finished fitting to mesh")
678
679    # convert array to list of lists
680    new_point_attributes = f.tolist()
681    #FIXME have this overwrite attributes with the same title - DSG
682    #Put the newer attributes last
683    if old_title_list <> []:
684        old_title_list.extend(title_list)
685        #FIXME can this be done a faster way? - DSG
686        for i in range(len(old_point_attributes)):
687            old_point_attributes[i].extend(new_point_attributes[i])
688        mesh_dict['vertex_attributes'] = old_point_attributes
689        mesh_dict['vertex_attribute_titles'] = old_title_list
690    else:
691        mesh_dict['vertex_attributes'] = new_point_attributes
692        mesh_dict['vertex_attribute_titles'] = title_list
693
694    if verbose: log.critical("exporting to file %s" % mesh_output_file)
695
696    try:
697        export_mesh_file(mesh_output_file, mesh_dict)
698    except IOError,e:
699        if display_errors:
700            log.critical("Could not write file %s", str(e))
701        raise IOError
Note: See TracBrowser for help on using the repository browser.