source: inundation/ga/storm_surge/pyvolution/least_squares.py @ 485

Last change on this file since 485 was 485, checked in by ole, 20 years ago

Cleaned up is sparse and least squares.
Verified that quad trees work.
Implemented sparse matrix x matrix mult and simplified
interpolate in least_squares.
Added more unit testing.

File size: 19.1 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The panalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14         
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.   
18"""
19
20
21#FIXME (Ole): Currently datapoints outside the triangular mesh are ignored.
22#             Is there a clean way of inlcuding them?
23
24
25import exceptions
26class ShapeError(exceptions.Exception): pass
27
28from general_mesh import General_mesh
29from Numeric import zeros, array, Float, Int, dot, transpose
30from LinearAlgebra import solve_linear_equations
31from sparse import Sparse
32from cg_solve import conjugate_gradient, VectorShapeError
33
34try:
35    from util import gradient
36except ImportError, e: 
37    #FIXME reduce the dependency of modules in pyvolution
38    # Have util in a dir, working like load_mesh, and get rid of this
39    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
40        """
41        """
42   
43        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)           
44        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
45        a /= det
46
47        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
48        b /= det           
49
50        return a, b
51
52
53DEFAULT_ALPHA = 0.001
54   
55def fit_to_mesh_file(mesh_file, point_file, mesh_output_file, alpha=DEFAULT_ALPHA):
56    """
57    Given a mesh file (tsh) and a point attribute file (xya), fit
58    point attributes to the mesh and write a mesh file with the
59    results.
60    """
61    from load_mesh.loadASCII import mesh_file_to_mesh_dictionary, \
62                 load_xya_file, export_trianglulation_file
63    # load in the .tsh file
64    mesh_dict = mesh_file_to_mesh_dictionary(mesh_file)
65    vertex_coordinates = mesh_dict['generatedpointlist']
66    triangles = mesh_dict['generatedtrianglelist']
67   
68    old_point_attributes = mesh_dict['generatedpointattributelist'] 
69    old_title_list = mesh_dict['generatedpointattributetitlelist']
70
71   
72    # load in the .xya file
73    point_dict = load_xya_file(point_file)
74    point_coordinates = point_dict['pointlist']
75    point_attributes = point_dict['pointattributelist']
76    title_string = point_dict['title']
77    title_list = title_string.split(',') #FIXME iffy! Hard coding title delimiter 
78    for i in range(len(title_list)):
79        title_list[i] = title_list[i].strip() 
80    #print "title_list stripped", title_list   
81    f = fit_to_mesh(vertex_coordinates,
82                    triangles,
83                    point_coordinates,
84                    point_attributes,
85                    alpha = alpha)
86   
87    # convert array to list of lists
88    new_point_attributes = f.tolist()
89
90    #FIXME have this overwrite attributes with the same title - DSG
91    #Put the newer attributes last
92    if old_title_list <> []:
93        old_title_list.extend(title_list)
94        #FIXME can this be done a faster way? - DSG
95        for i in range(len(old_point_attributes)):
96            old_point_attributes[i].extend(new_point_attributes[i])
97        mesh_dict['generatedpointattributelist'] = old_point_attributes
98        mesh_dict['generatedpointattributetitlelist'] = old_title_list
99    else:
100        mesh_dict['generatedpointattributelist'] = new_point_attributes
101        mesh_dict['generatedpointattributetitlelist'] = title_list
102   
103    export_trianglulation_file(mesh_output_file, mesh_dict)
104       
105
106def fit_to_mesh(vertex_coordinates,
107                triangles,
108                point_coordinates,
109                point_attributes,
110                alpha = DEFAULT_ALPHA):
111    """
112    Fit a smooth surface to a trianglulation,
113    given data points with attributes.
114
115         
116        Inputs:
117       
118          vertex_coordinates: List of coordinate pairs [xi, eta] of points
119          constituting mesh (or a an m x 2 Numeric array)
120       
121          triangles: List of 3-tuples (or a Numeric array) of
122          integers representing indices of all vertices in the mesh.
123
124          point_coordinates: List of coordinate pairs [x, y] of data points
125          (or an nx2 Numeric array)
126
127          alpha: Smoothing parameter.
128
129          point_attributes: Vector or array of data at the point_coordinates.
130    """
131    interp = Interpolation(vertex_coordinates,
132                           triangles,
133                           point_coordinates,
134                           alpha = alpha)
135   
136    vertex_attributes = interp.fit_points(point_attributes)
137    return vertex_attributes
138
139
140class Interpolation:
141
142    def __init__(self,
143                 vertex_coordinates,
144                 triangles,
145                 point_coordinates = None,
146                 alpha = DEFAULT_ALPHA):
147
148       
149        """ Build interpolation matrix mapping from
150        function values at vertices to function values at data points
151
152        Inputs:
153       
154          vertex_coordinates: List of coordinate pairs [xi, eta] of points
155          constituting mesh (or a an m x 2 Numeric array)
156       
157          triangles: List of 3-tuples (or a Numeric array) of
158          integers representing indices of all vertices in the mesh.
159
160          point_coordinates: List of coordinate pairs [x, y] of data points
161          (or an nx2 Numeric array)
162
163          alpha: Smoothing parameter
164         
165        """
166
167
168        #Convert input to Numeric arrays
169        vertex_coordinates = array(vertex_coordinates).astype(Float)
170        triangles = array(triangles).astype(Int)               
171       
172        #Build underlying mesh
173        self.mesh = General_mesh(vertex_coordinates, triangles)
174
175        #Smoothing parameter
176        self.alpha = alpha
177
178        #Build coefficient matrices
179        self.build_coefficient_matrix_B(point_coordinates)   
180
181
182       
183    def build_coefficient_matrix_B(self, point_coordinates=None):
184        """Build final coefficient matrix"""
185       
186
187        if self.alpha <> 0:
188            self.build_smoothing_matrix_D()
189       
190        if point_coordinates:
191
192            self.build_interpolation_matrix_A(point_coordinates)
193
194            if self.alpha <> 0:
195                self.B = self.AtA + self.alpha*self.D
196            else:
197                self.B = self.AtA
198
199
200       
201    def build_interpolation_matrix_A(self, point_coordinates):
202        """Build n x m interpolation matrix, where
203        n is the number of data points and
204        m is the number of basis functions phi_k (one per vertex)
205
206        This algorithm uses a quad tree data structure for fast binning of data points
207        """
208
209        from quad import build_quadtree
210       
211        #Convert input to Numeric arrays
212        point_coordinates = array(point_coordinates).astype(Float)
213       
214        #Build n x m interpolation matrix       
215        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
216        n = point_coordinates.shape[0]     #Nbr of data points
217       
218        self.A = Sparse(n,m)
219        self.AtA = Sparse(m,m)
220
221        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
222        root = build_quadtree(self.mesh)
223
224        #Compute matrix elements
225        for i in range(n):
226            #For each data_coordinate point
227
228            #print 'Doing %d of %d' %(i, n)
229
230            x = point_coordinates[i]
231
232            #Find vertices near x
233            candidate_vertices = root.search(x[0], x[1])
234
235            #Find triangle containing x:
236            element_found = False           
237           
238            #For all vertices in same cell as point x
239            for v in candidate_vertices:
240           
241                #for each triangle id (k) which has v as a vertex
242                for k, _ in self.mesh.vertexlist[v]:
243                   
244                    #Get the three vertex_points of candidate triangle
245                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
246                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
247                    xi2 = self.mesh.get_vertex_coordinate(k, 2)     
248
249                    #Get the three normals
250                    n0 = self.mesh.get_normal(k, 0)
251                    n1 = self.mesh.get_normal(k, 1)
252                    n2 = self.mesh.get_normal(k, 2)               
253
254                    #Compute interpolation
255                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
256                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
257                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
258
259                    #FIXME: Maybe move out to test or something
260                    epsilon = 1.0e-6
261                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
262
263                    #Check that this triangle contains the data point
264                    if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
265                        element_found = True
266                        break
267
268                if element_found is True:
269                    #Don't look for any other triangle
270                    break
271                   
272
273            #Update interpolation matrix A if necessary     
274            if element_found is True:       
275                #Assign values to matrix A
276
277                j0 = self.mesh.triangles[k,0] #Global vertex id
278                #self.A[i, j0] = sigma0
279
280                j1 = self.mesh.triangles[k,1] #Global vertex id
281                #self.A[i, j1] = sigma1
282
283                j2 = self.mesh.triangles[k,2] #Global vertex id
284                #self.A[i, j2] = sigma2
285
286                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
287                js     = [j0,j1,j2]
288
289                for j in js:
290                    self.A[i,j] = sigmas[j]
291                    for k in js:
292                        self.AtA[j,k] += sigmas[j]*sigmas[k]
293            else:
294                pass
295                #Ok if there is no triangle for datapoint
296                #(as in brute force version)
297                #raise 'Could not find triangle for point', x
298
299
300       
301    def build_interpolation_matrix_A_brute(self, point_coordinates):
302        """Build n x m interpolation matrix, where
303        n is the number of data points and
304        m is the number of basis functions phi_k (one per vertex)
305
306        This is the brute force which is too slow for large problems,
307        but could be used for testing
308        """
309
310
311       
312        #Convert input to Numeric arrays
313        point_coordinates = array(point_coordinates).astype(Float)
314       
315        #Build n x m interpolation matrix       
316        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
317        n = point_coordinates.shape[0]     #Nbr of data points
318       
319        self.A = Sparse(n,m)
320        self.AtA = Sparse(m,m)
321
322        #Compute matrix elements
323        for i in range(n):
324            #For each data_coordinate point
325
326            x = point_coordinates[i]
327            element_found = False
328            k = 0
329            while not element_found and k < len(self.mesh):
330                #For each triangle (brute force)
331                #FIXME: Real algorithm should only visit relevant triangles
332
333                #Get the three vertex_points
334                xi0 = self.mesh.get_vertex_coordinate(k, 0)
335                xi1 = self.mesh.get_vertex_coordinate(k, 1)
336                xi2 = self.mesh.get_vertex_coordinate(k, 2)                 
337
338                #Get the three normals
339                n0 = self.mesh.get_normal(k, 0)
340                n1 = self.mesh.get_normal(k, 1)
341                n2 = self.mesh.get_normal(k, 2)               
342
343                #Compute interpolation
344                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
345                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
346                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
347
348                #FIXME: Maybe move out to test or something
349                epsilon = 1.0e-6
350                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
351
352                #Check that this triangle contains data point
353                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
354                    element_found = True
355                    #Assign values to matrix A
356
357                    j0 = self.mesh.triangles[k,0] #Global vertex id
358                    #self.A[i, j0] = sigma0
359
360                    j1 = self.mesh.triangles[k,1] #Global vertex id
361                    #self.A[i, j1] = sigma1
362
363                    j2 = self.mesh.triangles[k,2] #Global vertex id
364                    #self.A[i, j2] = sigma2
365
366                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
367                    js     = [j0,j1,j2]
368
369                    for j in js:
370                        self.A[i,j] = sigmas[j]
371                        for k in js:
372                            self.AtA[j,k] += sigmas[j]*sigmas[k]
373                k = k+1
374       
375
376       
377    def get_A(self):
378        return self.A.todense() 
379
380    def get_B(self):
381        return self.B.todense()
382   
383    def get_D(self):
384        return self.D.todense()
385   
386        #FIXME: Remember to re-introduce the 1/n factor in the
387        #interpolation term
388       
389    def build_smoothing_matrix_D(self):
390        """Build m x m smoothing matrix, where
391        m is the number of basis functions phi_k (one per vertex)
392
393        The smoothing matrix is defined as
394
395        D = D1 + D2
396
397        where
398
399        [D1]_{k,l} = \int_\Omega
400           \frac{\partial \phi_k}{\partial x}
401           \frac{\partial \phi_l}{\partial x}\,
402           dx dy
403
404        [D2]_{k,l} = \int_\Omega
405           \frac{\partial \phi_k}{\partial y}
406           \frac{\partial \phi_l}{\partial y}\,
407           dx dy
408
409
410        The derivatives \frac{\partial \phi_k}{\partial x},
411        \frac{\partial \phi_k}{\partial x} for a particular triangle
412        are obtained by computing the gradient a_k, b_k for basis function k
413        """
414
415        #FIXME: algorithm might be optimised by computing local 9x9
416        #"element stiffness matrices:
417
418        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
419
420        self.D = Sparse(m,m)
421
422        #For each triangle compute contributions to D = D1+D2       
423        for i in range(len(self.mesh)):
424
425            #Get area
426            area = self.mesh.areas[i]
427
428            #Get global vertex indices
429            v0 = self.mesh.triangles[i,0]
430            v1 = self.mesh.triangles[i,1]
431            v2 = self.mesh.triangles[i,2]
432           
433            #Get the three vertex_points
434            xi0 = self.mesh.get_vertex_coordinate(i, 0)
435            xi1 = self.mesh.get_vertex_coordinate(i, 1)
436            xi2 = self.mesh.get_vertex_coordinate(i, 2)                 
437
438            #Compute gradients for each vertex
439            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
440                              1, 0, 0)
441
442            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
443                              0, 1, 0)
444
445            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
446                              0, 0, 1)           
447
448            #Compute diagonal contributions
449            self.D[v0,v0] += (a0*a0 + b0*b0)*area
450            self.D[v1,v1] += (a1*a1 + b1*b1)*area
451            self.D[v2,v2] += (a2*a2 + b2*b2)*area           
452
453            #Compute contributions for basis functions sharing edges
454            e01 = (a0*a1 + b0*b1)*area
455            self.D[v0,v1] += e01
456            self.D[v1,v0] += e01
457
458            e12 = (a1*a2 + b1*b2)*area
459            self.D[v1,v2] += e12
460            self.D[v2,v1] += e12
461
462            e20 = (a2*a0 + b2*b0)*area
463            self.D[v2,v0] += e20
464            self.D[v0,v2] += e20             
465
466           
467    def fit(self, z):
468        """Fit a smooth surface to given 1d array of data points z.
469
470        The smooth surface is computed at each vertex in the underlying
471        mesh using the formula given in the module doc string.
472
473        Pre Condition:
474          self.A, self.At and self.B have been initialised
475         
476        Inputs:
477          z: Single 1d vector or array of data at the point_coordinates.
478        """
479
480        #Convert input to Numeric arrays
481        z = array(z).astype(Float)
482
483
484        if len(z.shape) > 1 :
485            raise VectorShapeError, 'Can only deal with 1d data vector'
486       
487        #Compute right hand side based on data
488        Atz = self.A.trans_mult(z)
489
490       
491        #Check sanity
492        n, m = self.A.shape
493        if n<m and self.alpha == 0.0:
494            msg = 'ERROR (least_squares): Too few data points\n'
495            msg += 'There only %d data points. Need at least %d\n' %(n,m)
496            msg += 'Alternatively, increase smoothing parameter alpha' 
497            raise msg
498
499
500        return conjugate_gradient(self.B, Atz, Atz,imax=2*len(Atz) )
501        #FIXME: Should we store the result here for later use? (ON)       
502
503           
504    def fit_points(self, z):
505        """Like fit, but more robust when each point has two or more attributes
506        FIXME(Ole): The name fit_points doesn't carry any meaning for me.
507        How about something like fit_multiple or fit_columns?
508        """
509       
510        try:
511            return self.fit(z)
512        except VectorShapeError, e:
513            # broadcasting is not supported.
514
515            #Convert input to Numeric arrays
516            z = array(z).astype(Float)
517           
518            #Build n x m interpolation matrix       
519            m = self.mesh.coordinates.shape[0] #Number of vertices
520            n = z.shape[1]               #Number of data points         
521
522            f = zeros((m,n), Float) #Resulting columns
523           
524            for i in range(z.shape[1]):
525                f[:,i] = self.fit(z[:,i])
526               
527            return f
528           
529       
530    def interpolate(self, f):
531        """Compute predicted values at data points implied in self.A.
532
533        The mesh values representing a smooth surface are
534        assumed to be specified in f. This argument could,
535        for example have been obtained from the method self.fit()
536       
537        Pre Condition:
538          self.A has been initialised
539       
540        Inputs:
541          f: Vector or array of data at the mesh vertices.
542          If f is an array, interpolation will be done for each column
543        """
544
545        return self.A * f
546       
547           
548     
549    #FIXME: We will need a method 'evaluate(self):' that will interpolate
550    #a computed surface living on the mesh onto a collection of
551    #arbitrary data points
552    #
553    #Precondition: self.fit(z) has stored its result in self.f.
554    #
555    #Input: data_points
556    #Algorithm:
557    #  1 Build a new temporary A matrix based on mesh and new data points
558    #  2 Apply it to self.f (return A*self.f)
559    #
560    # ON
561
562
563
564#-------------------------------------------------------------
565if __name__ == "__main__":
566    """
567    Load in a mesh and data points with attributes.
568    Fit the attributes to the mesh.
569    Save a new mesh file.
570    """
571    import os, sys
572    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh alpha"\
573            %os.path.basename(sys.argv[0])
574
575    if len(sys.argv) < 4:
576        print usage
577    else:
578        mesh_file = sys.argv[1]
579        point_file = sys.argv[2]
580        mesh_output_file = sys.argv[3]
581        if len(sys.argv) > 4:
582            alpha = sys.argv[4]
583        else:
584            alpha = DEFAULT_ALPHA
585        fit_to_mesh_file(mesh_file, point_file, mesh_output_file, alpha)
586       
Note: See TracBrowser for help on using the repository browser.