source: anuga_core/source/anuga/fit_interpolate/fit.py @ 4425

Last change on this file since 4425 was 4425, checked in by ole, 17 years ago

Added points datafile as a caching dependency.
Also implemented that caching now throws an error if dependency file has been removed after caching has taken place.

File size: 20.7 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18
19   TO DO
20   * test geo_ref, geo_spatial
21
22   IDEAS
23   * (DSG-) Change the interface of fit, so a domain object can
24      be passed in. (I don't know if this is feasible). If could
25      save time/memory.
26"""
27import types
28
29from Numeric import zeros, Float, ArrayType,take
30
31from anuga.caching import cache           
32from anuga.geospatial_data.geospatial_data import Geospatial_data, \
33     ensure_absolute
34from anuga.fit_interpolate.general_fit_interpolate import FitInterpolate
35from anuga.utilities.sparse import Sparse, Sparse_CSR
36from anuga.utilities.polygon import in_and_outside_polygon
37from anuga.fit_interpolate.search_functions import search_tree_of_vertices
38from anuga.utilities.cg_solve import conjugate_gradient
39from anuga.utilities.numerical_tools import ensure_numeric, gradient
40
41import exceptions
42class ToFewPointsError(exceptions.Exception): pass
43class VertsWithNoTrianglesError(exceptions.Exception): pass
44
45DEFAULT_ALPHA = 0.001
46
47
48class Fit(FitInterpolate):
49   
50    def __init__(self,
51                 vertex_coordinates,
52                 triangles,
53                 mesh_origin=None,
54                 alpha = None,
55                 verbose=False,
56                 max_vertices_per_cell=30):
57
58
59        """
60        Fit data at points to the vertices of a mesh.
61
62        Inputs:
63
64          vertex_coordinates: List of coordinate pairs [xi, eta] of
65              points constituting a mesh (or an m x 2 Numeric array or
66              a geospatial object)
67              Points may appear multiple times
68              (e.g. if vertices have discontinuities)
69
70          triangles: List of 3-tuples (or a Numeric array) of
71              integers representing indices of all vertices in the mesh.
72
73          mesh_origin: A geo_reference object or 3-tuples consisting of
74              UTM zone, easting and northing.
75              If specified vertex coordinates are assumed to be
76              relative to their respective origins.
77
78          max_vertices_per_cell: Number of vertices in a quad tree cell
79          at which the cell is split into 4.
80
81          Note: Don't supply a vertex coords as a geospatial object and
82              a mesh origin, since geospatial has its own mesh origin.
83
84
85        Usage,
86        To use this in a blocking way, call  build_fit_subset, with z info,
87        and then fit, with no point coord, z info.
88       
89        """
90        # Initialise variabels
91
92        if alpha is None:
93
94            self.alpha = DEFAULT_ALPHA
95        else:   
96            self.alpha = alpha
97        FitInterpolate.__init__(self,
98                 vertex_coordinates,
99                 triangles,
100                 mesh_origin,
101                 verbose,
102                 max_vertices_per_cell)
103       
104        m = self.mesh.number_of_nodes # Nbr of basis functions (vertices)
105       
106        self.AtA = None
107        self.Atz = None
108
109        self.point_count = 0
110        if self.alpha <> 0:
111            if verbose: print 'Building smoothing matrix'
112            self._build_smoothing_matrix_D()
113           
114    def _build_coefficient_matrix_B(self,
115                                  verbose = False):
116        """
117        Build final coefficient matrix
118
119        Precon
120        If alpha is not zero, matrix D has been built
121        Matrix Ata has been built
122        """
123
124        if self.alpha <> 0:
125            #if verbose: print 'Building smoothing matrix'
126            #self._build_smoothing_matrix_D()
127            self.B = self.AtA + self.alpha*self.D
128        else:
129            self.B = self.AtA
130
131        #Convert self.B matrix to CSR format for faster matrix vector
132        self.B = Sparse_CSR(self.B)
133
134    def _build_smoothing_matrix_D(self):
135        """Build m x m smoothing matrix, where
136        m is the number of basis functions phi_k (one per vertex)
137
138        The smoothing matrix is defined as
139
140        D = D1 + D2
141
142        where
143
144        [D1]_{k,l} = \int_\Omega
145           \frac{\partial \phi_k}{\partial x}
146           \frac{\partial \phi_l}{\partial x}\,
147           dx dy
148
149        [D2]_{k,l} = \int_\Omega
150           \frac{\partial \phi_k}{\partial y}
151           \frac{\partial \phi_l}{\partial y}\,
152           dx dy
153
154
155        The derivatives \frac{\partial \phi_k}{\partial x},
156        \frac{\partial \phi_k}{\partial x} for a particular triangle
157        are obtained by computing the gradient a_k, b_k for basis function k
158        """
159       
160        #FIXME: algorithm might be optimised by computing local 9x9
161        #"element stiffness matrices:
162
163        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
164
165        self.D = Sparse(m,m)
166
167        #For each triangle compute contributions to D = D1+D2
168        for i in range(len(self.mesh)):
169
170            #Get area
171            area = self.mesh.areas[i]
172
173            #Get global vertex indices
174            v0 = self.mesh.triangles[i,0]
175            v1 = self.mesh.triangles[i,1]
176            v2 = self.mesh.triangles[i,2]
177
178            #Get the three vertex_points
179            xi0 = self.mesh.get_vertex_coordinate(i, 0)
180            xi1 = self.mesh.get_vertex_coordinate(i, 1)
181            xi2 = self.mesh.get_vertex_coordinate(i, 2)
182
183            #Compute gradients for each vertex
184            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
185                              1, 0, 0)
186
187            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
188                              0, 1, 0)
189
190            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
191                              0, 0, 1)
192
193            #Compute diagonal contributions
194            self.D[v0,v0] += (a0*a0 + b0*b0)*area
195            self.D[v1,v1] += (a1*a1 + b1*b1)*area
196            self.D[v2,v2] += (a2*a2 + b2*b2)*area
197
198            #Compute contributions for basis functions sharing edges
199            e01 = (a0*a1 + b0*b1)*area
200            self.D[v0,v1] += e01
201            self.D[v1,v0] += e01
202
203            e12 = (a1*a2 + b1*b2)*area
204            self.D[v1,v2] += e12
205            self.D[v2,v1] += e12
206
207            e20 = (a2*a0 + b2*b0)*area
208            self.D[v2,v0] += e20
209            self.D[v0,v2] += e20
210
211
212    def get_D(self):
213        return self.D.todense()
214
215
216    def _build_matrix_AtA_Atz(self,
217                              point_coordinates,
218                              z,
219                              verbose = False):
220        """Build:
221        AtA  m x m  interpolation matrix, and,
222        Atz  m x a  interpolation matrix where,
223        m is the number of basis functions phi_k (one per vertex)
224        a is the number of data attributes
225
226        This algorithm uses a quad tree data structure for fast binning of
227        data points.
228
229        If Ata is None, the matrices AtA and Atz are created.
230
231        This function can be called again and again, with sub-sets of
232        the point coordinates.  Call fit to get the results.
233       
234        Preconditions
235        z and points are numeric
236        Point_coordindates and mesh vertices have the same origin.
237
238        The number of attributes of the data points does not change
239        """
240        #Build n x m interpolation matrix
241
242        if self.AtA == None:
243            # AtA and Atz need to be initialised.
244            m = self.mesh.number_of_nodes
245            if len(z.shape) > 1:
246                att_num = z.shape[1]
247                self.Atz = zeros((m,att_num), Float)
248            else:
249                att_num = 1
250                self.Atz = zeros((m,), Float)
251            assert z.shape[0] == point_coordinates.shape[0] 
252
253            self.AtA = Sparse(m,m)
254            # The memory damage has been done by now.
255           
256        self.point_count += point_coordinates.shape[0]
257        #print "_build_matrix_AtA_Atz - self.point_count", self.point_count
258        if verbose: print 'Getting indices inside mesh boundary'
259        #print 'point_coordinates.shape', point_coordinates.shape         
260        #print 'self.mesh.get_boundary_polygon()',\
261        #      self.mesh.get_boundary_polygon()
262
263        inside_poly_indices, outside_poly_indices  = \
264                     in_and_outside_polygon(point_coordinates,
265                                            self.mesh.get_boundary_polygon(),
266                                            closed = True, verbose = verbose)
267        #print "self.inside_poly_indices",self.inside_poly_indices
268        #print "self.outside_poly_indices",self.outside_poly_indices
269
270       
271        n = len(inside_poly_indices)
272        if verbose: print 'Building fitting matrix from %d points' %n       
273        #Compute matrix elements for points inside the mesh
274        for k, i in enumerate(inside_poly_indices):
275            #For each data_coordinate point
276            if verbose and k%((n+10)/10)==0: print 'Doing %d of %d' %(k, n)
277            x = point_coordinates[i]
278            element_found, sigma0, sigma1, sigma2, k = \
279                           search_tree_of_vertices(self.root, self.mesh, x)
280           
281            if element_found is True:
282                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
283                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
284                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
285
286                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
287                js     = [j0,j1,j2]
288
289                for j in js:
290                    self.Atz[j] +=  sigmas[j]*z[i]
291                    #print "self.Atz building", self.Atz
292                    #print "self.Atz[j]", self.Atz[j]
293                    #print " sigmas[j]", sigmas[j]
294                    #print "z[i]",z[i]
295                    #print "result", sigmas[j]*z[i]
296                   
297                    for k in js:
298                        self.AtA[j,k] += sigmas[j]*sigmas[k]
299            else:
300                msg = 'Could not find triangle for point', x
301                raise Exception(msg)
302   
303       
304    def fit(self, point_coordinates_or_filename=None, z=None,
305            verbose=False,
306            point_origin=None,
307            attribute_name=None,
308            max_read_lines=500):
309        """Fit a smooth surface to given 1d array of data points z.
310
311        The smooth surface is computed at each vertex in the underlying
312        mesh using the formula given in the module doc string.
313
314        Inputs:
315        point_coordinates: The co-ordinates of the data points.
316              List of coordinate pairs [x, y] of
317              data points or an nx2 Numeric array or a Geospatial_data object
318          z: Single 1d vector or array of data at the point_coordinates.
319         
320        """
321        # use blocking to load in the point info
322        if type(point_coordinates_or_filename) == types.StringType:
323            msg = "Don't set a point origin when reading from a file"
324            assert point_origin is None, msg
325            filename = point_coordinates_or_filename
326            for i,geo_block in enumerate(Geospatial_data(filename,
327                                              max_read_lines=max_read_lines,
328                                              load_file_now=False,
329                                              verbose=verbose)):
330                if verbose is True and 0 == i%200: # round every 5 minutes
331                    # But this is dependant on the # of Triangles, so it
332                    #isn't every 5 minutes.
333                    print 'Block %i' %i
334                # build the array
335                points = geo_block.get_data_points(absolute=True)
336                z = geo_block.get_attributes(attribute_name=attribute_name)
337                self.build_fit_subset(points, z)
338            point_coordinates = None
339        else:
340            point_coordinates =  point_coordinates_or_filename
341           
342        if point_coordinates is None:
343            assert self.AtA <> None
344            assert self.Atz <> None
345            #FIXME (DSG) - do  a message
346        else:
347            point_coordinates = ensure_absolute(point_coordinates,
348                                                geo_reference=point_origin)
349            #if isinstance(point_coordinates,Geospatial_data) and z is None:
350            # z will come from the geo-ref
351            self.build_fit_subset(point_coordinates, z, verbose)
352
353        #Check sanity
354        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
355        n = self.point_count
356        if n<m and self.alpha == 0.0:
357            msg = 'ERROR (least_squares): Too few data points\n'
358            msg += 'There are only %d data points and alpha == 0. ' %n
359            msg += 'Need at least %d\n' %m
360            msg += 'Alternatively, set smoothing parameter alpha to a small '
361            msg += 'positive value,\ne.g. 1.0e-3.'
362            raise ToFewPointsError(msg)
363
364        self._build_coefficient_matrix_B(verbose)
365        loners = self.mesh.get_lone_vertices()
366        # FIXME  - make this as error message.
367        # test with
368        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
369        if len(loners)>0:
370            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
371            msg += 'All vertices should be part of a triangle.\n'
372            msg += 'In the future this will be inforced.\n'
373            msg += 'The following vertices are not part of a triangle;\n'
374            msg += str(loners)
375            print msg
376            #raise VertsWithNoTrianglesError(msg)
377       
378       
379        return conjugate_gradient(self.B, self.Atz, self.Atz,
380                                  imax=2*len(self.Atz) )
381
382       
383    def build_fit_subset(self, point_coordinates, z=None, attribute_name=None,
384                              verbose=False):
385        """Fit a smooth surface to given 1d array of data points z.
386
387        The smooth surface is computed at each vertex in the underlying
388        mesh using the formula given in the module doc string.
389
390        Inputs:
391        point_coordinates: The co-ordinates of the data points.
392              List of coordinate pairs [x, y] of
393              data points or an nx2 Numeric array or a Geospatial_data object
394        z: Single 1d vector or array of data at the point_coordinates.
395        attribute_name: Used to get the z values from the
396              geospatial object if no attribute_name is specified,
397              it's a bit of a lucky dip as to what attributes you get.
398              If there is only one attribute it will be that one.
399
400        """
401
402        #FIXME(DSG-DSG): Check that the vert and point coords
403        #have the same zone.
404        if isinstance(point_coordinates,Geospatial_data):
405            point_coordinates = point_coordinates.get_data_points( \
406                absolute = True)
407       
408        #Convert input to Numeric arrays
409        if z is not None:
410            z = ensure_numeric(z, Float)
411        else:
412            msg = 'z not specified'
413            assert isinstance(point_coordinates,Geospatial_data), msg
414            z = point_coordinates.get_attributes(attribute_name)
415           
416        point_coordinates = ensure_numeric(point_coordinates, Float)
417
418        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
419
420
421############################################################################
422
423def fit_to_mesh(vertex_coordinates,
424                triangles,
425                point_coordinates, # this can also be a .csv/.txt file name
426                point_attributes=None,
427                alpha=DEFAULT_ALPHA,
428                verbose=False,
429                acceptable_overshoot=1.01,
430                mesh_origin=None,
431                data_origin=None,
432                max_read_lines=None,
433                attribute_name=None,
434                use_cache = False):
435    """Wrapper around internal function _fit_to_mesh for use with caching.
436   
437    """
438   
439    args = (vertex_coordinates, triangles, point_coordinates, )
440    kwargs = {'point_attributes': point_attributes,
441              'alpha': alpha,
442              'verbose': verbose,
443              'acceptable_overshoot': acceptable_overshoot,
444              'mesh_origin': mesh_origin,
445              'data_origin': data_origin,
446              'max_read_lines': max_read_lines,
447              'attribute_name': attribute_name,
448              'use_cache':use_cache
449              }
450
451    if use_cache is True:
452        if isinstance(point_coordinates, basestring):
453            # We assume that point_coordinates is the name of a .csv/.txt
454            # file which must be passed onto caching as a dependency (in case it
455            # has changed on disk)
456            dep = [point_coordinates]
457        else:
458            dep = None
459
460        return cache(_fit_to_mesh,
461                     args, kwargs,
462                     verbose=verbose,
463                     compression=False,
464                     dependencies=dep)
465    else:
466        return apply(_fit_to_mesh,
467                     args, kwargs)
468
469def _fit_to_mesh(vertex_coordinates,
470                triangles,
471                point_coordinates, # this can also be a .csv/.txt file name
472                point_attributes=None,
473                alpha=DEFAULT_ALPHA,
474                verbose=False,
475                acceptable_overshoot=1.01,
476                mesh_origin=None,
477                data_origin=None,
478                max_read_lines=None,
479                attribute_name=None,
480                use_cache = False):
481    """
482    Fit a smooth surface to a triangulation,
483    given data points with attributes.
484
485
486        Inputs:
487        vertex_coordinates: List of coordinate pairs [xi, eta] of
488              points constituting a mesh (or an m x 2 Numeric array or
489              a geospatial object)
490              Points may appear multiple times
491              (e.g. if vertices have discontinuities)
492
493          triangles: List of 3-tuples (or a Numeric array) of
494          integers representing indices of all vertices in the mesh.
495
496          point_coordinates: List of coordinate pairs [x, y] of data points
497          (or an nx2 Numeric array)
498
499          alpha: Smoothing parameter.
500
501          acceptable overshoot: controls the allowed factor by which fitted values
502          may exceed the value of input data. The lower limit is defined
503          as min(z) - acceptable_overshoot*delta z and upper limit
504          as max(z) + acceptable_overshoot*delta z
505
506          mesh_origin: A geo_reference object or 3-tuples consisting of
507              UTM zone, easting and northing.
508              If specified vertex coordinates are assumed to be
509              relative to their respective origins.
510         
511
512          point_attributes: Vector or array of data at the
513                            point_coordinates.
514
515    """
516    #Since this is a wrapper for fit, lets handle the geo_spatial att's
517    if use_cache is True:
518        interp = cache(_fit,
519                       (vertex_coordinates,
520                        triangles),
521                       {'verbose': verbose,
522                        'mesh_origin': mesh_origin,
523                        'alpha':alpha},
524                       verbose = verbose)       
525       
526    else:
527        interp = Fit(vertex_coordinates,
528                     triangles,
529                     verbose=verbose,
530                     mesh_origin=mesh_origin,
531                     alpha=alpha)
532       
533    vertex_attributes = interp.fit(point_coordinates,
534                                   point_attributes,
535                                   point_origin=data_origin,
536                                   max_read_lines=max_read_lines,
537                                   attribute_name=attribute_name,
538                                   verbose=verbose)
539
540       
541    # Add the value checking stuff that's in least squares.
542    # Maybe this stuff should get pushed down into Fit.
543    # at least be a method of Fit.
544    # Or intigrate it into the fit method, saving teh max and min's
545    # as att's.
546   
547    return vertex_attributes
548
549def _fit(*args, **kwargs):
550    """Private function for use with caching. Reason is that classes
551    may change their byte code between runs which is annoying.
552    """
553   
554    return Fit(*args, **kwargs)
555
Note: See TracBrowser for help on using the repository browser.