source: anuga_core/source/anuga/fit_interpolate/fit.py @ 4180

Last change on this file since 4180 was 4175, checked in by duncan, 18 years ago

added blocking on netcdf files & a fix in using verbose

File size: 18.9 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18
19   TO DO
20   * test geo_ref, geo_spatial
21
22   IDEAS
23   * (DSG-) Change the interface of fit, so a domain object can
24      be passed in. (I don't know if this is feasible). If could
25      save time/memory.
26"""
27import types
28
29from Numeric import zeros, Float, ArrayType,take
30
31from anuga.caching import cache           
32from anuga.geospatial_data.geospatial_data import Geospatial_data, \
33     ensure_absolute
34from anuga.fit_interpolate.general_fit_interpolate import FitInterpolate
35from anuga.utilities.sparse import Sparse, Sparse_CSR
36from anuga.utilities.polygon import in_and_outside_polygon
37from anuga.fit_interpolate.search_functions import search_tree_of_vertices
38from anuga.utilities.cg_solve import conjugate_gradient
39from anuga.utilities.numerical_tools import ensure_numeric, gradient
40
41import exceptions
42class ToFewPointsError(exceptions.Exception): pass
43class VertsWithNoTrianglesError(exceptions.Exception): pass
44
45DEFAULT_ALPHA = 0.001
46
47
48class Fit(FitInterpolate):
49   
50    def __init__(self,
51                 vertex_coordinates,
52                 triangles,
53                 mesh_origin=None,
54                 alpha = None,
55                 verbose=False,
56                 max_vertices_per_cell=30):
57
58
59        """
60        Fit data at points to the vertices of a mesh.
61
62        Inputs:
63
64          vertex_coordinates: List of coordinate pairs [xi, eta] of
65              points constituting a mesh (or an m x 2 Numeric array or
66              a geospatial object)
67              Points may appear multiple times
68              (e.g. if vertices have discontinuities)
69
70          triangles: List of 3-tuples (or a Numeric array) of
71              integers representing indices of all vertices in the mesh.
72
73          mesh_origin: A geo_reference object or 3-tuples consisting of
74              UTM zone, easting and northing.
75              If specified vertex coordinates are assumed to be
76              relative to their respective origins.
77
78          max_vertices_per_cell: Number of vertices in a quad tree cell
79          at which the cell is split into 4.
80
81          Note: Don't supply a vertex coords as a geospatial object and
82              a mesh origin, since geospatial has its own mesh origin.
83
84
85        Usage,
86        To use this in a blocking way, call  build_fit_subset, with z info,
87        and then fit, with no point coord, z info.
88       
89        """
90        # Initialise variabels
91
92        if alpha is None:
93
94            self.alpha = DEFAULT_ALPHA
95        else:   
96            self.alpha = alpha
97        FitInterpolate.__init__(self,
98                 vertex_coordinates,
99                 triangles,
100                 mesh_origin,
101                 verbose,
102                 max_vertices_per_cell)
103       
104        m = self.mesh.number_of_nodes # Nbr of basis functions (vertices)
105       
106        self.AtA = None
107        self.Atz = None
108
109        self.point_count = 0
110        if self.alpha <> 0:
111            if verbose: print 'Building smoothing matrix'
112            self._build_smoothing_matrix_D()
113           
114    def _build_coefficient_matrix_B(self,
115                                  verbose = False):
116        """
117        Build final coefficient matrix
118
119        Precon
120        If alpha is not zero, matrix D has been built
121        Matrix Ata has been built
122        """
123
124        if self.alpha <> 0:
125            #if verbose: print 'Building smoothing matrix'
126            #self._build_smoothing_matrix_D()
127            self.B = self.AtA + self.alpha*self.D
128        else:
129            self.B = self.AtA
130
131        #Convert self.B matrix to CSR format for faster matrix vector
132        self.B = Sparse_CSR(self.B)
133
134    def _build_smoothing_matrix_D(self):
135        """Build m x m smoothing matrix, where
136        m is the number of basis functions phi_k (one per vertex)
137
138        The smoothing matrix is defined as
139
140        D = D1 + D2
141
142        where
143
144        [D1]_{k,l} = \int_\Omega
145           \frac{\partial \phi_k}{\partial x}
146           \frac{\partial \phi_l}{\partial x}\,
147           dx dy
148
149        [D2]_{k,l} = \int_\Omega
150           \frac{\partial \phi_k}{\partial y}
151           \frac{\partial \phi_l}{\partial y}\,
152           dx dy
153
154
155        The derivatives \frac{\partial \phi_k}{\partial x},
156        \frac{\partial \phi_k}{\partial x} for a particular triangle
157        are obtained by computing the gradient a_k, b_k for basis function k
158        """
159       
160        #FIXME: algorithm might be optimised by computing local 9x9
161        #"element stiffness matrices:
162
163        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
164
165        self.D = Sparse(m,m)
166
167        #For each triangle compute contributions to D = D1+D2
168        for i in range(len(self.mesh)):
169
170            #Get area
171            area = self.mesh.areas[i]
172
173            #Get global vertex indices
174            v0 = self.mesh.triangles[i,0]
175            v1 = self.mesh.triangles[i,1]
176            v2 = self.mesh.triangles[i,2]
177
178            #Get the three vertex_points
179            xi0 = self.mesh.get_vertex_coordinate(i, 0)
180            xi1 = self.mesh.get_vertex_coordinate(i, 1)
181            xi2 = self.mesh.get_vertex_coordinate(i, 2)
182
183            #Compute gradients for each vertex
184            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
185                              1, 0, 0)
186
187            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
188                              0, 1, 0)
189
190            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
191                              0, 0, 1)
192
193            #Compute diagonal contributions
194            self.D[v0,v0] += (a0*a0 + b0*b0)*area
195            self.D[v1,v1] += (a1*a1 + b1*b1)*area
196            self.D[v2,v2] += (a2*a2 + b2*b2)*area
197
198            #Compute contributions for basis functions sharing edges
199            e01 = (a0*a1 + b0*b1)*area
200            self.D[v0,v1] += e01
201            self.D[v1,v0] += e01
202
203            e12 = (a1*a2 + b1*b2)*area
204            self.D[v1,v2] += e12
205            self.D[v2,v1] += e12
206
207            e20 = (a2*a0 + b2*b0)*area
208            self.D[v2,v0] += e20
209            self.D[v0,v2] += e20
210
211
212    def get_D(self):
213        return self.D.todense()
214
215
216    def _build_matrix_AtA_Atz(self,
217                              point_coordinates,
218                              z,
219                              verbose = False):
220        """Build:
221        AtA  m x m  interpolation matrix, and,
222        Atz  m x a  interpolation matrix where,
223        m is the number of basis functions phi_k (one per vertex)
224        a is the number of data attributes
225
226        This algorithm uses a quad tree data structure for fast binning of
227        data points.
228
229        If Ata is None, the matrices AtA and Atz are created.
230
231        This function can be called again and again, with sub-sets of
232        the point coordinates.  Call fit to get the results.
233       
234        Preconditions
235        z and points are numeric
236        Point_coordindates and mesh vertices have the same origin.
237
238        The number of attributes of the data points does not change
239        """
240        #Build n x m interpolation matrix
241
242        if self.AtA == None:
243            # AtA and Atz need to be initialised.
244            m = self.mesh.number_of_nodes
245            if len(z.shape) > 1:
246                att_num = z.shape[1]
247                self.Atz = zeros((m,att_num), Float)
248            else:
249                att_num = 1
250                self.Atz = zeros((m,), Float)
251            assert z.shape[0] == point_coordinates.shape[0] 
252
253            self.AtA = Sparse(m,m)
254            # The memory damage has been done by now.
255           
256        self.point_count += point_coordinates.shape[0]
257        #print "_build_matrix_AtA_Atz - self.point_count", self.point_count
258        if verbose: print 'Getting indices inside mesh boundary'
259        #print 'point_coordinates.shape', point_coordinates.shape         
260        #print 'self.mesh.get_boundary_polygon()',\
261        #      self.mesh.get_boundary_polygon()
262
263        # Why are these global?
264        self.inside_poly_indices, self.outside_poly_indices  = \
265                     in_and_outside_polygon(point_coordinates,
266                                            self.mesh.get_boundary_polygon(),
267                                            closed = True, verbose = verbose)
268        #print "self.inside_poly_indices",self.inside_poly_indices
269        #print "self.outside_poly_indices",self.outside_poly_indices
270
271       
272        n = len(self.inside_poly_indices)
273        if verbose: print 'Building fitting matrix from %d points' %n       
274        #Compute matrix elements for points inside the mesh
275        for k, i in enumerate(self.inside_poly_indices):
276            #For each data_coordinate point
277            if verbose and k%((n+10)/10)==0: print 'Doing %d of %d' %(k, n)
278            x = point_coordinates[i]
279            element_found, sigma0, sigma1, sigma2, k = \
280                           search_tree_of_vertices(self.root, self.mesh, x)
281           
282            if element_found is True:
283                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
284                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
285                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
286
287                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
288                js     = [j0,j1,j2]
289
290                for j in js:
291                    self.Atz[j] +=  sigmas[j]*z[i]
292                    #print "self.Atz building", self.Atz
293                    #print "self.Atz[j]", self.Atz[j]
294                    #print " sigmas[j]", sigmas[j]
295                    #print "z[i]",z[i]
296                    #print "result", sigmas[j]*z[i]
297                   
298                    for k in js:
299                        self.AtA[j,k] += sigmas[j]*sigmas[k]
300            else:
301                msg = 'Could not find triangle for point', x
302                raise Exception(msg)
303   
304       
305    def fit(self, point_coordinates_or_filename=None, z=None,
306            verbose=False,
307            point_origin=None,
308            attribute_name=None,
309            max_read_lines=500):
310        """Fit a smooth surface to given 1d array of data points z.
311
312        The smooth surface is computed at each vertex in the underlying
313        mesh using the formula given in the module doc string.
314
315        Inputs:
316        point_coordinates: The co-ordinates of the data points.
317              List of coordinate pairs [x, y] of
318              data points or an nx2 Numeric array or a Geospatial_data object
319          z: Single 1d vector or array of data at the point_coordinates.
320         
321        """
322        # use blocking to load in the point info
323        if type(point_coordinates_or_filename) == types.StringType:
324            msg = "Don't set a point origin when reading from a file"
325            assert point_origin is None, msg
326            filename = point_coordinates_or_filename
327            for i,geo_block in  enumerate(Geospatial_data(filename,
328                                              max_read_lines=max_read_lines,
329                                              load_file_now=False,
330                                              verbose=verbose)):
331                if verbose is True and 0 == i%200: # round every 5 minutes
332                    print 'Block %i' %i
333                # build the array
334                points = geo_block.get_data_points(absolute=True)
335                z = geo_block.get_attributes(attribute_name=attribute_name)
336                self.build_fit_subset(points, z)
337            point_coordinates = None
338        else:
339            point_coordinates =  point_coordinates_or_filename
340           
341        if point_coordinates is None:
342            assert self.AtA <> None
343            assert self.Atz <> None
344            #FIXME (DSG) - do  a message
345        else:
346            point_coordinates = ensure_absolute(point_coordinates,
347                                                geo_reference=point_origin)
348            #if isinstance(point_coordinates,Geospatial_data) and z is None:
349            # z will come from the geo-ref
350            self.build_fit_subset(point_coordinates, z, verbose)
351
352        #Check sanity
353        m = self.mesh.number_of_nodes # Nbr of basis functions (1/vertex)
354        n = self.point_count
355        if n<m and self.alpha == 0.0:
356            msg = 'ERROR (least_squares): Too few data points\n'
357            msg += 'There are only %d data points and alpha == 0. ' %n
358            msg += 'Need at least %d\n' %m
359            msg += 'Alternatively, set smoothing parameter alpha to a small '
360            msg += 'positive value,\ne.g. 1.0e-3.'
361            raise ToFewPointsError(msg)
362
363        self._build_coefficient_matrix_B(verbose)
364        loners = self.mesh.get_lone_vertices()
365        # FIXME  - make this as error message.
366        # test with
367        # Not_yet_test_smooth_att_to_mesh_with_excess_verts.
368        if len(loners)>0:
369            msg = 'WARNING: (least_squares): \nVertices with no triangles\n'
370            msg += 'All vertices should be part of a triangle.\n'
371            msg += 'In the future this will be inforced.\n'
372            msg += 'The following vertices are not part of a triangle;\n'
373            msg += str(loners)
374            print msg
375            #raise VertsWithNoTrianglesError(msg)
376       
377       
378        return conjugate_gradient(self.B, self.Atz, self.Atz,
379                                  imax=2*len(self.Atz) )
380
381       
382    def build_fit_subset(self, point_coordinates, z=None, attribute_name=None,
383                              verbose=False):
384        """Fit a smooth surface to given 1d array of data points z.
385
386        The smooth surface is computed at each vertex in the underlying
387        mesh using the formula given in the module doc string.
388
389        Inputs:
390        point_coordinates: The co-ordinates of the data points.
391              List of coordinate pairs [x, y] of
392              data points or an nx2 Numeric array or a Geospatial_data object
393        z: Single 1d vector or array of data at the point_coordinates.
394        attribute_name: Used to get the z values from the
395              geospatial object if no attribute_name is specified,
396              it's a bit of a lucky dip as to what attributes you get.
397              If there is only one attribute it will be that one.
398
399        """
400
401        #FIXME(DSG-DSG): Check that the vert and point coords
402        #have the same zone.
403        if isinstance(point_coordinates,Geospatial_data):
404            point_coordinates = point_coordinates.get_data_points( \
405                absolute = True)
406       
407        #Convert input to Numeric arrays
408        if z is not None:
409            z = ensure_numeric(z, Float)
410        else:
411            msg = 'z not specified'
412            assert isinstance(point_coordinates,Geospatial_data), msg
413            z = point_coordinates.get_attributes(attribute_name)
414           
415        point_coordinates = ensure_numeric(point_coordinates, Float)
416
417        self._build_matrix_AtA_Atz(point_coordinates, z, verbose)
418
419
420############################################################################
421
422def fit_to_mesh(vertex_coordinates,
423                triangles,
424                point_coordinates, # this can also be a .csv/.txt file name
425                point_attributes=None,
426                alpha=DEFAULT_ALPHA,
427                verbose=False,
428                acceptable_overshoot=1.01,
429                mesh_origin=None,
430                data_origin=None,
431                max_read_lines=None,
432                attribute_name=None,
433                use_cache = False):
434    """
435    Fit a smooth surface to a triangulation,
436    given data points with attributes.
437
438
439        Inputs:
440        vertex_coordinates: List of coordinate pairs [xi, eta] of
441              points constituting a mesh (or an m x 2 Numeric array or
442              a geospatial object)
443              Points may appear multiple times
444              (e.g. if vertices have discontinuities)
445
446          triangles: List of 3-tuples (or a Numeric array) of
447          integers representing indices of all vertices in the mesh.
448
449          point_coordinates: List of coordinate pairs [x, y] of data points
450          (or an nx2 Numeric array)
451
452          alpha: Smoothing parameter.
453
454          acceptable overshoot: controls the allowed factor by which fitted values
455          may exceed the value of input data. The lower limit is defined
456          as min(z) - acceptable_overshoot*delta z and upper limit
457          as max(z) + acceptable_overshoot*delta z
458
459          mesh_origin: A geo_reference object or 3-tuples consisting of
460              UTM zone, easting and northing.
461              If specified vertex coordinates are assumed to be
462              relative to their respective origins.
463         
464
465          point_attributes: Vector or array of data at the
466                            point_coordinates.
467
468    """
469    #Since this is a wrapper for fit, lets handle the geo_spatial att's
470    if use_cache is True:
471        interp = cache(_fit,
472                       (vertex_coordinates,
473                        triangles),
474                       {'verbose': verbose,
475                        'mesh_origin': mesh_origin,
476                        'alpha':alpha},
477                       verbose = verbose)       
478       
479    else:
480        interp = Fit(vertex_coordinates,
481                     triangles,
482                     verbose=verbose,
483                     mesh_origin=mesh_origin,
484                     alpha=alpha)
485       
486    vertex_attributes = interp.fit(point_coordinates,
487                                   point_attributes,
488                                   point_origin=data_origin,
489                                   max_read_lines=max_read_lines,
490                                   attribute_name=attribute_name,
491                                   verbose=verbose)
492
493       
494    # Add the value checking stuff that's in least squares.
495    # Maybe this stuff should get pushed down into Fit.
496    # at least be a method of Fit.
497    # Or intigrate it into the fit method, saving teh max and min's
498    # as att's.
499   
500    return vertex_attributes
501
502def _fit(*args, **kwargs):
503    """Private function for use with caching. Reason is that classes
504    may change their byte code between runs which is annoying.
505    """
506   
507    return Fit(*args, **kwargs)
508
Note: See TracBrowser for help on using the repository browser.