source: anuga_core/source/anuga/fit_interpolate/fit.py @ 4861

Last change on this file since 4861 was 4861, checked in by duncan, 16 years ago

Fit speed up for when there is a lot of blocking

File size: 25.4 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18
19   TO DO
20   * test geo_ref, geo_spatial
21
22   IDEAS
23   * (DSG-) Change the interface of fit, so a domain object can
24      be passed in. (I don't know if this is feasible). If could
25      save time/memory.
26"""
27import types
28
29from Numeric import zeros, Float, ArrayType,take, Int
30
31from anuga.abstract_2d_finite_volumes.neighbour_mesh import Mesh
32from anuga.caching import cache           
33from anuga.geospatial_data.geospatial_data import Geospatial_data, \
34     ensure_absolute
35from anuga.fit_interpolate.general_fit_interpolate import FitInterpolate
36from anuga.utilities.sparse import Sparse, Sparse_CSR
37from anuga.utilities.polygon import in_and_outside_polygon
38from anuga.fit_interpolate.search_functions import search_tree_of_vertices
39
40from anuga.utilities.cg_solve import conjugate_gradient
41from anuga.utilities.numerical_tools import ensure_numeric, gradient
42
43import exceptions
44class TooFewPointsError(exceptions.Exception): pass
45class VertsWithNoTrianglesError(exceptions.Exception): pass
46
47DEFAULT_ALPHA = 0.001
48
49
50class Fit(FitInterpolate):
51   
52    def __init__(self,
53                 vertex_coordinates=None,
54                 triangles=None,
55                 mesh=None,
56                 mesh_origin=None,
57                 alpha = None,
58                 verbose=False,
59                 max_vertices_per_cell=None):
60
61
62        """
63        Fit data at points to the vertices of a mesh.
64
65        Inputs:
66
67          vertex_coordinates: List of coordinate pairs [xi, eta] of
68              points constituting a mesh (or an m x 2 Numeric array or
69              a geospatial object)
70              Points may appear multiple times
71              (e.g. if vertices have discontinuities)
72
73          triangles: List of 3-tuples (or a Numeric array) of
74              integers representing indices of all vertices in the mesh.
75
76          mesh_origin: A geo_reference object or 3-tuples consisting of
77              UTM zone, easting and northing.
78              If specified vertex coordinates are assumed to be
79              relative to their respective origins.
80
81          max_vertices_per_cell: Number of vertices in a quad tree cell
82          at which the cell is split into 4.
83
84          Note: Don't supply a vertex coords as a geospatial object and
85              a mesh origin, since geospatial has its own mesh origin.
86
87
88        Usage,
89        To use this in a blocking way, call  build_fit_subset, with z info,
90        and then fit, with no point coord, z info.
91       
92        """
93        # Initialise variabels
94        if alpha is None:
95
96            self.alpha = DEFAULT_ALPHA
97        else:   
98            self.alpha = alpha
99           
100        FitInterpolate.__init__(self,
101                 vertex_coordinates,
102                 triangles,
103                 mesh,
104                 mesh_origin,
105                 verbose,
106                 max_vertices_per_cell)
107       
108        m = self.mesh.number_of_nodes # Nbr of basis functions (vertices)
109       
110        self.AtA = None
111        self.Atz = None
112
113        self.point_count = 0
114        if self.alpha <> 0:
115            if verbose: print 'Building smoothing matrix'
116            self._build_smoothing_matrix_D()
117           
118        self.mesh_boundary_polygon = self.mesh.get_boundary_polygon()   
119           
120    def _build_coefficient_matrix_B(self,
121                                  verbose = False):
122        """
123        Build final coefficient matrix
124
125        Precon
126        If alpha is not zero, matrix D has been built
127        Matrix Ata has been built
128        """
129
130        if self.alpha <> 0:
131            #if verbose: print 'Building smoothing matrix'
132            #self._build_smoothing_matrix_D()
133            self.B = self.AtA + self.alpha*self.D
134        else:
135            self.B = self.AtA
136
137        #Convert self.B matrix to CSR format for faster matrix vector
138        self.B = Sparse_CSR(self.B)
139
140    def _build_smoothing_matrix_D(self):
141        """Build m x m smoothing matrix, where
142        m is the number of basis functions phi_k (one per vertex)
143
144        The smoothing matrix is defined as
145
146        D = D1 + D2
147
148        where
149
150        [D1]_{k,l} = \int_\Omega
151           \frac{\partial \phi_k}{\partial x}
152           \frac{\partial \phi_l}{\partial x}\,
153           dx dy
154
155        [D2]_{k,l} = \int_\Omega
156           \frac{\partial \phi_k}{\partial y}
157           \frac{\partial \phi_l}{\partial y}\,
158           dx dy
159
160
161        The derivatives \frac{\partial \phi_k}{\partial x},
162        \frac{\partial \phi_k}{\partial x} for a particular triangle
163        are obtained by computing the gradient a_k, b_k for basis function k
164        """
165       
166        #FIXME: algorithm might be optimised by computing local 9x9
167        #"element stiffness matrices:
168
169        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
170
171        self.D = Sparse(m,m)
172
173        #For each triangle compute contributions to D = D1+D2
174        for i in range(len(self.mesh)):
175
176            #Get area
177            area = self.mesh.areas[i]
178
179            #Get global vertex indices
180            v0 = self.mesh.triangles[i,0]
181            v1 = self.mesh.triangles[i,1]
182            v2 = self.mesh.triangles[i,2]
183
184            #Get the three vertex_points
185            xi0 = self.mesh.get_vertex_coordinate(i, 0)
186            xi1 = self.mesh.get_vertex_coordinate(i, 1)
187            xi2 = self.mesh.get_vertex_coordinate(i, 2)
188
189            #Compute gradients for each vertex
190            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
191                              1, 0, 0)
192
193            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
194                              0, 1, 0)
195
196            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
197                              0, 0, 1)
198
199            #Compute diagonal contributions
200            self.D[v0,v0] += (a0*a0 + b0*b0)*area
201            self.D[v1,v1] += (a1*a1 + b1*b1)*area
202            self.D[v2,v2] += (a2*a2 + b2*b2)*area
203
204            #Compute contributions for basis functions sharing edges
205            e01 = (a0*a1 + b0*b1)*area
206            self.D[v0,v1] += e01
207            self.D[v1,v0] += e01
208
209            e12 = (a1*a2 + b1*b2)*area
210            self.D[v1,v2] += e12
211            self.D[v2,v1] += e12
212
213            e20 = (a2*a0 + b2*b0)*area
214            self.D[v2,v0] += e20
215            self.D[v0,v2] += e20
216
217    def get_D(self):
218        return self.D.todense()
219
220
221
222    def _build_matrix_AtA_Atz(self,
223                              point_coordinates,
224                              z,
225                              verbose = False):
226        """Build:
227        AtA  m x m  interpolation matrix, and,
228        Atz  m x a  interpolation matrix where,
229        m is the number of basis functions phi_k (one per vertex)
230        a is the number of data attributes
231
232        This algorithm uses a quad tree data structure for fast binning of
233        data points.
234
235        If Ata is None, the matrices AtA and Atz are created.
236
237        This function can be called again and again, with sub-sets of
238        the point coordinates.  Call fit to get the results.
239       
240        Preconditions
241        z and points are numeric
242        Point_coordindates and mesh vertices have the same origin.
243
244        The number of attributes of the data points does not change
245        """
246        #Build n x m interpolation matrix
247
248        if self.AtA == None:
249            # AtA and Atz need to be initialised.
250            m = self.mesh.number_of_nodes
251            if len(z.shape) > 1:
252                att_num = z.shape[1]
253                self.Atz = zeros((m,att_num), Float)
254            else:
255                att_num = 1
256                self.Atz = zeros((m,), Float)
257            assert z.shape[0] == point_coordinates.shape[0] 
258
259            AtA = Sparse(m,m)
260            # The memory damage has been done by now.
261        else:
262             AtA = self.AtA #Did this for speed, did ~nothing
263        self.point_count += point_coordinates.shape[0]
264        #print "_build_matrix_AtA_Atz - self.point_count", self.point_count
265        if verbose: print 'Getting indices inside mesh boundary'
266        #print 'point_coordinates.shape', point_coordinates.shape         
267        #print 'self.mesh.get_boundary_polygon()',\
268        #      self.mesh.get_boundary_polygon()
269
270        inside_poly_indices, outside_poly_indices  = \
271                     in_and_outside_polygon(point_coordinates,
272                                            self.mesh_boundary_polygon,
273                                            closed = True, verbose = verbose)
274        #print "self.inside_poly_indices",self.inside_poly_indices
275        #print "self.outside_poly_indices",self.outside_poly_indices
276
277       
278        n = len(inside_poly_indices)
279        if verbose: print 'Building fitting matrix from %d points' %n       
280        #Compute matrix elements for points inside the mesh
281        triangles = self.mesh.triangles #Did this for speed, did ~nothing
282        for d, i in enumerate(inside_poly_indices):
283            #For each data_coordinate point
284            if verbose and d%((n+10)/10)==0: print 'Doing %d of %d' %(d, n)
285            x = point_coordinates[i]
286            element_found, sigma0, sigma1, sigma2, k = \
287                           search_tree_of_vertices(self.root, self.mesh, x)
288           
289            if element_found is True:
290                j0 = triangles[k,0] #Global vertex id for sigma0
291                j1 = triangles[k,1] #Global vertex id for sigma1
292                j2 = triangles[k,2] #Global vertex id for sigma2
293
294                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
295                js     = [j0,j1,j2]
296
297                for j in js:
298                    self.Atz[j] +=  sigmas[j]*z[i]
299                    #print "self.Atz building", self.Atz
300                    #print "self.Atz[j]", self.Atz[j]
301                    #print " sigmas[j]", sigmas[j]
302                    #print "z[i]",z[i]
303                    #print "result", sigmas[j]*z[i]
304                   
305                    for k in js:
306                        AtA[j,k] += sigmas[j]*sigmas[k]
307            else:
308                msg = 'Could not find triangle for point', x
309                raise Exception(msg)
310            self.AtA = AtA
311       
312    def fit(self, point_coordinates_or_filename=None, z=None,
313            verbose=False,
314            point_origin=None,
315            attribute_name=None,
316            max_read_lines=500):
317        """Fit a smooth surface to given 1d array of data points z.
318
319        The smooth surface is computed at each vertex in the underlying
320        mesh using the formula given in the module doc string.
321
322        Inputs:
323        point_coordinates: The co-ordinates of the data points.
324              List of coordinate pairs [x, y] of
325              data points or an nx2 Numeric array or a Geospatial_data object
326              or points file filename
327          z: Single 1d vector or array of data at the point_coordinates.
328         
329        """
330        # use blocking to load in the point info
331        if type(point_coordinates_or_filename) == types.StringType:
332            msg = "Don't set a point origin when reading from a file"
333            assert point_origin is None, msg
334            filename = point_coordinates_or_filename
335
336            G_data = Geospatial_data(filename,
337                                     max_read_lines=max_read_lines,
338                                     load_file_now=False,
339                                     verbose=verbose)
340
341            for i, geo_block in enumerate(G_data):
342                if verbose is True and 0 == i%200: 
343                    # The time this will take
344                    # is dependant on the # of Triangles
345                       
346                    print 'Processing Block %d' %i
347                    # FIXME (Ole): It would be good to say how many blocks
348                    # there are here. But this is no longer necessary
349                    # for pts files as they are reported in geospatial_data
350                    # I suggest deleting this verbose output and make
351                    # Geospatial_data more informative for txt files.
352           
353
354                   
355                # Build the array
356
357                points = geo_block.get_data_points(absolute=True)
358                #print "fit points", points
359                z = geo_block.get_attributes(attribute_name=attribute_name)
360                self.build_fit_subset(points, z, verbose=verbose)
361
362               
363            point_coordinates = None
364        else:
365            point_coordinates =  point_coordinates_or_filename
366           
367        if point_coordinates is None:
368            assert self.AtA <> None
369            assert self.Atz <> None
370            #FIXME (DSG) - do  a message
371        else:
372            point_coordinates = ensure_absolute(point_coordinates,
373                                                geo_reference=point_origin)
374            #if isinstance(point_coordinates,Geospatial_data) and z is None:
375            # z will come from the geo-ref
376            self.build_fit_subset(point_coordinates, z, verbose)
377
378        #Check sanity
379        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
380        n = self.point_count
381        if n<m and self.alpha == 0.0:
382            msg = 'ERROR (least_squares): Too few data points\n'
383            msg += 'There are only %d data points and alpha == 0. ' %n
384            msg += 'Need at least %d\n' %m
385            msg += 'Alternatively, set smoothing parameter alpha to a small '
386            msg += 'positive value,\ne.g. 1.0e-3.'
387            raise TooFewPointsError(msg)
388
389        self._build_coefficient_matrix_B(verbose)
390        loners = self.mesh.get_lone_vertices()
391        # FIXME  - make this as error message.
392        # test with
393        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
394        if len(loners)>0:
395            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
396            msg += 'All vertices should be part of a triangle.\n'
397            msg += 'In the future this will be inforced.\n'
398            msg += 'The following vertices are not part of a triangle;\n'
399            msg += str(loners)
400            print msg
401            #raise VertsWithNoTrianglesError(msg)
402       
403       
404        return conjugate_gradient(self.B, self.Atz, self.Atz,
405                                  imax=2*len(self.Atz) )
406
407       
408    def build_fit_subset(self, point_coordinates, z=None, attribute_name=None,
409                              verbose=False):
410        """Fit a smooth surface to given 1d array of data points z.
411
412        The smooth surface is computed at each vertex in the underlying
413        mesh using the formula given in the module doc string.
414
415        Inputs:
416        point_coordinates: The co-ordinates of the data points.
417              List of coordinate pairs [x, y] of
418              data points or an nx2 Numeric array or a Geospatial_data object
419        z: Single 1d vector or array of data at the point_coordinates.
420        attribute_name: Used to get the z values from the
421              geospatial object if no attribute_name is specified,
422              it's a bit of a lucky dip as to what attributes you get.
423              If there is only one attribute it will be that one.
424
425        """
426
427        # FIXME(DSG-DSG): Check that the vert and point coords
428        # have the same zone.
429        if isinstance(point_coordinates,Geospatial_data):
430            point_coordinates = point_coordinates.get_data_points( \
431                absolute = True)
432       
433        # Convert input to Numeric arrays
434        if z is not None:
435            z = ensure_numeric(z, Float)
436        else:
437            msg = 'z not specified'
438            assert isinstance(point_coordinates,Geospatial_data), msg
439            z = point_coordinates.get_attributes(attribute_name)
440
441        point_coordinates = ensure_numeric(point_coordinates, Float)
442        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
443
444
445############################################################################
446
447def fit_to_mesh(point_coordinates, # this can also be a points file name
448                vertex_coordinates=None,
449                triangles=None,
450                mesh=None,
451                point_attributes=None,
452                alpha=DEFAULT_ALPHA,
453                verbose=False,
454                acceptable_overshoot=1.01,
455                mesh_origin=None,
456                data_origin=None,
457                max_read_lines=None,
458                attribute_name=None,
459                use_cache = False):
460    """Wrapper around internal function _fit_to_mesh for use with caching.
461   
462    """
463   
464    args = (point_coordinates, )
465    kwargs = {'vertex_coordinates': vertex_coordinates,
466              'triangles': triangles,
467              'mesh': mesh,
468              'point_attributes': point_attributes,
469              'alpha': alpha,
470              'verbose': verbose,
471              'acceptable_overshoot': acceptable_overshoot,
472              'mesh_origin': mesh_origin,
473              'data_origin': data_origin,
474              'max_read_lines': max_read_lines,
475              'attribute_name': attribute_name,
476              'use_cache':use_cache
477              }
478
479    if use_cache is True:
480        if isinstance(point_coordinates, basestring):
481            # We assume that point_coordinates is the name of a .csv/.txt
482            # file which must be passed onto caching as a dependency
483            # (in case it has changed on disk)
484            dep = [point_coordinates]
485        else:
486            dep = None
487
488        return cache(_fit_to_mesh,
489                     args, kwargs,
490                     verbose=verbose,
491                     compression=False,
492                     dependencies=dep)
493    else:
494        return apply(_fit_to_mesh,
495                     args, kwargs)
496
497def _fit_to_mesh(point_coordinates, # this can also be a points file name
498                 vertex_coordinates=None,
499                 triangles=None,
500                 mesh=None,
501                 point_attributes=None,
502                 alpha=DEFAULT_ALPHA,
503                 verbose=False,
504                 acceptable_overshoot=1.01,
505                 mesh_origin=None,
506                 data_origin=None,
507                 max_read_lines=None,
508                 attribute_name=None,
509                 use_cache = False):
510    """
511    Fit a smooth surface to a triangulation,
512    given data points with attributes.
513
514
515        Inputs:
516        vertex_coordinates: List of coordinate pairs [xi, eta] of
517              points constituting a mesh (or an m x 2 Numeric array or
518              a geospatial object)
519              Points may appear multiple times
520              (e.g. if vertices have discontinuities)
521
522          triangles: List of 3-tuples (or a Numeric array) of
523          integers representing indices of all vertices in the mesh.
524
525          point_coordinates: List of coordinate pairs [x, y] of data points
526          (or an nx2 Numeric array). This can also be a .csv/.txt/.pts
527          file name.
528
529          alpha: Smoothing parameter.
530
531          acceptable overshoot: NOT IMPLEMENTED
532          controls the allowed factor by which
533          fitted values
534          may exceed the value of input data. The lower limit is defined
535          as min(z) - acceptable_overshoot*delta z and upper limit
536          as max(z) + acceptable_overshoot*delta z
537         
538
539          mesh_origin: A geo_reference object or 3-tuples consisting of
540              UTM zone, easting and northing.
541              If specified vertex coordinates are assumed to be
542              relative to their respective origins.
543         
544
545          point_attributes: Vector or array of data at the
546                            point_coordinates.
547
548    """
549
550    # Duncan and Ole think that this isn't worth caching.
551    # Caching happens at the higher level anyway.
552   
553
554    if mesh is None:
555        # FIXME(DSG): Throw errors if triangles or vertex_coordinates
556        # are None
557           
558        #Convert input to Numeric arrays
559        triangles = ensure_numeric(triangles, Int)
560        vertex_coordinates = ensure_absolute(vertex_coordinates,
561                                             geo_reference = mesh_origin)
562
563        if verbose: print 'FitInterpolate: Building mesh'       
564        mesh = Mesh(vertex_coordinates, triangles)
565        mesh.check_integrity()
566   
567    interp = Fit(mesh=mesh,
568                 verbose=verbose,
569                 alpha=alpha)
570
571    vertex_attributes = interp.fit(point_coordinates,
572                                   point_attributes,
573                                   point_origin=data_origin,
574                                   max_read_lines=max_read_lines,
575                                   attribute_name=attribute_name,
576                                   verbose=verbose)
577
578       
579    # Add the value checking stuff that's in least squares.
580    # Maybe this stuff should get pushed down into Fit.
581    # at least be a method of Fit.
582    # Or intigrate it into the fit method, saving teh max and min's
583    # as att's.
584   
585    return vertex_attributes
586
587
588#def _fit(*args, **kwargs):
589#    """Private function for use with caching. Reason is that classes
590#    may change their byte code between runs which is annoying.
591#    """
592#   
593#    return Fit(*args, **kwargs)
594
595
596def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
597                     alpha=DEFAULT_ALPHA, verbose= False,
598                     expand_search = False,
599                     precrop = False,
600                     display_errors = True):
601    """
602    Given a mesh file (tsh) and a point attribute file, fit
603    point attributes to the mesh and write a mesh file with the
604    results.
605
606    Note: the points file needs titles.  If you want anuga to use the tsh file,
607    make sure the title is elevation.
608
609    NOTE: Throws IOErrors, for a variety of file problems.
610   
611    """
612
613    from load_mesh.loadASCII import import_mesh_file, \
614         export_mesh_file, concatinate_attributelist
615
616
617    try:
618        mesh_dict = import_mesh_file(mesh_file)
619    except IOError,e:
620        if display_errors:
621            print "Could not load bad file. ", e
622        raise IOError  #Could not load bad mesh file.
623   
624    vertex_coordinates = mesh_dict['vertices']
625    triangles = mesh_dict['triangles']
626    if type(mesh_dict['vertex_attributes']) == ArrayType:
627        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
628    else:
629        old_point_attributes = mesh_dict['vertex_attributes']
630
631    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
632        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
633    else:
634        old_title_list = mesh_dict['vertex_attribute_titles']
635
636    if verbose: print 'tsh file %s loaded' %mesh_file
637
638    # load in the points file
639    try:
640        geo = Geospatial_data(point_file, verbose=verbose)
641    except IOError,e:
642        if display_errors:
643            print "Could not load bad file. ", e
644        raise IOError  #Re-raise exception 
645
646    point_coordinates = geo.get_data_points(absolute=True)
647    title_list,point_attributes = concatinate_attributelist( \
648        geo.get_all_attributes())
649
650    if mesh_dict.has_key('geo_reference') and \
651           not mesh_dict['geo_reference'] is None:
652        mesh_origin = mesh_dict['geo_reference'].get_origin()
653    else:
654        mesh_origin = None
655
656    if verbose: print "points file loaded"
657    if verbose: print "fitting to mesh"
658    f = fit_to_mesh(point_coordinates,
659                    vertex_coordinates,
660                    triangles,
661                    None,
662                    point_attributes,
663                    alpha = alpha,
664                    verbose = verbose,
665                    data_origin = None,
666                    mesh_origin = mesh_origin)
667    if verbose: print "finished fitting to mesh"
668
669    # convert array to list of lists
670    new_point_attributes = f.tolist()
671    #FIXME have this overwrite attributes with the same title - DSG
672    #Put the newer attributes last
673    if old_title_list <> []:
674        old_title_list.extend(title_list)
675        #FIXME can this be done a faster way? - DSG
676        for i in range(len(old_point_attributes)):
677            old_point_attributes[i].extend(new_point_attributes[i])
678        mesh_dict['vertex_attributes'] = old_point_attributes
679        mesh_dict['vertex_attribute_titles'] = old_title_list
680    else:
681        mesh_dict['vertex_attributes'] = new_point_attributes
682        mesh_dict['vertex_attribute_titles'] = title_list
683
684    if verbose: print "exporting to file ", mesh_output_file
685
686    try:
687        export_mesh_file(mesh_output_file, mesh_dict)
688    except IOError,e:
689        if display_errors:
690            print "Could not write file. ", e
691        raise IOError
Note: See TracBrowser for help on using the repository browser.