source: trunk/anuga_core/source/anuga/fit_interpolate/fit.py @ 7998

Last change on this file since 7998 was 7804, checked in by James Hudson, 15 years ago

Fixed up failing tests, updated user guide with new API (first few chapters only).

File size: 25.0 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4   putting point data onto the mesh.
5
6   The penalty term (or smoothing term) is controlled by the smoothing
7   parameter alpha.
8   With a value of alpha=0, the fit function will attempt
9   to interpolate as closely as possible in the least-squares sense.
10   With values alpha > 0, a certain amount of smoothing will be applied.
11   A positive alpha is essential in cases where there are too few
12   data points.
13   A negative alpha is not allowed.
14   A typical value of alpha is 1.0e-6
15
16
17   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
18   Geoscience Australia, 2004.
19
20   TO DO
21   * test geo_ref, geo_spatial
22
23   IDEAS
24   * (DSG-) Change the interface of fit, so a domain object can
25      be passed in. (I don't know if this is feasible). If could
26      save time/memory.
27"""
28import types
29
30from anuga.abstract_2d_finite_volumes.neighbour_mesh import Mesh
31from anuga.caching import cache           
32from anuga.geospatial_data.geospatial_data import Geospatial_data, \
33     ensure_absolute
34from anuga.fit_interpolate.general_fit_interpolate import FitInterpolate
35from anuga.utilities.sparse import Sparse, Sparse_CSR
36from anuga.geometry.polygon import inside_polygon, is_inside_polygon
37from anuga.pmesh.mesh_quadtree import MeshQuadtree
38
39from anuga.utilities.cg_solve import conjugate_gradient
40from anuga.utilities.numerical_tools import ensure_numeric, gradient
41from anuga.config import default_smoothing_parameter as DEFAULT_ALPHA
42import anuga.utilities.log as log
43
44import exceptions
45class TooFewPointsError(exceptions.Exception): pass
46class VertsWithNoTrianglesError(exceptions.Exception): pass
47
48import numpy as num
49
50
51class Fit(FitInterpolate):
52   
53    def __init__(self,
54                 vertex_coordinates=None,
55                 triangles=None,
56                 mesh=None,
57                 mesh_origin=None,
58                 alpha = None,
59                 verbose=False):
60
61
62        """
63        Fit data at points to the vertices of a mesh.
64
65        Inputs:
66
67          vertex_coordinates: List of coordinate pairs [xi, eta] of
68              points constituting a mesh (or an m x 2 numeric array or
69              a geospatial object)
70              Points may appear multiple times
71              (e.g. if vertices have discontinuities)
72
73          triangles: List of 3-tuples (or a numeric array) of
74              integers representing indices of all vertices in the mesh.
75
76          mesh_origin: A geo_reference object or 3-tuples consisting of
77              UTM zone, easting and northing.
78              If specified vertex coordinates are assumed to be
79              relative to their respective origins.
80
81          Note: Don't supply a vertex coords as a geospatial object and
82              a mesh origin, since geospatial has its own mesh origin.
83
84
85        Usage,
86        To use this in a blocking way, call  build_fit_subset, with z info,
87        and then fit, with no point coord, z info.
88       
89        """
90        # Initialise variabels
91        if alpha is None:
92            self.alpha = DEFAULT_ALPHA
93        else:   
94            self.alpha = alpha
95           
96        FitInterpolate.__init__(self,
97                 vertex_coordinates,
98                 triangles,
99                 mesh,
100                 mesh_origin,
101                 verbose)
102       
103        m = self.mesh.number_of_nodes # Nbr of basis functions (vertices)
104       
105        self.AtA = None
106        self.Atz = None
107
108        self.point_count = 0
109        if self.alpha <> 0:
110            if verbose: log.critical('Building smoothing matrix')
111            self._build_smoothing_matrix_D()
112           
113        bd_poly = self.mesh.get_boundary_polygon()   
114        self.mesh_boundary_polygon = ensure_numeric(bd_poly)
115           
116    def _build_coefficient_matrix_B(self,
117                                  verbose = False):
118        """
119        Build final coefficient matrix
120
121        Precon
122        If alpha is not zero, matrix D has been built
123        Matrix Ata has been built
124        """
125
126        if self.alpha <> 0:
127            #if verbose: log.critical('Building smoothing matrix')
128            #self._build_smoothing_matrix_D()
129            self.B = self.AtA + self.alpha*self.D
130        else:
131            self.B = self.AtA
132
133        # Convert self.B matrix to CSR format for faster matrix vector
134        self.B = Sparse_CSR(self.B)
135
136    def _build_smoothing_matrix_D(self):
137        """Build m x m smoothing matrix, where
138        m is the number of basis functions phi_k (one per vertex)
139
140        The smoothing matrix is defined as
141
142        D = D1 + D2
143
144        where
145
146        [D1]_{k,l} = \int_\Omega
147           \frac{\partial \phi_k}{\partial x}
148           \frac{\partial \phi_l}{\partial x}\,
149           dx dy
150
151        [D2]_{k,l} = \int_\Omega
152           \frac{\partial \phi_k}{\partial y}
153           \frac{\partial \phi_l}{\partial y}\,
154           dx dy
155
156
157        The derivatives \frac{\partial \phi_k}{\partial x},
158        \frac{\partial \phi_k}{\partial x} for a particular triangle
159        are obtained by computing the gradient a_k, b_k for basis function k
160        """
161       
162        # FIXME: algorithm might be optimised by computing local 9x9
163        # "element stiffness matrices:
164
165        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
166
167        self.D = Sparse(m,m)
168
169        # For each triangle compute contributions to D = D1+D2
170        for i in range(len(self.mesh)):
171
172            # Get area
173            area = self.mesh.areas[i]
174
175            # Get global vertex indices
176            v0 = self.mesh.triangles[i,0]
177            v1 = self.mesh.triangles[i,1]
178            v2 = self.mesh.triangles[i,2]
179
180            # Get the three vertex_points
181            xi0 = self.mesh.get_vertex_coordinate(i, 0)
182            xi1 = self.mesh.get_vertex_coordinate(i, 1)
183            xi2 = self.mesh.get_vertex_coordinate(i, 2)
184
185            # Compute gradients for each vertex
186            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
187                              1, 0, 0)
188
189            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
190                              0, 1, 0)
191
192            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
193                              0, 0, 1)
194
195            # Compute diagonal contributions
196            self.D[v0,v0] += (a0*a0 + b0*b0)*area
197            self.D[v1,v1] += (a1*a1 + b1*b1)*area
198            self.D[v2,v2] += (a2*a2 + b2*b2)*area
199
200            # Compute contributions for basis functions sharing edges
201            e01 = (a0*a1 + b0*b1)*area
202            self.D[v0,v1] += e01
203            self.D[v1,v0] += e01
204
205            e12 = (a1*a2 + b1*b2)*area
206            self.D[v1,v2] += e12
207            self.D[v2,v1] += e12
208
209            e20 = (a2*a0 + b2*b0)*area
210            self.D[v2,v0] += e20
211            self.D[v0,v2] += e20
212
213    def get_D(self):
214        return self.D.todense()
215
216
217
218    def _build_matrix_AtA_Atz(self,
219                              point_coordinates,
220                              z,
221                              verbose = False):
222        """Build:
223        AtA  m x m  interpolation matrix, and,
224        Atz  m x a  interpolation matrix where,
225        m is the number of basis functions phi_k (one per vertex)
226        a is the number of data attributes
227
228        This algorithm uses a quad tree data structure for fast binning of
229        data points.
230
231        If Ata is None, the matrices AtA and Atz are created.
232
233        This function can be called again and again, with sub-sets of
234        the point coordinates.  Call fit to get the results.
235       
236        Preconditions
237        z and points are numeric
238        Point_coordindates and mesh vertices have the same origin.
239
240        The number of attributes of the data points does not change
241        """
242       
243        # Build n x m interpolation matrix
244        if self.AtA == None:
245            # AtA and Atz need to be initialised.
246            m = self.mesh.number_of_nodes
247            if len(z.shape) > 1:
248                att_num = z.shape[1]
249                self.Atz = num.zeros((m,att_num), num.float)
250            else:
251                att_num = 1
252                self.Atz = num.zeros((m,), num.float)
253            assert z.shape[0] == point_coordinates.shape[0] 
254
255            AtA = Sparse(m,m)
256            # The memory damage has been done by now.
257        else:
258             AtA = self.AtA # Did this for speed, did ~nothing
259        self.point_count += point_coordinates.shape[0]
260
261
262        inside_indices = inside_polygon(point_coordinates,
263                                        self.mesh_boundary_polygon,
264                                        closed=True,
265                                        verbose=False) # Suppress output
266       
267        n = len(inside_indices)
268
269        # Compute matrix elements for points inside the mesh
270        triangles = self.mesh.triangles # Shorthand
271        for d, i in enumerate(inside_indices):
272            # For each data_coordinate point
273            # if verbose and d%((n+10)/10)==0: log.critical('Doing %d of %d'
274                                                            # %(d, n))
275            x = point_coordinates[i]
276           
277            element_found, sigma0, sigma1, sigma2, k = \
278                           self.root.search_fast(x)
279           
280            if element_found is True:
281                j0 = triangles[k,0] # Global vertex id for sigma0
282                j1 = triangles[k,1] # Global vertex id for sigma1
283                j2 = triangles[k,2] # Global vertex id for sigma2
284
285                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
286                js     = [j0,j1,j2]
287
288                for j in js:
289                    self.Atz[j] +=  sigmas[j]*z[i]
290                   
291                    for k in js:
292                        AtA[j,k] += sigmas[j]*sigmas[k]
293            else:
294                flag = is_inside_polygon(x,
295                                         self.mesh_boundary_polygon,
296                                         closed=True,
297                                         verbose=False) # Suppress output
298                msg = 'Point (%f, %f) is not inside mesh boundary' % tuple(x)
299                assert flag is True, msg               
300               
301                # data point has fallen within a hole - so ignore it.
302               
303        self.AtA = AtA
304
305       
306    def fit(self, point_coordinates_or_filename=None, z=None,
307            verbose=False,
308            point_origin=None,
309            attribute_name=None,
310            max_read_lines=500):
311        """Fit a smooth surface to given 1d array of data points z.
312
313        The smooth surface is computed at each vertex in the underlying
314        mesh using the formula given in the module doc string.
315
316        Inputs:
317        point_coordinates: The co-ordinates of the data points.
318              List of coordinate pairs [x, y] of
319              data points or an nx2 numeric array or a Geospatial_data object
320              or points file filename
321          z: Single 1d vector or array of data at the point_coordinates.
322         
323        """
324       
325        # Use blocking to load in the point info
326        if type(point_coordinates_or_filename) == types.StringType:
327            msg = "Don't set a point origin when reading from a file"
328            assert point_origin is None, msg
329            filename = point_coordinates_or_filename
330
331            G_data = Geospatial_data(filename,
332                                     max_read_lines=max_read_lines,
333                                     load_file_now=False,
334                                     verbose=verbose)
335
336            for i, geo_block in enumerate(G_data):
337                if verbose is True and 0 == i%200: 
338                    # The time this will take
339                    # is dependant on the # of Triangles
340                       
341                    log.critical('Processing Block %d' % i)
342                    # FIXME (Ole): It would be good to say how many blocks
343                    # there are here. But this is no longer necessary
344                    # for pts files as they are reported in geospatial_data
345                    # I suggest deleting this verbose output and make
346                    # Geospatial_data more informative for txt files.
347                    #
348                    # I still think so (12/12/7, Ole).
349           
350
351                   
352                # Build the array
353
354                points = geo_block.get_data_points(absolute=True)
355                z = geo_block.get_attributes(attribute_name=attribute_name)
356                self.build_fit_subset(points, z, verbose=verbose)
357
358                # FIXME(Ole): I thought this test would make sense here
359                # See test_fitting_example_that_crashed_2 in test_shallow_water_domain.py
360                # Committed 11 March 2009
361                msg = 'Matrix AtA was not built'
362                assert self.AtA is not None, msg
363               
364            point_coordinates = None
365        else:
366            point_coordinates =  point_coordinates_or_filename
367           
368        if point_coordinates is None:
369            if verbose: log.critical('Warning: no data points in fit')
370            msg = 'No interpolation matrix.'
371            assert self.AtA is not None, msg
372            assert self.Atz is not None
373           
374            # FIXME (DSG) - do  a message
375        else:
376            point_coordinates = ensure_absolute(point_coordinates,
377                                                geo_reference=point_origin)
378            # if isinstance(point_coordinates,Geospatial_data) and z is None:
379            # z will come from the geo-ref
380            self.build_fit_subset(point_coordinates, z, verbose)
381
382        # Check sanity
383        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
384        n = self.point_count
385        if n<m and self.alpha == 0.0:
386            msg = 'ERROR (least_squares): Too few data points\n'
387            msg += 'There are only %d data points and alpha == 0. ' %n
388            msg += 'Need at least %d\n' %m
389            msg += 'Alternatively, set smoothing parameter alpha to a small '
390            msg += 'positive value,\ne.g. 1.0e-3.'
391            raise TooFewPointsError(msg)
392
393        self._build_coefficient_matrix_B(verbose)
394        loners = self.mesh.get_lone_vertices()
395        # FIXME  - make this as error message.
396        # test with
397        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
398        if len(loners)>0:
399            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
400            msg += 'All vertices should be part of a triangle.\n'
401            msg += 'In the future this will be inforced.\n'
402            msg += 'The following vertices are not part of a triangle;\n'
403            msg += str(loners)
404            log.critical(msg)
405            #raise VertsWithNoTrianglesError(msg)
406       
407       
408        return conjugate_gradient(self.B, self.Atz, self.Atz,
409                                  imax=2*len(self.Atz) )
410
411       
412    def build_fit_subset(self, point_coordinates, z=None, attribute_name=None,
413                              verbose=False):
414        """Fit a smooth surface to given 1d array of data points z.
415
416        The smooth surface is computed at each vertex in the underlying
417        mesh using the formula given in the module doc string.
418
419        Inputs:
420        point_coordinates: The co-ordinates of the data points.
421              List of coordinate pairs [x, y] of
422              data points or an nx2 numeric array or a Geospatial_data object
423        z: Single 1d vector or array of data at the point_coordinates.
424        attribute_name: Used to get the z values from the
425              geospatial object if no attribute_name is specified,
426              it's a bit of a lucky dip as to what attributes you get.
427              If there is only one attribute it will be that one.
428
429        """
430
431        # FIXME(DSG-DSG): Check that the vert and point coords
432        # have the same zone.
433        if isinstance(point_coordinates,Geospatial_data):
434            point_coordinates = point_coordinates.get_data_points( \
435                absolute = True)
436       
437        # Convert input to numeric arrays
438        if z is not None:
439            z = ensure_numeric(z, num.float)
440        else:
441            msg = 'z not specified'
442            assert isinstance(point_coordinates,Geospatial_data), msg
443            z = point_coordinates.get_attributes(attribute_name)
444
445        point_coordinates = ensure_numeric(point_coordinates, num.float)
446        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
447
448
449############################################################################
450
451def fit_to_mesh(point_coordinates, # this can also be a points file name
452                vertex_coordinates=None,
453                triangles=None,
454                mesh=None,
455                point_attributes=None,
456                alpha=DEFAULT_ALPHA,
457                verbose=False,
458                mesh_origin=None,
459                data_origin=None,
460                max_read_lines=None,
461                attribute_name=None,
462                use_cache=False):
463    """Wrapper around internal function _fit_to_mesh for use with caching.
464   
465    """
466   
467    args = (point_coordinates, )
468    kwargs = {'vertex_coordinates': vertex_coordinates,
469              'triangles': triangles,
470              'mesh': mesh,
471              'point_attributes': point_attributes,
472              'alpha': alpha,
473              'verbose': verbose,
474              'mesh_origin': mesh_origin,
475              'data_origin': data_origin,
476              'max_read_lines': max_read_lines,
477              'attribute_name': attribute_name
478              }
479
480    if use_cache is True:
481        if isinstance(point_coordinates, basestring):
482            # We assume that point_coordinates is the name of a .csv/.txt
483            # file which must be passed onto caching as a dependency
484            # (in case it has changed on disk)
485            dep = [point_coordinates]
486        else:
487            dep = None
488
489           
490        #from caching import myhash
491        #import copy
492        #print args
493        #print kwargs
494        #print 'hashing:'
495        #print 'args', myhash( (args, kwargs) )
496        #print 'again', myhash( copy.deepcopy( (args, kwargs)) )       
497       
498        #print 'mesh hash', myhash( kwargs['mesh'] )       
499       
500        #print '-------------------------'
501        #print 'vertices hash', myhash( kwargs['mesh'].nodes )
502        #print 'triangles hash', myhash( kwargs['mesh'].triangles )
503        #print '-------------------------'       
504       
505        #for key in mesh.__dict__:
506        #    print key, myhash(mesh.__dict__[key])
507       
508        #for key in mesh.quantities.keys():
509        #    print key, myhash(mesh.quantities[key])
510       
511        #import sys; sys.exit()
512           
513        return cache(_fit_to_mesh,
514                     args, kwargs,
515                     verbose=verbose,
516                     compression=False,
517                     dependencies=dep)
518    else:
519        return apply(_fit_to_mesh,
520                     args, kwargs)
521
522def _fit_to_mesh(point_coordinates, # this can also be a points file name
523                 vertex_coordinates=None,
524                 triangles=None,
525                 mesh=None,
526                 point_attributes=None,
527                 alpha=DEFAULT_ALPHA,
528                 verbose=False,
529                 mesh_origin=None,
530                 data_origin=None,
531                 max_read_lines=None,
532                 attribute_name=None):
533    """
534    Fit a smooth surface to a triangulation,
535    given data points with attributes.
536
537
538        Inputs:
539        vertex_coordinates: List of coordinate pairs [xi, eta] of
540              points constituting a mesh (or an m x 2 numeric array or
541              a geospatial object)
542              Points may appear multiple times
543              (e.g. if vertices have discontinuities)
544
545          triangles: List of 3-tuples (or a numeric array) of
546          integers representing indices of all vertices in the mesh.
547
548          point_coordinates: List of coordinate pairs [x, y] of data points
549          (or an nx2 numeric array). This can also be a .csv/.txt/.pts
550          file name.
551
552          alpha: Smoothing parameter.
553
554          mesh_origin: A geo_reference object or 3-tuples consisting of
555              UTM zone, easting and northing.
556              If specified vertex coordinates are assumed to be
557              relative to their respective origins.
558
559          point_attributes: Vector or array of data at the
560                            point_coordinates.
561
562    """
563
564    if mesh is None:
565        # FIXME(DSG): Throw errors if triangles or vertex_coordinates
566        # are None
567           
568        #Convert input to numeric arrays
569        triangles = ensure_numeric(triangles, num.int)
570        vertex_coordinates = ensure_absolute(vertex_coordinates,
571                                             geo_reference = mesh_origin)
572
573        if verbose: log.critical('FitInterpolate: Building mesh')
574        mesh = Mesh(vertex_coordinates, triangles)
575        mesh.check_integrity()
576   
577   
578    interp = Fit(mesh=mesh,
579                 verbose=verbose,
580                 alpha=alpha)
581
582    vertex_attributes = interp.fit(point_coordinates,
583                                   point_attributes,
584                                   point_origin=data_origin,
585                                   max_read_lines=max_read_lines,
586                                   attribute_name=attribute_name,
587                                   verbose=verbose)
588
589       
590    # Add the value checking stuff that's in least squares.
591    # Maybe this stuff should get pushed down into Fit.
592    # at least be a method of Fit.
593    # Or intigrate it into the fit method, saving teh max and min's
594    # as att's.
595   
596    return vertex_attributes
597
598
599#def _fit(*args, **kwargs):
600#    """Private function for use with caching. Reason is that classes
601#    may change their byte code between runs which is annoying.
602#    """
603#   
604#    return Fit(*args, **kwargs)
605
606
607def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
608                     alpha=DEFAULT_ALPHA, verbose= False,
609                     expand_search = False,
610                     precrop = False,
611                     display_errors = True):
612    """
613    Given a mesh file (tsh) and a point attribute file, fit
614    point attributes to the mesh and write a mesh file with the
615    results.
616
617    Note: the points file needs titles.  If you want anuga to use the tsh file,
618    make sure the title is elevation.
619
620    NOTE: Throws IOErrors, for a variety of file problems.
621   
622    """
623
624    from load_mesh.loadASCII import import_mesh_file, \
625         export_mesh_file, concatinate_attributelist
626
627
628    try:
629        mesh_dict = import_mesh_file(mesh_file)
630    except IOError,e:
631        if display_errors:
632            log.critical("Could not load bad file: %s" % str(e))
633        raise IOError  #Could not load bad mesh file.
634   
635    vertex_coordinates = mesh_dict['vertices']
636    triangles = mesh_dict['triangles']
637    if isinstance(mesh_dict['vertex_attributes'], num.ndarray):
638        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
639    else:
640        old_point_attributes = mesh_dict['vertex_attributes']
641
642    if isinstance(mesh_dict['vertex_attribute_titles'], num.ndarray):
643        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
644    else:
645        old_title_list = mesh_dict['vertex_attribute_titles']
646
647    if verbose: log.critical('tsh file %s loaded' % mesh_file)
648
649    # load in the points file
650    try:
651        geo = Geospatial_data(point_file, verbose=verbose)
652    except IOError,e:
653        if display_errors:
654            log.critical("Could not load bad file: %s" % str(e))
655        raise IOError  #Re-raise exception 
656
657    point_coordinates = geo.get_data_points(absolute=True)
658    title_list,point_attributes = concatinate_attributelist( \
659        geo.get_all_attributes())
660
661    if mesh_dict.has_key('geo_reference') and \
662           not mesh_dict['geo_reference'] is None:
663        mesh_origin = mesh_dict['geo_reference'].get_origin()
664    else:
665        mesh_origin = None
666
667    if verbose: log.critical("points file loaded")
668    if verbose: log.critical("fitting to mesh")
669    f = fit_to_mesh(point_coordinates,
670                    vertex_coordinates,
671                    triangles,
672                    None,
673                    point_attributes,
674                    alpha = alpha,
675                    verbose = verbose,
676                    data_origin = None,
677                    mesh_origin = mesh_origin)
678    if verbose: log.critical("finished fitting to mesh")
679
680    # convert array to list of lists
681    new_point_attributes = f.tolist()
682    #FIXME have this overwrite attributes with the same title - DSG
683    #Put the newer attributes last
684    if old_title_list <> []:
685        old_title_list.extend(title_list)
686        #FIXME can this be done a faster way? - DSG
687        for i in range(len(old_point_attributes)):
688            old_point_attributes[i].extend(new_point_attributes[i])
689        mesh_dict['vertex_attributes'] = old_point_attributes
690        mesh_dict['vertex_attribute_titles'] = old_title_list
691    else:
692        mesh_dict['vertex_attributes'] = new_point_attributes
693        mesh_dict['vertex_attribute_titles'] = title_list
694
695    if verbose: log.critical("exporting to file %s" % mesh_output_file)
696
697    try:
698        export_mesh_file(mesh_output_file, mesh_dict)
699    except IOError,e:
700        if display_errors:
701            log.critical("Could not write file %s", str(e))
702        raise IOError
Note: See TracBrowser for help on using the repository browser.