source: anuga_core/source/anuga/fit_interpolate/fit.py @ 4596

Last change on this file since 4596 was 4589, checked in by ole, 18 years ago

Fixed up dangerous double use of one variable

File size: 24.7 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18
19   TO DO
20   * test geo_ref, geo_spatial
21
22   IDEAS
23   * (DSG-) Change the interface of fit, so a domain object can
24      be passed in. (I don't know if this is feasible). If could
25      save time/memory.
26"""
27import types
28
29from Numeric import zeros, Float, ArrayType,take
30
31from anuga.caching import cache           
32from anuga.geospatial_data.geospatial_data import Geospatial_data, \
33     ensure_absolute
34from anuga.fit_interpolate.general_fit_interpolate import FitInterpolate
35from anuga.utilities.sparse import Sparse, Sparse_CSR
36from anuga.utilities.polygon import in_and_outside_polygon
37from anuga.fit_interpolate.search_functions import search_tree_of_vertices
38from anuga.utilities.cg_solve import conjugate_gradient
39from anuga.utilities.numerical_tools import ensure_numeric, gradient
40
41import exceptions
42class ToFewPointsError(exceptions.Exception): pass
43class VertsWithNoTrianglesError(exceptions.Exception): pass
44
45DEFAULT_ALPHA = 0.001
46
47
48class Fit(FitInterpolate):
49   
50    def __init__(self,
51                 vertex_coordinates,
52                 triangles,
53                 mesh_origin=None,
54                 alpha = None,
55                 verbose=False,
56                 max_vertices_per_cell=30):
57
58
59        """
60        Fit data at points to the vertices of a mesh.
61
62        Inputs:
63
64          vertex_coordinates: List of coordinate pairs [xi, eta] of
65              points constituting a mesh (or an m x 2 Numeric array or
66              a geospatial object)
67              Points may appear multiple times
68              (e.g. if vertices have discontinuities)
69
70          triangles: List of 3-tuples (or a Numeric array) of
71              integers representing indices of all vertices in the mesh.
72
73          mesh_origin: A geo_reference object or 3-tuples consisting of
74              UTM zone, easting and northing.
75              If specified vertex coordinates are assumed to be
76              relative to their respective origins.
77
78          max_vertices_per_cell: Number of vertices in a quad tree cell
79          at which the cell is split into 4.
80
81          Note: Don't supply a vertex coords as a geospatial object and
82              a mesh origin, since geospatial has its own mesh origin.
83
84
85        Usage,
86        To use this in a blocking way, call  build_fit_subset, with z info,
87        and then fit, with no point coord, z info.
88       
89        """
90        # Initialise variabels
91
92        if alpha is None:
93
94            self.alpha = DEFAULT_ALPHA
95        else:   
96            self.alpha = alpha
97        FitInterpolate.__init__(self,
98                 vertex_coordinates,
99                 triangles,
100                 mesh_origin,
101                 verbose,
102                 max_vertices_per_cell)
103       
104        m = self.mesh.number_of_nodes # Nbr of basis functions (vertices)
105       
106        self.AtA = None
107        self.Atz = None
108
109        self.point_count = 0
110        if self.alpha <> 0:
111            if verbose: print 'Building smoothing matrix'
112            self._build_smoothing_matrix_D()
113           
114    def _build_coefficient_matrix_B(self,
115                                  verbose = False):
116        """
117        Build final coefficient matrix
118
119        Precon
120        If alpha is not zero, matrix D has been built
121        Matrix Ata has been built
122        """
123
124        if self.alpha <> 0:
125            #if verbose: print 'Building smoothing matrix'
126            #self._build_smoothing_matrix_D()
127            self.B = self.AtA + self.alpha*self.D
128        else:
129            self.B = self.AtA
130
131        #Convert self.B matrix to CSR format for faster matrix vector
132        self.B = Sparse_CSR(self.B)
133
134    def _build_smoothing_matrix_D(self):
135        """Build m x m smoothing matrix, where
136        m is the number of basis functions phi_k (one per vertex)
137
138        The smoothing matrix is defined as
139
140        D = D1 + D2
141
142        where
143
144        [D1]_{k,l} = \int_\Omega
145           \frac{\partial \phi_k}{\partial x}
146           \frac{\partial \phi_l}{\partial x}\,
147           dx dy
148
149        [D2]_{k,l} = \int_\Omega
150           \frac{\partial \phi_k}{\partial y}
151           \frac{\partial \phi_l}{\partial y}\,
152           dx dy
153
154
155        The derivatives \frac{\partial \phi_k}{\partial x},
156        \frac{\partial \phi_k}{\partial x} for a particular triangle
157        are obtained by computing the gradient a_k, b_k for basis function k
158        """
159       
160        #FIXME: algorithm might be optimised by computing local 9x9
161        #"element stiffness matrices:
162
163        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
164
165        self.D = Sparse(m,m)
166
167        #For each triangle compute contributions to D = D1+D2
168        for i in range(len(self.mesh)):
169
170            #Get area
171            area = self.mesh.areas[i]
172
173            #Get global vertex indices
174            v0 = self.mesh.triangles[i,0]
175            v1 = self.mesh.triangles[i,1]
176            v2 = self.mesh.triangles[i,2]
177
178            #Get the three vertex_points
179            xi0 = self.mesh.get_vertex_coordinate(i, 0)
180            xi1 = self.mesh.get_vertex_coordinate(i, 1)
181            xi2 = self.mesh.get_vertex_coordinate(i, 2)
182
183            #Compute gradients for each vertex
184            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
185                              1, 0, 0)
186
187            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
188                              0, 1, 0)
189
190            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
191                              0, 0, 1)
192
193            #Compute diagonal contributions
194            self.D[v0,v0] += (a0*a0 + b0*b0)*area
195            self.D[v1,v1] += (a1*a1 + b1*b1)*area
196            self.D[v2,v2] += (a2*a2 + b2*b2)*area
197
198            #Compute contributions for basis functions sharing edges
199            e01 = (a0*a1 + b0*b1)*area
200            self.D[v0,v1] += e01
201            self.D[v1,v0] += e01
202
203            e12 = (a1*a2 + b1*b2)*area
204            self.D[v1,v2] += e12
205            self.D[v2,v1] += e12
206
207            e20 = (a2*a0 + b2*b0)*area
208            self.D[v2,v0] += e20
209            self.D[v0,v2] += e20
210
211
212    def get_D(self):
213        return self.D.todense()
214
215
216    def _build_matrix_AtA_Atz(self,
217                              point_coordinates,
218                              z,
219                              verbose = False):
220        """Build:
221        AtA  m x m  interpolation matrix, and,
222        Atz  m x a  interpolation matrix where,
223        m is the number of basis functions phi_k (one per vertex)
224        a is the number of data attributes
225
226        This algorithm uses a quad tree data structure for fast binning of
227        data points.
228
229        If Ata is None, the matrices AtA and Atz are created.
230
231        This function can be called again and again, with sub-sets of
232        the point coordinates.  Call fit to get the results.
233       
234        Preconditions
235        z and points are numeric
236        Point_coordindates and mesh vertices have the same origin.
237
238        The number of attributes of the data points does not change
239        """
240        #Build n x m interpolation matrix
241
242        if self.AtA == None:
243            # AtA and Atz need to be initialised.
244            m = self.mesh.number_of_nodes
245            if len(z.shape) > 1:
246                att_num = z.shape[1]
247                self.Atz = zeros((m,att_num), Float)
248            else:
249                att_num = 1
250                self.Atz = zeros((m,), Float)
251            assert z.shape[0] == point_coordinates.shape[0] 
252
253            self.AtA = Sparse(m,m)
254            # The memory damage has been done by now.
255           
256        self.point_count += point_coordinates.shape[0]
257        #print "_build_matrix_AtA_Atz - self.point_count", self.point_count
258        if verbose: print 'Getting indices inside mesh boundary'
259        #print 'point_coordinates.shape', point_coordinates.shape         
260        #print 'self.mesh.get_boundary_polygon()',\
261        #      self.mesh.get_boundary_polygon()
262
263        inside_poly_indices, outside_poly_indices  = \
264                     in_and_outside_polygon(point_coordinates,
265                                            self.mesh.get_boundary_polygon(),
266                                            closed = True, verbose = verbose)
267        #print "self.inside_poly_indices",self.inside_poly_indices
268        #print "self.outside_poly_indices",self.outside_poly_indices
269
270       
271        n = len(inside_poly_indices)
272        if verbose: print 'Building fitting matrix from %d points' %n       
273        #Compute matrix elements for points inside the mesh
274        for d, i in enumerate(inside_poly_indices):
275            #For each data_coordinate point
276            if verbose and d%((n+10)/10)==0: print 'Doing %d of %d' %(d, n)
277            x = point_coordinates[i]
278            element_found, sigma0, sigma1, sigma2, k = \
279                           search_tree_of_vertices(self.root, self.mesh, x)
280           
281            if element_found is True:
282                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
283                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
284                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
285
286                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
287                js     = [j0,j1,j2]
288
289                for j in js:
290                    self.Atz[j] +=  sigmas[j]*z[i]
291                    #print "self.Atz building", self.Atz
292                    #print "self.Atz[j]", self.Atz[j]
293                    #print " sigmas[j]", sigmas[j]
294                    #print "z[i]",z[i]
295                    #print "result", sigmas[j]*z[i]
296                   
297                    for k in js:
298                        self.AtA[j,k] += sigmas[j]*sigmas[k]
299            else:
300                msg = 'Could not find triangle for point', x
301                raise Exception(msg)
302   
303       
304    def fit(self, point_coordinates_or_filename=None, z=None,
305            verbose=False,
306            point_origin=None,
307            attribute_name=None,
308            max_read_lines=500):
309        """Fit a smooth surface to given 1d array of data points z.
310
311        The smooth surface is computed at each vertex in the underlying
312        mesh using the formula given in the module doc string.
313
314        Inputs:
315        point_coordinates: The co-ordinates of the data points.
316              List of coordinate pairs [x, y] of
317              data points or an nx2 Numeric array or a Geospatial_data object
318              or filename (txt, csv, pts?)
319          z: Single 1d vector or array of data at the point_coordinates.
320         
321        """
322        # use blocking to load in the point info
323        if type(point_coordinates_or_filename) == types.StringType:
324            msg = "Don't set a point origin when reading from a file"
325            assert point_origin is None, msg
326            filename = point_coordinates_or_filename
327
328            G_data = Geospatial_data(filename,
329                                     max_read_lines=max_read_lines,
330                                     load_file_now=False,
331                                     verbose=verbose)
332            for i, geo_block in enumerate(G_data):
333                if verbose is True and 0 == i%200: # round every 5 minutes
334                    # But this is dependant on the # of Triangles, so it
335                    #isn't every 5 minutes.
336                       
337                    print 'Processing Block %d' %i
338                    # FIXME (Ole): It would be good to say how many blocks
339                    # there are here. But this is no longer necessary
340                    # for pts files as they are reported in geospatial_data
341                    # I suggest deleting this verbose output and make
342                    # Geospatial_data more informative for txt files.
343           
344
345                   
346                # build the array
347                points = geo_block.get_data_points(absolute=True)
348                z = geo_block.get_attributes(attribute_name=attribute_name)
349                self.build_fit_subset(points, z)
350            point_coordinates = None
351        else:
352            point_coordinates =  point_coordinates_or_filename
353           
354        if point_coordinates is None:
355            assert self.AtA <> None
356            assert self.Atz <> None
357            #FIXME (DSG) - do  a message
358        else:
359            point_coordinates = ensure_absolute(point_coordinates,
360                                                geo_reference=point_origin)
361            #if isinstance(point_coordinates,Geospatial_data) and z is None:
362            # z will come from the geo-ref
363            self.build_fit_subset(point_coordinates, z, verbose)
364
365        #Check sanity
366        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
367        n = self.point_count
368        if n<m and self.alpha == 0.0:
369            msg = 'ERROR (least_squares): Too few data points\n'
370            msg += 'There are only %d data points and alpha == 0. ' %n
371            msg += 'Need at least %d\n' %m
372            msg += 'Alternatively, set smoothing parameter alpha to a small '
373            msg += 'positive value,\ne.g. 1.0e-3.'
374            raise ToFewPointsError(msg)
375
376        self._build_coefficient_matrix_B(verbose)
377        loners = self.mesh.get_lone_vertices()
378        # FIXME  - make this as error message.
379        # test with
380        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
381        if len(loners)>0:
382            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
383            msg += 'All vertices should be part of a triangle.\n'
384            msg += 'In the future this will be inforced.\n'
385            msg += 'The following vertices are not part of a triangle;\n'
386            msg += str(loners)
387            print msg
388            #raise VertsWithNoTrianglesError(msg)
389       
390       
391        return conjugate_gradient(self.B, self.Atz, self.Atz,
392                                  imax=2*len(self.Atz) )
393
394       
395    def build_fit_subset(self, point_coordinates, z=None, attribute_name=None,
396                              verbose=False):
397        """Fit a smooth surface to given 1d array of data points z.
398
399        The smooth surface is computed at each vertex in the underlying
400        mesh using the formula given in the module doc string.
401
402        Inputs:
403        point_coordinates: The co-ordinates of the data points.
404              List of coordinate pairs [x, y] of
405              data points or an nx2 Numeric array or a Geospatial_data object
406        z: Single 1d vector or array of data at the point_coordinates.
407        attribute_name: Used to get the z values from the
408              geospatial object if no attribute_name is specified,
409              it's a bit of a lucky dip as to what attributes you get.
410              If there is only one attribute it will be that one.
411
412        """
413
414        #FIXME(DSG-DSG): Check that the vert and point coords
415        #have the same zone.
416        if isinstance(point_coordinates,Geospatial_data):
417            point_coordinates = point_coordinates.get_data_points( \
418                absolute = True)
419       
420        #Convert input to Numeric arrays
421        if z is not None:
422            z = ensure_numeric(z, Float)
423        else:
424            msg = 'z not specified'
425            assert isinstance(point_coordinates,Geospatial_data), msg
426            z = point_coordinates.get_attributes(attribute_name)
427           
428        point_coordinates = ensure_numeric(point_coordinates, Float)
429
430        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
431
432
433############################################################################
434
435def fit_to_mesh(vertex_coordinates,
436                triangles,
437                point_coordinates, # this can also be a .csv/.txt file name
438                point_attributes=None,
439                alpha=DEFAULT_ALPHA,
440                verbose=False,
441                acceptable_overshoot=1.01,
442                mesh_origin=None,
443                data_origin=None,
444                max_read_lines=None,
445                attribute_name=None,
446                use_cache = False):
447    """Wrapper around internal function _fit_to_mesh for use with caching.
448   
449    """
450   
451    args = (vertex_coordinates, triangles, point_coordinates, )
452    kwargs = {'point_attributes': point_attributes,
453              'alpha': alpha,
454              'verbose': verbose,
455              'acceptable_overshoot': acceptable_overshoot,
456              'mesh_origin': mesh_origin,
457              'data_origin': data_origin,
458              'max_read_lines': max_read_lines,
459              'attribute_name': attribute_name,
460              'use_cache':use_cache
461              }
462
463    if use_cache is True:
464        if isinstance(point_coordinates, basestring):
465            # We assume that point_coordinates is the name of a .csv/.txt
466            # file which must be passed onto caching as a dependency (in case it
467            # has changed on disk)
468            dep = [point_coordinates]
469        else:
470            dep = None
471
472        return cache(_fit_to_mesh,
473                     args, kwargs,
474                     verbose=verbose,
475                     compression=False,
476                     dependencies=dep)
477    else:
478        return apply(_fit_to_mesh,
479                     args, kwargs)
480
481def _fit_to_mesh(vertex_coordinates,
482                 triangles,
483                 point_coordinates, # this can also be a points file name
484                 point_attributes=None,
485                 alpha=DEFAULT_ALPHA,
486                 verbose=False,
487                 acceptable_overshoot=1.01,
488                 mesh_origin=None,
489                 data_origin=None,
490                 max_read_lines=None,
491                 attribute_name=None,
492                 use_cache = False):
493    """
494    Fit a smooth surface to a triangulation,
495    given data points with attributes.
496
497
498        Inputs:
499        vertex_coordinates: List of coordinate pairs [xi, eta] of
500              points constituting a mesh (or an m x 2 Numeric array or
501              a geospatial object)
502              Points may appear multiple times
503              (e.g. if vertices have discontinuities)
504
505          triangles: List of 3-tuples (or a Numeric array) of
506          integers representing indices of all vertices in the mesh.
507
508          point_coordinates: List of coordinate pairs [x, y] of data points
509          (or an nx2 Numeric array). This can also be a .csv/.txt/.pts
510          file name.
511
512          alpha: Smoothing parameter.
513
514          acceptable overshoot: controls the allowed factor by which fitted values
515          may exceed the value of input data. The lower limit is defined
516          as min(z) - acceptable_overshoot*delta z and upper limit
517          as max(z) + acceptable_overshoot*delta z
518
519          mesh_origin: A geo_reference object or 3-tuples consisting of
520              UTM zone, easting and northing.
521              If specified vertex coordinates are assumed to be
522              relative to their respective origins.
523         
524
525          point_attributes: Vector or array of data at the
526                            point_coordinates.
527
528    """
529    #Since this is a wrapper for fit, lets handle the geo_spatial att's
530    if use_cache is True:
531        interp = cache(_fit,
532                       (vertex_coordinates,
533                        triangles),
534                       {'verbose': verbose,
535                        'mesh_origin': mesh_origin,
536                        'alpha':alpha},
537                       compression = False,
538                       verbose = verbose)       
539       
540    else:
541        interp = Fit(vertex_coordinates,
542                     triangles,
543                     verbose=verbose,
544                     mesh_origin=mesh_origin,
545                     alpha=alpha)
546       
547    vertex_attributes = interp.fit(point_coordinates,
548                                   point_attributes,
549                                   point_origin=data_origin,
550                                   max_read_lines=max_read_lines,
551                                   attribute_name=attribute_name,
552                                   verbose=verbose)
553
554       
555    # Add the value checking stuff that's in least squares.
556    # Maybe this stuff should get pushed down into Fit.
557    # at least be a method of Fit.
558    # Or intigrate it into the fit method, saving teh max and min's
559    # as att's.
560   
561    return vertex_attributes
562
563def _fit(*args, **kwargs):
564    """Private function for use with caching. Reason is that classes
565    may change their byte code between runs which is annoying.
566    """
567   
568    return Fit(*args, **kwargs)
569
570
571def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
572                     alpha=DEFAULT_ALPHA, verbose= False,
573                     expand_search = False,
574                     precrop = False,
575                     display_errors = True):
576    """
577    Given a mesh file (tsh) and a point attribute file (xya), fit
578    point attributes to the mesh and write a mesh file with the
579    results.
580
581    Note: the .xya files need titles.  If you want anuga to use the tsh file,
582    make sure the title is elevation.
583
584    NOTE: Throws IOErrors, for a variety of file problems.
585   
586    """
587
588    from load_mesh.loadASCII import import_mesh_file, \
589         export_mesh_file, concatinate_attributelist
590
591
592    try:
593        mesh_dict = import_mesh_file(mesh_file)
594    except IOError,e:
595        if display_errors:
596            print "Could not load bad file. ", e
597        raise IOError  #Could not load bad mesh file.
598   
599    vertex_coordinates = mesh_dict['vertices']
600    triangles = mesh_dict['triangles']
601    if type(mesh_dict['vertex_attributes']) == ArrayType:
602        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
603    else:
604        old_point_attributes = mesh_dict['vertex_attributes']
605
606    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
607        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
608    else:
609        old_title_list = mesh_dict['vertex_attribute_titles']
610
611    if verbose: print 'tsh file %s loaded' %mesh_file
612
613    # load in the points file
614    try:
615        geo = Geospatial_data(point_file, verbose=verbose)
616    except IOError,e:
617        if display_errors:
618            print "Could not load bad file. ", e
619        raise IOError  #Re-raise exception 
620
621    point_coordinates = geo.get_data_points(absolute=True)
622    title_list,point_attributes = concatinate_attributelist( \
623        geo.get_all_attributes())
624
625    if mesh_dict.has_key('geo_reference') and \
626           not mesh_dict['geo_reference'] is None:
627        mesh_origin = mesh_dict['geo_reference'].get_origin()
628    else:
629        mesh_origin = None
630
631    if verbose: print "points file loaded"
632    if verbose: print "fitting to mesh"
633    f = fit_to_mesh(vertex_coordinates,
634                    triangles,
635                    point_coordinates,
636                    point_attributes,
637                    alpha = alpha,
638                    verbose = verbose,
639                    data_origin = None,
640                    mesh_origin = mesh_origin)
641    if verbose: print "finished fitting to mesh"
642
643    # convert array to list of lists
644    new_point_attributes = f.tolist()
645    #FIXME have this overwrite attributes with the same title - DSG
646    #Put the newer attributes last
647    if old_title_list <> []:
648        old_title_list.extend(title_list)
649        #FIXME can this be done a faster way? - DSG
650        for i in range(len(old_point_attributes)):
651            old_point_attributes[i].extend(new_point_attributes[i])
652        mesh_dict['vertex_attributes'] = old_point_attributes
653        mesh_dict['vertex_attribute_titles'] = old_title_list
654    else:
655        mesh_dict['vertex_attributes'] = new_point_attributes
656        mesh_dict['vertex_attribute_titles'] = title_list
657
658    if verbose: print "exporting to file ", mesh_output_file
659
660    try:
661        export_mesh_file(mesh_output_file, mesh_dict)
662    except IOError,e:
663        if display_errors:
664            print "Could not write file. ", e
665        raise IOError
Note: See TracBrowser for help on using the repository browser.