source: trunk/anuga_core/source/anuga/fit_interpolate/fit.py @ 7753

Last change on this file since 7753 was 7751, checked in by James Hudson, 15 years ago

Refactorings from May ANUGA meeting.

File size: 25.5 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4   putting point data onto the mesh.
5
6   The penalty term (or smoothing term) is controlled by the smoothing
7   parameter alpha.
8   With a value of alpha=0, the fit function will attempt
9   to interpolate as closely as possible in the least-squares sense.
10   With values alpha > 0, a certain amount of smoothing will be applied.
11   A positive alpha is essential in cases where there are too few
12   data points.
13   A negative alpha is not allowed.
14   A typical value of alpha is 1.0e-6
15
16
17   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
18   Geoscience Australia, 2004.
19
20   TO DO
21   * test geo_ref, geo_spatial
22
23   IDEAS
24   * (DSG-) Change the interface of fit, so a domain object can
25      be passed in. (I don't know if this is feasible). If could
26      save time/memory.
27"""
28import types
29
30from anuga.abstract_2d_finite_volumes.neighbour_mesh import Mesh
31from anuga.caching import cache           
32from anuga.geospatial_data.geospatial_data import Geospatial_data, \
33     ensure_absolute
34from anuga.fit_interpolate.general_fit_interpolate import FitInterpolate
35from anuga.utilities.sparse import Sparse, Sparse_CSR
36from anuga.geometry.polygon import inside_polygon, is_inside_polygon
37from anuga.pmesh.mesh_quadtree import MeshQuadtree
38
39from anuga.utilities.cg_solve import conjugate_gradient
40from anuga.utilities.numerical_tools import ensure_numeric, gradient
41from anuga.config import default_smoothing_parameter as DEFAULT_ALPHA
42import anuga.utilities.log as log
43
44import exceptions
45class TooFewPointsError(exceptions.Exception): pass
46class VertsWithNoTrianglesError(exceptions.Exception): pass
47
48import numpy as num
49
50
51class Fit(FitInterpolate):
52   
53    def __init__(self,
54                 vertex_coordinates=None,
55                 triangles=None,
56                 mesh=None,
57                 mesh_origin=None,
58                 alpha = None,
59                 verbose=False):
60
61
62        """
63        Fit data at points to the vertices of a mesh.
64
65        Inputs:
66
67          vertex_coordinates: List of coordinate pairs [xi, eta] of
68              points constituting a mesh (or an m x 2 numeric array or
69              a geospatial object)
70              Points may appear multiple times
71              (e.g. if vertices have discontinuities)
72
73          triangles: List of 3-tuples (or a numeric array) of
74              integers representing indices of all vertices in the mesh.
75
76          mesh_origin: A geo_reference object or 3-tuples consisting of
77              UTM zone, easting and northing.
78              If specified vertex coordinates are assumed to be
79              relative to their respective origins.
80
81          Note: Don't supply a vertex coords as a geospatial object and
82              a mesh origin, since geospatial has its own mesh origin.
83
84
85        Usage,
86        To use this in a blocking way, call  build_fit_subset, with z info,
87        and then fit, with no point coord, z info.
88       
89        """
90        # Initialise variabels
91        if alpha is None:
92            self.alpha = DEFAULT_ALPHA
93        else:   
94            self.alpha = alpha
95           
96        FitInterpolate.__init__(self,
97                 vertex_coordinates,
98                 triangles,
99                 mesh,
100                 mesh_origin,
101                 verbose)
102       
103        m = self.mesh.number_of_nodes # Nbr of basis functions (vertices)
104       
105        self.AtA = None
106        self.Atz = None
107
108        self.point_count = 0
109        if self.alpha <> 0:
110            if verbose: log.critical('Building smoothing matrix')
111            self._build_smoothing_matrix_D()
112           
113        bd_poly = self.mesh.get_boundary_polygon()   
114        self.mesh_boundary_polygon = ensure_numeric(bd_poly)
115           
116    def _build_coefficient_matrix_B(self,
117                                  verbose = False):
118        """
119        Build final coefficient matrix
120
121        Precon
122        If alpha is not zero, matrix D has been built
123        Matrix Ata has been built
124        """
125
126        if self.alpha <> 0:
127            #if verbose: log.critical('Building smoothing matrix')
128            #self._build_smoothing_matrix_D()
129            self.B = self.AtA + self.alpha*self.D
130        else:
131            self.B = self.AtA
132
133        # Convert self.B matrix to CSR format for faster matrix vector
134        self.B = Sparse_CSR(self.B)
135
136    def _build_smoothing_matrix_D(self):
137        """Build m x m smoothing matrix, where
138        m is the number of basis functions phi_k (one per vertex)
139
140        The smoothing matrix is defined as
141
142        D = D1 + D2
143
144        where
145
146        [D1]_{k,l} = \int_\Omega
147           \frac{\partial \phi_k}{\partial x}
148           \frac{\partial \phi_l}{\partial x}\,
149           dx dy
150
151        [D2]_{k,l} = \int_\Omega
152           \frac{\partial \phi_k}{\partial y}
153           \frac{\partial \phi_l}{\partial y}\,
154           dx dy
155
156
157        The derivatives \frac{\partial \phi_k}{\partial x},
158        \frac{\partial \phi_k}{\partial x} for a particular triangle
159        are obtained by computing the gradient a_k, b_k for basis function k
160        """
161       
162        # FIXME: algorithm might be optimised by computing local 9x9
163        # "element stiffness matrices:
164
165        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
166
167        self.D = Sparse(m,m)
168
169        # For each triangle compute contributions to D = D1+D2
170        for i in range(len(self.mesh)):
171
172            # Get area
173            area = self.mesh.areas[i]
174
175            # Get global vertex indices
176            v0 = self.mesh.triangles[i,0]
177            v1 = self.mesh.triangles[i,1]
178            v2 = self.mesh.triangles[i,2]
179
180            # Get the three vertex_points
181            xi0 = self.mesh.get_vertex_coordinate(i, 0)
182            xi1 = self.mesh.get_vertex_coordinate(i, 1)
183            xi2 = self.mesh.get_vertex_coordinate(i, 2)
184
185            # Compute gradients for each vertex
186            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
187                              1, 0, 0)
188
189            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
190                              0, 1, 0)
191
192            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
193                              0, 0, 1)
194
195            # Compute diagonal contributions
196            self.D[v0,v0] += (a0*a0 + b0*b0)*area
197            self.D[v1,v1] += (a1*a1 + b1*b1)*area
198            self.D[v2,v2] += (a2*a2 + b2*b2)*area
199
200            # Compute contributions for basis functions sharing edges
201            e01 = (a0*a1 + b0*b1)*area
202            self.D[v0,v1] += e01
203            self.D[v1,v0] += e01
204
205            e12 = (a1*a2 + b1*b2)*area
206            self.D[v1,v2] += e12
207            self.D[v2,v1] += e12
208
209            e20 = (a2*a0 + b2*b0)*area
210            self.D[v2,v0] += e20
211            self.D[v0,v2] += e20
212
213    def get_D(self):
214        return self.D.todense()
215
216
217
218    def _build_matrix_AtA_Atz(self,
219                              point_coordinates,
220                              z,
221                              verbose = False):
222        """Build:
223        AtA  m x m  interpolation matrix, and,
224        Atz  m x a  interpolation matrix where,
225        m is the number of basis functions phi_k (one per vertex)
226        a is the number of data attributes
227
228        This algorithm uses a quad tree data structure for fast binning of
229        data points.
230
231        If Ata is None, the matrices AtA and Atz are created.
232
233        This function can be called again and again, with sub-sets of
234        the point coordinates.  Call fit to get the results.
235       
236        Preconditions
237        z and points are numeric
238        Point_coordindates and mesh vertices have the same origin.
239
240        The number of attributes of the data points does not change
241        """
242       
243        # Build n x m interpolation matrix
244        if self.AtA == None:
245            # AtA and Atz need to be initialised.
246            m = self.mesh.number_of_nodes
247            if len(z.shape) > 1:
248                att_num = z.shape[1]
249                self.Atz = num.zeros((m,att_num), num.float)
250            else:
251                att_num = 1
252                self.Atz = num.zeros((m,), num.float)
253            assert z.shape[0] == point_coordinates.shape[0] 
254
255            AtA = Sparse(m,m)
256            # The memory damage has been done by now.
257        else:
258             AtA = self.AtA # Did this for speed, did ~nothing
259        self.point_count += point_coordinates.shape[0]
260
261
262        inside_indices = inside_polygon(point_coordinates,
263                                        self.mesh_boundary_polygon,
264                                        closed=True,
265                                        verbose=False) # Suppress output
266       
267        n = len(inside_indices)
268
269        # Compute matrix elements for points inside the mesh
270        triangles = self.mesh.triangles # Shorthand
271        for d, i in enumerate(inside_indices):
272            # For each data_coordinate point
273            # if verbose and d%((n+10)/10)==0: log.critical('Doing %d of %d'
274                                                            # %(d, n))
275            x = point_coordinates[i]
276           
277            element_found, sigma0, sigma1, sigma2, k = \
278                           self.root.search_fast(x)
279           
280            if element_found is True:
281                j0 = triangles[k,0] # Global vertex id for sigma0
282                j1 = triangles[k,1] # Global vertex id for sigma1
283                j2 = triangles[k,2] # Global vertex id for sigma2
284
285                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
286                js     = [j0,j1,j2]
287
288                for j in js:
289                    self.Atz[j] +=  sigmas[j]*z[i]
290                   
291                    for k in js:
292                        AtA[j,k] += sigmas[j]*sigmas[k]
293            else:
294                flag = is_inside_polygon(x,
295                                         self.mesh_boundary_polygon,
296                                         closed=True,
297                                         verbose=False) # Suppress output
298                msg = 'Point (%f, %f) is not inside mesh boundary' % tuple(x)
299                assert flag is True, msg               
300               
301                # FIXME(Ole): This is the message referred to in ticket:314
302                minx = min(self.mesh_boundary_polygon[:,0])
303                maxx = max(self.mesh_boundary_polygon[:,0])               
304                miny = min(self.mesh_boundary_polygon[:,1])
305                maxy = max(self.mesh_boundary_polygon[:,1])
306                msg = 'Could not find triangle for point %s. ' % str(x) 
307                msg += 'Mesh boundary extent is (%.f, %.f), (%.f, %.f)'\
308                    % (minx, maxx, miny, maxy)
309                raise RuntimeError, msg
310
311               
312        self.AtA = AtA
313
314       
315    def fit(self, point_coordinates_or_filename=None, z=None,
316            verbose=False,
317            point_origin=None,
318            attribute_name=None,
319            max_read_lines=500):
320        """Fit a smooth surface to given 1d array of data points z.
321
322        The smooth surface is computed at each vertex in the underlying
323        mesh using the formula given in the module doc string.
324
325        Inputs:
326        point_coordinates: The co-ordinates of the data points.
327              List of coordinate pairs [x, y] of
328              data points or an nx2 numeric array or a Geospatial_data object
329              or points file filename
330          z: Single 1d vector or array of data at the point_coordinates.
331         
332        """
333       
334        # Use blocking to load in the point info
335        if type(point_coordinates_or_filename) == types.StringType:
336            msg = "Don't set a point origin when reading from a file"
337            assert point_origin is None, msg
338            filename = point_coordinates_or_filename
339
340            G_data = Geospatial_data(filename,
341                                     max_read_lines=max_read_lines,
342                                     load_file_now=False,
343                                     verbose=verbose)
344
345            for i, geo_block in enumerate(G_data):
346                if verbose is True and 0 == i%200: 
347                    # The time this will take
348                    # is dependant on the # of Triangles
349                       
350                    log.critical('Processing Block %d' % i)
351                    # FIXME (Ole): It would be good to say how many blocks
352                    # there are here. But this is no longer necessary
353                    # for pts files as they are reported in geospatial_data
354                    # I suggest deleting this verbose output and make
355                    # Geospatial_data more informative for txt files.
356                    #
357                    # I still think so (12/12/7, Ole).
358           
359
360                   
361                # Build the array
362
363                points = geo_block.get_data_points(absolute=True)
364                z = geo_block.get_attributes(attribute_name=attribute_name)
365                self.build_fit_subset(points, z, verbose=verbose)
366
367                # FIXME(Ole): I thought this test would make sense here
368                # See test_fitting_example_that_crashed_2 in test_shallow_water_domain.py
369                # Committed 11 March 2009
370                msg = 'Matrix AtA was not built'
371                assert self.AtA is not None, msg
372               
373            point_coordinates = None
374        else:
375            point_coordinates =  point_coordinates_or_filename
376           
377        if point_coordinates is None:
378            if verbose: log.critical('Warning: no data points in fit')
379            msg = 'No interpolation matrix.'
380            assert self.AtA is not None, msg
381            assert self.Atz is not None
382           
383            # FIXME (DSG) - do  a message
384        else:
385            point_coordinates = ensure_absolute(point_coordinates,
386                                                geo_reference=point_origin)
387            # if isinstance(point_coordinates,Geospatial_data) and z is None:
388            # z will come from the geo-ref
389            self.build_fit_subset(point_coordinates, z, verbose)
390
391        # Check sanity
392        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
393        n = self.point_count
394        if n<m and self.alpha == 0.0:
395            msg = 'ERROR (least_squares): Too few data points\n'
396            msg += 'There are only %d data points and alpha == 0. ' %n
397            msg += 'Need at least %d\n' %m
398            msg += 'Alternatively, set smoothing parameter alpha to a small '
399            msg += 'positive value,\ne.g. 1.0e-3.'
400            raise TooFewPointsError(msg)
401
402        self._build_coefficient_matrix_B(verbose)
403        loners = self.mesh.get_lone_vertices()
404        # FIXME  - make this as error message.
405        # test with
406        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
407        if len(loners)>0:
408            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
409            msg += 'All vertices should be part of a triangle.\n'
410            msg += 'In the future this will be inforced.\n'
411            msg += 'The following vertices are not part of a triangle;\n'
412            msg += str(loners)
413            log.critical(msg)
414            #raise VertsWithNoTrianglesError(msg)
415       
416       
417        return conjugate_gradient(self.B, self.Atz, self.Atz,
418                                  imax=2*len(self.Atz) )
419
420       
421    def build_fit_subset(self, point_coordinates, z=None, attribute_name=None,
422                              verbose=False):
423        """Fit a smooth surface to given 1d array of data points z.
424
425        The smooth surface is computed at each vertex in the underlying
426        mesh using the formula given in the module doc string.
427
428        Inputs:
429        point_coordinates: The co-ordinates of the data points.
430              List of coordinate pairs [x, y] of
431              data points or an nx2 numeric array or a Geospatial_data object
432        z: Single 1d vector or array of data at the point_coordinates.
433        attribute_name: Used to get the z values from the
434              geospatial object if no attribute_name is specified,
435              it's a bit of a lucky dip as to what attributes you get.
436              If there is only one attribute it will be that one.
437
438        """
439
440        # FIXME(DSG-DSG): Check that the vert and point coords
441        # have the same zone.
442        if isinstance(point_coordinates,Geospatial_data):
443            point_coordinates = point_coordinates.get_data_points( \
444                absolute = True)
445       
446        # Convert input to numeric arrays
447        if z is not None:
448            z = ensure_numeric(z, num.float)
449        else:
450            msg = 'z not specified'
451            assert isinstance(point_coordinates,Geospatial_data), msg
452            z = point_coordinates.get_attributes(attribute_name)
453
454        point_coordinates = ensure_numeric(point_coordinates, num.float)
455        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
456
457
458############################################################################
459
460def fit_to_mesh(point_coordinates, # this can also be a points file name
461                vertex_coordinates=None,
462                triangles=None,
463                mesh=None,
464                point_attributes=None,
465                alpha=DEFAULT_ALPHA,
466                verbose=False,
467                mesh_origin=None,
468                data_origin=None,
469                max_read_lines=None,
470                attribute_name=None,
471                use_cache=False):
472    """Wrapper around internal function _fit_to_mesh for use with caching.
473   
474    """
475   
476    args = (point_coordinates, )
477    kwargs = {'vertex_coordinates': vertex_coordinates,
478              'triangles': triangles,
479              'mesh': mesh,
480              'point_attributes': point_attributes,
481              'alpha': alpha,
482              'verbose': verbose,
483              'mesh_origin': mesh_origin,
484              'data_origin': data_origin,
485              'max_read_lines': max_read_lines,
486              'attribute_name': attribute_name
487              }
488
489    if use_cache is True:
490        if isinstance(point_coordinates, basestring):
491            # We assume that point_coordinates is the name of a .csv/.txt
492            # file which must be passed onto caching as a dependency
493            # (in case it has changed on disk)
494            dep = [point_coordinates]
495        else:
496            dep = None
497
498           
499        #from caching import myhash
500        #import copy
501        #print args
502        #print kwargs
503        #print 'hashing:'
504        #print 'args', myhash( (args, kwargs) )
505        #print 'again', myhash( copy.deepcopy( (args, kwargs)) )       
506       
507        #print 'mesh hash', myhash( kwargs['mesh'] )       
508       
509        #print '-------------------------'
510        #print 'vertices hash', myhash( kwargs['mesh'].nodes )
511        #print 'triangles hash', myhash( kwargs['mesh'].triangles )
512        #print '-------------------------'       
513       
514        #for key in mesh.__dict__:
515        #    print key, myhash(mesh.__dict__[key])
516       
517        #for key in mesh.quantities.keys():
518        #    print key, myhash(mesh.quantities[key])
519       
520        #import sys; sys.exit()
521           
522        return cache(_fit_to_mesh,
523                     args, kwargs,
524                     verbose=verbose,
525                     compression=False,
526                     dependencies=dep)
527    else:
528        return apply(_fit_to_mesh,
529                     args, kwargs)
530
531def _fit_to_mesh(point_coordinates, # this can also be a points file name
532                 vertex_coordinates=None,
533                 triangles=None,
534                 mesh=None,
535                 point_attributes=None,
536                 alpha=DEFAULT_ALPHA,
537                 verbose=False,
538                 mesh_origin=None,
539                 data_origin=None,
540                 max_read_lines=None,
541                 attribute_name=None):
542    """
543    Fit a smooth surface to a triangulation,
544    given data points with attributes.
545
546
547        Inputs:
548        vertex_coordinates: List of coordinate pairs [xi, eta] of
549              points constituting a mesh (or an m x 2 numeric array or
550              a geospatial object)
551              Points may appear multiple times
552              (e.g. if vertices have discontinuities)
553
554          triangles: List of 3-tuples (or a numeric array) of
555          integers representing indices of all vertices in the mesh.
556
557          point_coordinates: List of coordinate pairs [x, y] of data points
558          (or an nx2 numeric array). This can also be a .csv/.txt/.pts
559          file name.
560
561          alpha: Smoothing parameter.
562
563          mesh_origin: A geo_reference object or 3-tuples consisting of
564              UTM zone, easting and northing.
565              If specified vertex coordinates are assumed to be
566              relative to their respective origins.
567
568          point_attributes: Vector or array of data at the
569                            point_coordinates.
570
571    """
572
573    if mesh is None:
574        # FIXME(DSG): Throw errors if triangles or vertex_coordinates
575        # are None
576           
577        #Convert input to numeric arrays
578        triangles = ensure_numeric(triangles, num.int)
579        vertex_coordinates = ensure_absolute(vertex_coordinates,
580                                             geo_reference = mesh_origin)
581
582        if verbose: log.critical('FitInterpolate: Building mesh')
583        mesh = Mesh(vertex_coordinates, triangles)
584        mesh.check_integrity()
585   
586   
587    interp = Fit(mesh=mesh,
588                 verbose=verbose,
589                 alpha=alpha)
590
591    vertex_attributes = interp.fit(point_coordinates,
592                                   point_attributes,
593                                   point_origin=data_origin,
594                                   max_read_lines=max_read_lines,
595                                   attribute_name=attribute_name,
596                                   verbose=verbose)
597
598       
599    # Add the value checking stuff that's in least squares.
600    # Maybe this stuff should get pushed down into Fit.
601    # at least be a method of Fit.
602    # Or intigrate it into the fit method, saving teh max and min's
603    # as att's.
604   
605    return vertex_attributes
606
607
608#def _fit(*args, **kwargs):
609#    """Private function for use with caching. Reason is that classes
610#    may change their byte code between runs which is annoying.
611#    """
612#   
613#    return Fit(*args, **kwargs)
614
615
616def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
617                     alpha=DEFAULT_ALPHA, verbose= False,
618                     expand_search = False,
619                     precrop = False,
620                     display_errors = True):
621    """
622    Given a mesh file (tsh) and a point attribute file, fit
623    point attributes to the mesh and write a mesh file with the
624    results.
625
626    Note: the points file needs titles.  If you want anuga to use the tsh file,
627    make sure the title is elevation.
628
629    NOTE: Throws IOErrors, for a variety of file problems.
630   
631    """
632
633    from load_mesh.loadASCII import import_mesh_file, \
634         export_mesh_file, concatinate_attributelist
635
636
637    try:
638        mesh_dict = import_mesh_file(mesh_file)
639    except IOError,e:
640        if display_errors:
641            log.critical("Could not load bad file: %s" % str(e))
642        raise IOError  #Could not load bad mesh file.
643   
644    vertex_coordinates = mesh_dict['vertices']
645    triangles = mesh_dict['triangles']
646    if isinstance(mesh_dict['vertex_attributes'], num.ndarray):
647        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
648    else:
649        old_point_attributes = mesh_dict['vertex_attributes']
650
651    if isinstance(mesh_dict['vertex_attribute_titles'], num.ndarray):
652        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
653    else:
654        old_title_list = mesh_dict['vertex_attribute_titles']
655
656    if verbose: log.critical('tsh file %s loaded' % mesh_file)
657
658    # load in the points file
659    try:
660        geo = Geospatial_data(point_file, verbose=verbose)
661    except IOError,e:
662        if display_errors:
663            log.critical("Could not load bad file: %s" % str(e))
664        raise IOError  #Re-raise exception 
665
666    point_coordinates = geo.get_data_points(absolute=True)
667    title_list,point_attributes = concatinate_attributelist( \
668        geo.get_all_attributes())
669
670    if mesh_dict.has_key('geo_reference') and \
671           not mesh_dict['geo_reference'] is None:
672        mesh_origin = mesh_dict['geo_reference'].get_origin()
673    else:
674        mesh_origin = None
675
676    if verbose: log.critical("points file loaded")
677    if verbose: log.critical("fitting to mesh")
678    f = fit_to_mesh(point_coordinates,
679                    vertex_coordinates,
680                    triangles,
681                    None,
682                    point_attributes,
683                    alpha = alpha,
684                    verbose = verbose,
685                    data_origin = None,
686                    mesh_origin = mesh_origin)
687    if verbose: log.critical("finished fitting to mesh")
688
689    # convert array to list of lists
690    new_point_attributes = f.tolist()
691    #FIXME have this overwrite attributes with the same title - DSG
692    #Put the newer attributes last
693    if old_title_list <> []:
694        old_title_list.extend(title_list)
695        #FIXME can this be done a faster way? - DSG
696        for i in range(len(old_point_attributes)):
697            old_point_attributes[i].extend(new_point_attributes[i])
698        mesh_dict['vertex_attributes'] = old_point_attributes
699        mesh_dict['vertex_attribute_titles'] = old_title_list
700    else:
701        mesh_dict['vertex_attributes'] = new_point_attributes
702        mesh_dict['vertex_attribute_titles'] = title_list
703
704    if verbose: log.critical("exporting to file %s" % mesh_output_file)
705
706    try:
707        export_mesh_file(mesh_output_file, mesh_dict)
708    except IOError,e:
709        if display_errors:
710            log.critical("Could not write file %s", str(e))
711        raise IOError
Note: See TracBrowser for help on using the repository browser.