source: inundation/ga/storm_surge/pyvolution/least_squares.py @ 544

Last change on this file since 544 was 488, checked in by ole, 20 years ago

Played with regridding of Cornell data

File size: 18.7 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14         
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.   
18"""
19
20
21#FIXME (Ole): Currently datapoints outside the triangular mesh are ignored.
22#             Is there a clean way of including them?
23
24
25import exceptions
26class ShapeError(exceptions.Exception): pass
27
28from general_mesh import General_mesh
29from Numeric import zeros, array, Float, Int, dot, transpose
30from LinearAlgebra import solve_linear_equations
31from sparse import Sparse
32from cg_solve import conjugate_gradient, VectorShapeError
33
34try:
35    from util import gradient
36except ImportError, e: 
37    #FIXME reduce the dependency of modules in pyvolution
38    # Have util in a dir, working like load_mesh, and get rid of this
39    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
40        """
41        """
42   
43        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)           
44        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
45        a /= det
46
47        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
48        b /= det           
49
50        return a, b
51
52
53DEFAULT_ALPHA = 0.001
54   
55def fit_to_mesh_file(mesh_file, point_file, mesh_output_file, alpha=DEFAULT_ALPHA):
56    """
57    Given a mesh file (tsh) and a point attribute file (xya), fit
58    point attributes to the mesh and write a mesh file with the
59    results.
60    """
61    from load_mesh.loadASCII import mesh_file_to_mesh_dictionary, \
62                 load_xya_file, export_trianglulation_file
63    # load in the .tsh file
64    mesh_dict = mesh_file_to_mesh_dictionary(mesh_file)
65    vertex_coordinates = mesh_dict['generatedpointlist']
66    triangles = mesh_dict['generatedtrianglelist']
67   
68    old_point_attributes = mesh_dict['generatedpointattributelist'] 
69    old_title_list = mesh_dict['generatedpointattributetitlelist']
70
71   
72    # load in the .xya file
73    point_dict = load_xya_file(point_file)
74    point_coordinates = point_dict['pointlist']
75    point_attributes = point_dict['pointattributelist']
76    title_string = point_dict['title']
77    title_list = title_string.split(',') #FIXME iffy! Hard coding title delimiter 
78    for i in range(len(title_list)):
79        title_list[i] = title_list[i].strip() 
80    #print "title_list stripped", title_list   
81    f = fit_to_mesh(vertex_coordinates,
82                    triangles,
83                    point_coordinates,
84                    point_attributes,
85                    alpha = alpha)
86   
87    # convert array to list of lists
88    new_point_attributes = f.tolist()
89
90    #FIXME have this overwrite attributes with the same title - DSG
91    #Put the newer attributes last
92    if old_title_list <> []:
93        old_title_list.extend(title_list)
94        #FIXME can this be done a faster way? - DSG
95        for i in range(len(old_point_attributes)):
96            old_point_attributes[i].extend(new_point_attributes[i])
97        mesh_dict['generatedpointattributelist'] = old_point_attributes
98        mesh_dict['generatedpointattributetitlelist'] = old_title_list
99    else:
100        mesh_dict['generatedpointattributelist'] = new_point_attributes
101        mesh_dict['generatedpointattributetitlelist'] = title_list
102   
103    export_trianglulation_file(mesh_output_file, mesh_dict)
104       
105
106def fit_to_mesh(vertex_coordinates,
107                triangles,
108                point_coordinates,
109                point_attributes,
110                alpha = DEFAULT_ALPHA):
111    """
112    Fit a smooth surface to a trianglulation,
113    given data points with attributes.
114
115         
116        Inputs:
117       
118          vertex_coordinates: List of coordinate pairs [xi, eta] of points
119          constituting mesh (or a an m x 2 Numeric array)
120       
121          triangles: List of 3-tuples (or a Numeric array) of
122          integers representing indices of all vertices in the mesh.
123
124          point_coordinates: List of coordinate pairs [x, y] of data points
125          (or an nx2 Numeric array)
126
127          alpha: Smoothing parameter.
128
129          point_attributes: Vector or array of data at the point_coordinates.
130    """
131    interp = Interpolation(vertex_coordinates,
132                           triangles,
133                           point_coordinates,
134                           alpha = alpha)
135   
136    vertex_attributes = interp.fit_points(point_attributes)
137    return vertex_attributes
138
139
140class Interpolation:
141
142    def __init__(self,
143                 vertex_coordinates,
144                 triangles,
145                 point_coordinates = None,
146                 alpha = DEFAULT_ALPHA):
147
148       
149        """ Build interpolation matrix mapping from
150        function values at vertices to function values at data points
151
152        Inputs:
153       
154          vertex_coordinates: List of coordinate pairs [xi, eta] of
155          points constituting mesh (or a an m x 2 Numeric array)
156       
157          triangles: List of 3-tuples (or a Numeric array) of
158          integers representing indices of all vertices in the mesh.
159
160          point_coordinates: List of coordinate pairs [x, y] of
161          data points (or an nx2 Numeric array)
162          If point_coordinates is absent, only smoothing matrix will
163          be built
164
165          alpha: Smoothing parameter
166         
167        """
168
169
170        #Convert input to Numeric arrays
171        vertex_coordinates = array(vertex_coordinates).astype(Float)
172        triangles = array(triangles).astype(Int)               
173       
174        #Build underlying mesh
175        self.mesh = General_mesh(vertex_coordinates, triangles)
176
177        #Smoothing parameter
178        self.alpha = alpha
179
180        #Build coefficient matrices
181        self.build_coefficient_matrix_B(point_coordinates)   
182
183
184       
185    def build_coefficient_matrix_B(self, point_coordinates=None):
186        """Build final coefficient matrix"""
187       
188
189        if self.alpha <> 0:
190            self.build_smoothing_matrix_D()
191       
192        if point_coordinates:
193
194            self.build_interpolation_matrix_A(point_coordinates)
195
196            if self.alpha <> 0:
197                self.B = self.AtA + self.alpha*self.D
198            else:
199                self.B = self.AtA
200
201
202       
203    def build_interpolation_matrix_A(self, point_coordinates):
204        """Build n x m interpolation matrix, where
205        n is the number of data points and
206        m is the number of basis functions phi_k (one per vertex)
207
208        This algorithm uses a quad tree data structure for fast binning of data points
209        """
210
211        from quad import build_quadtree
212       
213        #Convert input to Numeric arrays
214        point_coordinates = array(point_coordinates).astype(Float)
215       
216        #Build n x m interpolation matrix       
217        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
218        n = point_coordinates.shape[0]     #Nbr of data points
219       
220        self.A = Sparse(n,m)
221        self.AtA = Sparse(m,m)
222
223        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
224        root = build_quadtree(self.mesh)
225
226        #Compute matrix elements
227        for i in range(n):
228            #For each data_coordinate point
229
230            #print 'Doing %d of %d' %(i, n)
231
232            x = point_coordinates[i]
233
234            #Find vertices near x
235            candidate_vertices = root.search(x[0], x[1])
236
237            #Find triangle containing x:
238            element_found = False           
239           
240            #For all vertices in same cell as point x
241            for v in candidate_vertices:
242           
243                #for each triangle id (k) which has v as a vertex
244                for k, _ in self.mesh.vertexlist[v]:
245                   
246                    #Get the three vertex_points of candidate triangle
247                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
248                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
249                    xi2 = self.mesh.get_vertex_coordinate(k, 2)     
250
251                    #Get the three normals
252                    n0 = self.mesh.get_normal(k, 0)
253                    n1 = self.mesh.get_normal(k, 1)
254                    n2 = self.mesh.get_normal(k, 2)               
255
256                    #Compute interpolation
257                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
258                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
259                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
260
261                    #FIXME: Maybe move out to test or something
262                    epsilon = 1.0e-6
263                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
264
265                    #Check that this triangle contains the data point
266                    if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
267                        element_found = True
268                        break
269
270                if element_found is True:
271                    #Don't look for any other triangle
272                    break
273                   
274
275            #Update interpolation matrix A if necessary     
276            if element_found is True:       
277                #Assign values to matrix A
278
279                j0 = self.mesh.triangles[k,0] #Global vertex id
280                #self.A[i, j0] = sigma0
281
282                j1 = self.mesh.triangles[k,1] #Global vertex id
283                #self.A[i, j1] = sigma1
284
285                j2 = self.mesh.triangles[k,2] #Global vertex id
286                #self.A[i, j2] = sigma2
287
288                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
289                js     = [j0,j1,j2]
290
291                for j in js:
292                    self.A[i,j] = sigmas[j]
293                    for k in js:
294                        self.AtA[j,k] += sigmas[j]*sigmas[k]
295            else:
296                pass
297                #Ok if there is no triangle for datapoint
298                #(as in brute force version)
299                #raise 'Could not find triangle for point', x
300
301
302       
303    def build_interpolation_matrix_A_brute(self, point_coordinates):
304        """Build n x m interpolation matrix, where
305        n is the number of data points and
306        m is the number of basis functions phi_k (one per vertex)
307
308        This is the brute force which is too slow for large problems,
309        but could be used for testing
310        """
311
312
313       
314        #Convert input to Numeric arrays
315        point_coordinates = array(point_coordinates).astype(Float)
316       
317        #Build n x m interpolation matrix       
318        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
319        n = point_coordinates.shape[0]     #Nbr of data points
320       
321        self.A = Sparse(n,m)
322        self.AtA = Sparse(m,m)
323
324        #Compute matrix elements
325        for i in range(n):
326            #For each data_coordinate point
327
328            x = point_coordinates[i]
329            element_found = False
330            k = 0
331            while not element_found and k < len(self.mesh):
332                #For each triangle (brute force)
333                #FIXME: Real algorithm should only visit relevant triangles
334
335                #Get the three vertex_points
336                xi0 = self.mesh.get_vertex_coordinate(k, 0)
337                xi1 = self.mesh.get_vertex_coordinate(k, 1)
338                xi2 = self.mesh.get_vertex_coordinate(k, 2)                 
339
340                #Get the three normals
341                n0 = self.mesh.get_normal(k, 0)
342                n1 = self.mesh.get_normal(k, 1)
343                n2 = self.mesh.get_normal(k, 2)               
344
345                #Compute interpolation
346                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
347                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
348                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
349
350                #FIXME: Maybe move out to test or something
351                epsilon = 1.0e-6
352                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
353
354                #Check that this triangle contains data point
355                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
356                    element_found = True
357                    #Assign values to matrix A
358
359                    j0 = self.mesh.triangles[k,0] #Global vertex id
360                    #self.A[i, j0] = sigma0
361
362                    j1 = self.mesh.triangles[k,1] #Global vertex id
363                    #self.A[i, j1] = sigma1
364
365                    j2 = self.mesh.triangles[k,2] #Global vertex id
366                    #self.A[i, j2] = sigma2
367
368                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
369                    js     = [j0,j1,j2]
370
371                    for j in js:
372                        self.A[i,j] = sigmas[j]
373                        for k in js:
374                            self.AtA[j,k] += sigmas[j]*sigmas[k]
375                k = k+1
376       
377
378       
379    def get_A(self):
380        return self.A.todense() 
381
382    def get_B(self):
383        return self.B.todense()
384   
385    def get_D(self):
386        return self.D.todense()
387   
388        #FIXME: Remember to re-introduce the 1/n factor in the
389        #interpolation term
390       
391    def build_smoothing_matrix_D(self):
392        """Build m x m smoothing matrix, where
393        m is the number of basis functions phi_k (one per vertex)
394
395        The smoothing matrix is defined as
396
397        D = D1 + D2
398
399        where
400
401        [D1]_{k,l} = \int_\Omega
402           \frac{\partial \phi_k}{\partial x}
403           \frac{\partial \phi_l}{\partial x}\,
404           dx dy
405
406        [D2]_{k,l} = \int_\Omega
407           \frac{\partial \phi_k}{\partial y}
408           \frac{\partial \phi_l}{\partial y}\,
409           dx dy
410
411
412        The derivatives \frac{\partial \phi_k}{\partial x},
413        \frac{\partial \phi_k}{\partial x} for a particular triangle
414        are obtained by computing the gradient a_k, b_k for basis function k
415        """
416
417        #FIXME: algorithm might be optimised by computing local 9x9
418        #"element stiffness matrices:
419
420        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
421
422        self.D = Sparse(m,m)
423
424        #For each triangle compute contributions to D = D1+D2       
425        for i in range(len(self.mesh)):
426
427            #Get area
428            area = self.mesh.areas[i]
429
430            #Get global vertex indices
431            v0 = self.mesh.triangles[i,0]
432            v1 = self.mesh.triangles[i,1]
433            v2 = self.mesh.triangles[i,2]
434           
435            #Get the three vertex_points
436            xi0 = self.mesh.get_vertex_coordinate(i, 0)
437            xi1 = self.mesh.get_vertex_coordinate(i, 1)
438            xi2 = self.mesh.get_vertex_coordinate(i, 2)                 
439
440            #Compute gradients for each vertex
441            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
442                              1, 0, 0)
443
444            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
445                              0, 1, 0)
446
447            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
448                              0, 0, 1)           
449
450            #Compute diagonal contributions
451            self.D[v0,v0] += (a0*a0 + b0*b0)*area
452            self.D[v1,v1] += (a1*a1 + b1*b1)*area
453            self.D[v2,v2] += (a2*a2 + b2*b2)*area           
454
455            #Compute contributions for basis functions sharing edges
456            e01 = (a0*a1 + b0*b1)*area
457            self.D[v0,v1] += e01
458            self.D[v1,v0] += e01
459
460            e12 = (a1*a2 + b1*b2)*area
461            self.D[v1,v2] += e12
462            self.D[v2,v1] += e12
463
464            e20 = (a2*a0 + b2*b0)*area
465            self.D[v2,v0] += e20
466            self.D[v0,v2] += e20             
467
468           
469    def fit(self, z):
470        """Fit a smooth surface to given 1d array of data points z.
471
472        The smooth surface is computed at each vertex in the underlying
473        mesh using the formula given in the module doc string.
474
475        Pre Condition:
476          self.A, self.At and self.B have been initialised
477         
478        Inputs:
479          z: Single 1d vector or array of data at the point_coordinates.
480        """
481
482        #Convert input to Numeric arrays
483        z = array(z).astype(Float)
484
485
486        if len(z.shape) > 1 :
487            raise VectorShapeError, 'Can only deal with 1d data vector'
488       
489        #Compute right hand side based on data
490        Atz = self.A.trans_mult(z)
491
492       
493        #Check sanity
494        n, m = self.A.shape
495        if n<m and self.alpha == 0.0:
496            msg = 'ERROR (least_squares): Too few data points\n'
497            msg += 'There only %d data points. Need at least %d\n' %(n,m)
498            msg += 'Alternatively, increase smoothing parameter alpha' 
499            raise msg
500
501
502        return conjugate_gradient(self.B, Atz, Atz,imax=2*len(Atz) )
503        #FIXME: Should we store the result here for later use? (ON)       
504
505           
506    def fit_points(self, z):
507        """Like fit, but more robust when each point has two or more attributes
508        FIXME (Ole): The name fit_points doesn't carry any meaning
509        for me. How about something like fit_multiple or fit_columns?
510        """
511       
512        try:
513            return self.fit(z)
514        except VectorShapeError, e:
515            # broadcasting is not supported.
516
517            #Convert input to Numeric arrays
518            z = array(z).astype(Float)
519           
520            #Build n x m interpolation matrix       
521            m = self.mesh.coordinates.shape[0] #Number of vertices
522            n = z.shape[1]               #Number of data points         
523
524            f = zeros((m,n), Float) #Resulting columns
525           
526            for i in range(z.shape[1]):
527                f[:,i] = self.fit(z[:,i])
528               
529            return f
530           
531       
532    def interpolate(self, f):
533        """Evaluate smooth surface f at data points implied in self.A.
534
535        The mesh values representing a smooth surface are
536        assumed to be specified in f. This argument could,
537        for example have been obtained from the method self.fit()
538       
539        Pre Condition:
540          self.A has been initialised
541       
542        Inputs:
543          f: Vector or array of data at the mesh vertices.
544          If f is an array, interpolation will be done for each column
545        """
546
547        return self.A * f
548       
549           
550#-------------------------------------------------------------
551if __name__ == "__main__":
552    """
553    Load in a mesh and data points with attributes.
554    Fit the attributes to the mesh.
555    Save a new mesh file.
556    """
557    import os, sys
558    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh alpha"\
559            %os.path.basename(sys.argv[0])
560
561    if len(sys.argv) < 4:
562        print usage
563    else:
564        mesh_file = sys.argv[1]
565        point_file = sys.argv[2]
566        mesh_output_file = sys.argv[3]
567        if len(sys.argv) > 4:
568            alpha = sys.argv[4]
569        else:
570            alpha = DEFAULT_ALPHA
571        fit_to_mesh_file(mesh_file, point_file, mesh_output_file, alpha)
572       
Note: See TracBrowser for help on using the repository browser.