source: inundation/fit_interpolate/fit.py @ 2891

Last change on this file since 2891 was 2802, checked in by duncan, 19 years ago

adding the fit class

File size: 13.6 KB
Line 
1"""Least squares fitting.
2
3   Implements a penalised least-squares fit.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20from geospatial_data.geospatial_data import Geospatial_data
21from fit_interpolate.general_fit_interpolate import FitInterpolate
22
23DEFAULT_ALPHA = 0.001
24
25
26class Fit(FitInterpolate):
27   
28    def __init__(self,
29                 vertex_coordinates,
30                 triangles,
31                 mesh_origin=None,
32                 alpha = None,
33                 verbose=False,
34                 max_vertices_per_cell=30):
35
36
37        """
38        Fit data at points to the vertices of a mesh.
39
40        Inputs:
41
42          vertex_coordinates: List of coordinate pairs [xi, eta] of
43              points constituting a mesh (or an m x 2 Numeric array or
44              a geospatial object)
45              Points may appear multiple times
46              (e.g. if vertices have discontinuities)
47
48          triangles: List of 3-tuples (or a Numeric array) of
49              integers representing indices of all vertices in the mesh.
50
51          mesh_origin: A geo_reference object or 3-tuples consisting of
52              UTM zone, easting and northing.
53              If specified vertex coordinates are assumed to be
54              relative to their respective origins.
55
56          max_vertices_per_cell: Number of vertices in a quad tree cell
57          at which the cell is split into 4.
58
59          Note: Don't supply a vertex coords as a geospatial object and
60              a mesh origin, since geospatial has its own mesh origin.
61        """
62
63        # Initialise variabels
64        #self._A_can_be_reused = False
65        #self._point_coordinates = None
66
67        if alpha is None:
68            self.alpha = DEFAULT_ALPHA
69        else:   
70            self.alpha = alpha
71       
72        FitInterpolate.__init__(self,
73                 vertex_coordinates,
74                 triangles,
75                 mesh_origin,
76                 verbose,
77                 max_vertices_per_cell)
78       
79        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (vertices)
80       
81        #Build Atz and AtA matrix
82        self.AtA = Sparse(m,m)
83        self.Atz = zeros((m,att_num), Float)
84     
85
86    def _build_coefficient_matrix_B(self,
87                                  verbose = False):
88        """Build final coefficient matrix"""
89
90
91    def _build_smoothing_matrix_D(self):
92        """Build m x m smoothing matrix, where
93        m is the number of basis functions phi_k (one per vertex)
94
95        The smoothing matrix is defined as
96
97        D = D1 + D2
98
99        where
100
101        [D1]_{k,l} = \int_\Omega
102           \frac{\partial \phi_k}{\partial x}
103           \frac{\partial \phi_l}{\partial x}\,
104           dx dy
105
106        [D2]_{k,l} = \int_\Omega
107           \frac{\partial \phi_k}{\partial y}
108           \frac{\partial \phi_l}{\partial y}\,
109           dx dy
110
111
112        The derivatives \frac{\partial \phi_k}{\partial x},
113        \frac{\partial \phi_k}{\partial x} for a particular triangle
114        are obtained by computing the gradient a_k, b_k for basis function k
115        """
116
117    def _build_matrix_AtA_Atz(self,
118                              point_coordinates,
119                              z,
120                              verbose = False):
121        """Build:
122        AtA  m x m  interpolation matrix, and,
123        Atz  m x a  interpolation matrix where,
124        m is the number of basis functions phi_k (one per vertex)
125        a is the number of data attributes
126
127        This algorithm uses a quad tree data structure for fast binning of
128        data points.
129
130        Preconditions
131        z and points are numeric
132        Point_coordindates and mesh vertices have the same origin.
133        """
134       
135        if verbose: print 'Getting indices inside mesh boundary'
136        #print "self.mesh.get_boundary_polygon()",self.mesh.get_boundary_polygon()
137        self.inside_poly_indices, self.outside_poly_indices  = \
138                     in_and_outside_polygon(point_coordinates,
139                                            self.mesh.get_boundary_polygon(),
140                                            closed = True, verbose = verbose)
141        #print "self.inside_poly_indices",self.inside_poly_indices
142        #print "self.outside_poly_indices",self.outside_poly_indices
143
144       
145        #Compute matrix elements for points inside the mesh
146        for i in self.inside_poly_indices:
147            #For each data_coordinate point
148            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
149            x = point_coordinates[i]
150            element_found, sigma0, sigma1, sigma2, k = \
151                           search_tree_of_vertices(self.root, self.mesh, x)
152            #Update interpolation matrix A if necessary
153            if element_found is True:
154                #Assign values to matrix A
155
156                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
157                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
158                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
159
160                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
161                js     = [j0,j1,j2]
162
163                for j in js:
164                    self.Atz[j] +=  sigmas[j]*z[i]                 
165                    for k in js:
166                        if interp_only == False:
167                            self.AtA[j,k] += sigmas[j]*sigmas[k]
168            else:
169                msg = 'Could not find triangle for point', x
170                raise Exception(msg)
171   
172       
173    def fit(self, point_coordinates=point_coordinates, z=z):
174        """Fit a smooth surface to given 1d array of data points z.
175
176        The smooth surface is computed at each vertex in the underlying
177        mesh using the formula given in the module doc string.
178
179        Inputs:
180        point_coordinates: The co-ordinates of the data points.
181              List of coordinate pairs [x, y] of
182              data points or an nx2 Numeric array or a Geospatial_data object
183          z: Single 1d vector or array of data at the point_coordinates.
184         
185        """
186        # build ata and atz
187        # solve fit
188
189       
190    def build_fit_subset(self, point_coordinates, z):
191        """Fit a smooth surface to given 1d array of data points z.
192
193        The smooth surface is computed at each vertex in the underlying
194        mesh using the formula given in the module doc string.
195
196        Inputs:
197        point_coordinates: The co-ordinates of the data points.
198              List of coordinate pairs [x, y] of
199              data points or an nx2 Numeric array or a Geospatial_data object
200          z: Single 1d vector or array of data at the point_coordinates.
201
202        """
203        #Note: Don't get the z info from Geospatial_data.attributes yet.
204        # That means fit has to handle attribute title info.
205
206        #FIXME(DSG-DSG): Check that the vert and point coords
207        #have the same zone.
208        if isinstance(point_coordinates,Geospatial_data):
209            point_coordinates = vertex_coordinates.get_data_points( \
210                absolute = True)
211       
212        #Convert input to Numeric arrays
213        z = ensure_numeric(z, Float)
214        point_coordinates = ensure_numeric(point_coordinates, Float)
215
216
217
218############################################################################
219
220def fit_to_mesh(vertex_coordinates,
221                triangles,
222                point_coordinates,
223                point_attributes,
224                alpha = DEFAULT_ALPHA,
225                verbose = False,
226                acceptable_overshoot = 1.01,
227                expand_search = False,
228                data_origin = None,
229                mesh_origin = None,
230                precrop = False,
231                use_cache = False):
232    """
233    Fit a smooth surface to a triangulation,
234    given data points with attributes.
235
236
237        Inputs:
238
239          vertex_coordinates: List of coordinate pairs [xi, eta] of points
240          constituting mesh (or a an m x 2 Numeric array)
241
242          triangles: List of 3-tuples (or a Numeric array) of
243          integers representing indices of all vertices in the mesh.
244
245          point_coordinates: List of coordinate pairs [x, y] of data points
246          (or an nx2 Numeric array)
247
248          alpha: Smoothing parameter.
249
250          acceptable overshoot: controls the allowed factor by which fitted values
251          may exceed the value of input data. The lower limit is defined
252          as min(z) - acceptable_overshoot*delta z and upper limit
253          as max(z) + acceptable_overshoot*delta z
254         
255
256          point_attributes: Vector or array of data at the point_coordinates.
257
258          data_origin and mesh_origin are 3-tuples consisting of
259          UTM zone, easting and northing. If specified
260          point coordinates and vertex coordinates are assumed to be
261          relative to their respective origins.
262
263    """
264
265    if use_cache is True:
266        interp = cache(_fit,
267                       (vertex_coordinates,
268                        triangles),
269                       {'verbose': verbose,
270                        'mesh_origin': mesh_origin},
271                       verbose = verbose)       
272       
273    else:
274        interp = Interpolation(vertex_coordinates,
275                               triangles,
276                               verbose = verbose,
277                               mesh_origin = mesh_origin)
278       
279    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
280
281       
282#                               point_coordinates,
283#                               data_origin = data_origin,
284
285
286    #Sanity check
287    point_coordinates = ensure_numeric(point_coordinates)
288    vertex_coordinates = ensure_numeric(vertex_coordinates)
289
290    #Data points
291    X = point_coordinates[:,0]
292    Y = point_coordinates[:,1] 
293    Z = ensure_numeric(point_attributes)
294    if len(Z.shape) == 1:
295        Z = Z[:, NewAxis]
296       
297
298    #Data points inside mesh boundary
299    indices = interp.point_indices
300    if indices is not None:   
301        Xc = take(X, indices)
302        Yc = take(Y, indices)   
303        Zc = take(Z, indices)
304    else:
305        Xc = X
306        Yc = Y 
307        Zc = Z       
308   
309    #Vertex coordinates
310    Xi = vertex_coordinates[:,0]
311    Eta = vertex_coordinates[:,1]       
312    Zeta = ensure_numeric(vertex_attributes)
313    if len(Zeta.shape) == 1:
314        Zeta = Zeta[:, NewAxis]   
315
316    for i in range(Zeta.shape[1]): #For each attribute
317        zeta = Zeta[:,i]
318        z = Z[:,i]               
319        zc = Zc[:,i]
320
321        max_zc = max(zc)
322        min_zc = min(zc)
323        delta_zc = max_zc-min_zc
324        upper_limit = max_zc + delta_zc*acceptable_overshoot
325        lower_limit = min_zc - delta_zc*acceptable_overshoot       
326       
327
328        if max(zeta) > upper_limit or min(zeta) < lower_limit:
329            msg = 'Least sqares produced values outside the allowed '
330            msg += 'range [%f, %f].\n' %(lower_limit, upper_limit)
331            msg += 'z in [%f, %f], zeta in [%f, %f].\n' %(min_zc, max_zc,
332                                                          min(zeta), max(zeta))
333            msg += 'If greater range is needed, increase the value of '
334            msg += 'acceptable_fit_overshoot (currently %.2f).\n' %(acceptable_overshoot)
335
336
337            offending_vertices = (zeta > upper_limit or zeta < lower_limit)
338            Xi_c = compress(offending_vertices, Xi)
339            Eta_c = compress(offending_vertices, Eta)
340            offending_coordinates = concatenate((Xi_c[:, NewAxis],
341                                                 Eta_c[:, NewAxis]),
342                                                axis=1)
343
344            msg += 'Offending locations:\n %s' %(offending_coordinates)
345           
346            raise FittingError, msg
347
348
349   
350        if verbose:
351            print '+------------------------------------------------'
352            print 'Least squares statistics'
353            print '+------------------------------------------------'   
354            print 'points: %d points' %(len(z))
355            print '    x in [%f, %f]'%(min(X), max(X))
356            print '    y in [%f, %f]'%(min(Y), max(Y))
357            print '    z in [%f, %f]'%(min(z), max(z))
358            print
359
360            if indices is not None:
361                print 'Cropped points: %d points' %(len(zc))
362                print '    x in [%f, %f]'%(min(Xc), max(Xc))
363                print '    y in [%f, %f]'%(min(Yc), max(Yc))
364                print '    z in [%f, %f]'%(min(zc), max(zc))
365                print
366           
367
368            print 'Mesh: %d vertices' %(len(zeta))
369            print '    xi in [%f, %f]'%(min(Xi), max(Xi))
370            print '    eta in [%f, %f]'%(min(Eta), max(Eta))
371            print '    zeta in [%f, %f]'%(min(zeta), max(zeta))
372            print '+------------------------------------------------'
373
374    return vertex_attributes
375
376
377def _fit(*args, **kwargs):
378    """Private function for use with caching. Reason is that classes
379    may change their byte code between runs which is annoying.
380    """
381   
382    return Fit(*args, **kwargs)
383
Note: See TracBrowser for help on using the repository browser.