source: inundation/ga/storm_surge/pyvolution/least_squares.py @ 613

Last change on this file since 613 was 611, checked in by duncan, 20 years ago

Optimised the least_squares algorithm for building A matrix

File size: 21.1 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14         
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.   
18"""
19
20
21#FIXME (Ole): Currently datapoints outside the triangular mesh are ignored.
22#             Is there a clean way of including them?
23
24
25
26import exceptions
27class ShapeError(exceptions.Exception): pass
28
29from general_mesh import General_mesh
30from Numeric import zeros, array, Float, Int, dot, transpose
31from LinearAlgebra import solve_linear_equations
32from sparse import Sparse, Sparse_CSR
33from cg_solve import conjugate_gradient, VectorShapeError
34
35try:
36    from util import gradient
37except ImportError, e: 
38    #FIXME reduce the dependency of modules in pyvolution
39    # Have util in a dir, working like load_mesh, and get rid of this
40    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
41        """
42        """
43   
44        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)           
45        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
46        a /= det
47
48        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
49        b /= det           
50
51        return a, b
52
53
54DEFAULT_ALPHA = 0.001
55   
56def fit_to_mesh_file(mesh_file, point_file, mesh_output_file, alpha=DEFAULT_ALPHA):
57    """
58    Given a mesh file (tsh) and a point attribute file (xya), fit
59    point attributes to the mesh and write a mesh file with the
60    results.
61    """
62    from load_mesh.loadASCII import mesh_file_to_mesh_dictionary, \
63                 load_xya_file, export_trianglulation_file
64    # load in the .tsh file
65    mesh_dict = mesh_file_to_mesh_dictionary(mesh_file)
66    vertex_coordinates = mesh_dict['generatedpointlist']
67    triangles = mesh_dict['generatedtrianglelist']
68   
69    old_point_attributes = mesh_dict['generatedpointattributelist'] 
70    old_title_list = mesh_dict['generatedpointattributetitlelist']
71
72   
73    # load in the .xya file
74    try:
75        point_dict = load_xya_file(point_file)
76    except SyntaxError,e:
77        point_dict = load_xya_file(point_file,delimiter = ' ')
78    point_coordinates = point_dict['pointlist']
79    point_attributes = point_dict['pointattributelist']
80    title_string = point_dict['title']
81    title_list = title_string.split(',') #FIXME iffy! Hard coding title delimiter 
82    for i in range(len(title_list)):
83        title_list[i] = title_list[i].strip() 
84    #print "title_list stripped", title_list   
85    f = fit_to_mesh(vertex_coordinates,
86                    triangles,
87                    point_coordinates,
88                    point_attributes,
89                    alpha = alpha)
90   
91    # convert array to list of lists
92    new_point_attributes = f.tolist()
93
94    #FIXME have this overwrite attributes with the same title - DSG
95    #Put the newer attributes last
96    if old_title_list <> []:
97        old_title_list.extend(title_list)
98        #FIXME can this be done a faster way? - DSG
99        for i in range(len(old_point_attributes)):
100            old_point_attributes[i].extend(new_point_attributes[i])
101        mesh_dict['generatedpointattributelist'] = old_point_attributes
102        mesh_dict['generatedpointattributetitlelist'] = old_title_list
103    else:
104        mesh_dict['generatedpointattributelist'] = new_point_attributes
105        mesh_dict['generatedpointattributetitlelist'] = title_list
106   
107    export_trianglulation_file(mesh_output_file, mesh_dict)
108       
109
110def fit_to_mesh(vertex_coordinates,
111                triangles,
112                point_coordinates,
113                point_attributes,
114                alpha = DEFAULT_ALPHA):
115    """
116    Fit a smooth surface to a trianglulation,
117    given data points with attributes.
118
119         
120        Inputs:
121       
122          vertex_coordinates: List of coordinate pairs [xi, eta] of points
123          constituting mesh (or a an m x 2 Numeric array)
124       
125          triangles: List of 3-tuples (or a Numeric array) of
126          integers representing indices of all vertices in the mesh.
127
128          point_coordinates: List of coordinate pairs [x, y] of data points
129          (or an nx2 Numeric array)
130
131          alpha: Smoothing parameter.
132
133          point_attributes: Vector or array of data at the point_coordinates.
134    """
135    interp = Interpolation(vertex_coordinates,
136                           triangles,
137                           point_coordinates,
138                           alpha = alpha)
139   
140    vertex_attributes = interp.fit_points(point_attributes)
141    return vertex_attributes
142
143
144class Interpolation:
145
146    def __init__(self,
147                 vertex_coordinates,
148                 triangles,
149                 point_coordinates = None,
150                 alpha = DEFAULT_ALPHA):
151
152       
153        """ Build interpolation matrix mapping from
154        function values at vertices to function values at data points
155
156        Inputs:
157       
158          vertex_coordinates: List of coordinate pairs [xi, eta] of
159          points constituting mesh (or a an m x 2 Numeric array)
160       
161          triangles: List of 3-tuples (or a Numeric array) of
162          integers representing indices of all vertices in the mesh.
163
164          point_coordinates: List of coordinate pairs [x, y] of
165          data points (or an nx2 Numeric array)
166          If point_coordinates is absent, only smoothing matrix will
167          be built
168
169          alpha: Smoothing parameter
170         
171        """
172
173
174        #Convert input to Numeric arrays
175        vertex_coordinates = array(vertex_coordinates).astype(Float)
176        triangles = array(triangles).astype(Int)               
177       
178        #Build underlying mesh
179        self.mesh = General_mesh(vertex_coordinates, triangles)
180
181        #Smoothing parameter
182        self.alpha = alpha
183
184        #Build coefficient matrices
185        self.build_coefficient_matrix_B(point_coordinates)   
186
187    def set_point_coordinates(self, point_coordinates):
188        """
189        A public interface to setting the point co-ordinates.
190        """
191        self.build_coefficient_matrix_B(point_coordinates)
192       
193    def build_coefficient_matrix_B(self, point_coordinates=None):
194        """Build final coefficient matrix"""
195       
196
197        if self.alpha <> 0:
198            self.build_smoothing_matrix_D()
199       
200        if point_coordinates:
201
202            self.build_interpolation_matrix_A(point_coordinates)
203
204            if self.alpha <> 0:
205                self.B = self.AtA + self.alpha*self.D
206            else:
207                self.B = self.AtA
208
209            #Convert self.B matrix to CSR format for faster matrix vector
210            self.B = Sparse_CSR(self.B)
211       
212    def build_interpolation_matrix_A(self, point_coordinates):
213        """Build n x m interpolation matrix, where
214        n is the number of data points and
215        m is the number of basis functions phi_k (one per vertex)
216
217        This algorithm uses a quad tree data structure for fast binning of data points
218        """
219
220        from quad import build_quadtree
221       
222        #Convert input to Numeric arrays
223        point_coordinates = array(point_coordinates).astype(Float)
224       
225        #Build n x m interpolation matrix       
226        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
227        n = point_coordinates.shape[0]     #Nbr of data points
228       
229        self.A = Sparse(n,m)
230        self.AtA = Sparse(m,m)
231
232        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
233        root = build_quadtree(self.mesh)
234
235        #Compute matrix elements
236        for i in range(n):
237            #For each data_coordinate point
238
239            #print 'Doing %d of %d' %(i, n)
240
241            x = point_coordinates[i]
242
243            #Find vertices near x
244            candidate_vertices = root.search(x[0], x[1])
245
246            is_more_elements = True
247            if candidate_vertices == []:
248                # The point isn't even within the root cell!
249                is_more_elements = False
250                element_found = False
251            else:
252                element_found, sigma0, sigma1, sigma2, k = \
253                    self.search_triangles_of_vertices(candidate_vertices, x)
254
255            while not element_found and is_more_elements: 
256                candidate_vertices = root.expand_search()
257                if candidate_vertices == []:
258                    # All the triangles have been searched.
259                    is_more_elements = False
260                else:
261                    element_found, sigma0, sigma1, sigma2, k = \
262                      self.search_triangles_of_vertices(candidate_vertices, x)
263                   
264               
265           
266            #Update interpolation matrix A if necessary     
267            if element_found is True:       
268                #Assign values to matrix A
269
270                j0 = self.mesh.triangles[k,0] #Global vertex id
271                #self.A[i, j0] = sigma0
272
273                j1 = self.mesh.triangles[k,1] #Global vertex id
274                #self.A[i, j1] = sigma1
275
276                j2 = self.mesh.triangles[k,2] #Global vertex id
277                #self.A[i, j2] = sigma2
278
279                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
280                js     = [j0,j1,j2]
281
282                for j in js:
283                    self.A[i,j] = sigmas[j]
284                    for k in js:
285                        self.AtA[j,k] += sigmas[j]*sigmas[k]
286            else:
287                pass
288                #Ok if there is no triangle for datapoint
289                #(as in brute force version)
290                #raise 'Could not find triangle for point', x
291
292
293    def search_triangles_of_vertices(self, candidate_vertices, x):
294            #Find triangle containing x:
295            element_found = False
296
297            # This will be returned if element_found = False
298            sigma2 = -10.0
299            sigma0 = -10.0
300            sigma1 = -10.0
301
302            #For all vertices in same cell as point x
303            for v in candidate_vertices:
304           
305                #for each triangle id (k) which has v as a vertex
306                for k, _ in self.mesh.vertexlist[v]:
307                   
308                    #Get the three vertex_points of candidate triangle
309                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
310                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
311                    xi2 = self.mesh.get_vertex_coordinate(k, 2)     
312
313                    #print "PDSG - k", k
314                    #print "PDSG - xi0", xi0
315                    #print "PDSG - xi1", xi1
316                    #print "PDSG - xi2", xi2
317                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))" \
318                    #      % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
319                   
320                    #Get the three normals
321                    n0 = self.mesh.get_normal(k, 0)
322                    n1 = self.mesh.get_normal(k, 1)
323                    n2 = self.mesh.get_normal(k, 2)               
324
325                   
326
327                    #Compute interpolation
328                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
329                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
330                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
331
332                    #print "PDSG - sigma0", sigma0
333                    #print "PDSG - sigma1", sigma1
334                    #print "PDSG - sigma2", sigma2
335                   
336                    #FIXME: Maybe move out to test or something
337                    epsilon = 1.0e-6
338                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
339                   
340                    #Check that this triangle contains the data point
341                   
342                    #Sigmas can get negative within
343                    #machine precision on some machines (e.g nautilus)
344                    #Hence the small eps                   
345                    eps = 1.0e-15
346                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
347                        element_found = True
348                        break
349
350                if element_found is True:
351                    #Don't look for any other triangle
352                    break
353            return element_found, sigma0, sigma1, sigma2, k     
354                   
355
356       
357    def build_interpolation_matrix_A_brute(self, point_coordinates):
358        """Build n x m interpolation matrix, where
359        n is the number of data points and
360        m is the number of basis functions phi_k (one per vertex)
361
362        This is the brute force which is too slow for large problems,
363        but could be used for testing
364        """
365
366
367       
368        #Convert input to Numeric arrays
369        point_coordinates = array(point_coordinates).astype(Float)
370       
371        #Build n x m interpolation matrix       
372        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
373        n = point_coordinates.shape[0]     #Nbr of data points
374       
375        self.A = Sparse(n,m)
376        self.AtA = Sparse(m,m)
377
378        #Compute matrix elements
379        for i in range(n):
380            #For each data_coordinate point
381
382            x = point_coordinates[i]
383            element_found = False
384            k = 0
385            while not element_found and k < len(self.mesh):
386                #For each triangle (brute force)
387                #FIXME: Real algorithm should only visit relevant triangles
388
389                #Get the three vertex_points
390                xi0 = self.mesh.get_vertex_coordinate(k, 0)
391                xi1 = self.mesh.get_vertex_coordinate(k, 1)
392                xi2 = self.mesh.get_vertex_coordinate(k, 2)                 
393
394                #Get the three normals
395                n0 = self.mesh.get_normal(k, 0)
396                n1 = self.mesh.get_normal(k, 1)
397                n2 = self.mesh.get_normal(k, 2)               
398
399                #Compute interpolation
400                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
401                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
402                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
403
404                #FIXME: Maybe move out to test or something
405                epsilon = 1.0e-6
406                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
407
408                #Check that this triangle contains data point
409                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
410                    element_found = True
411                    #Assign values to matrix A
412
413                    j0 = self.mesh.triangles[k,0] #Global vertex id
414                    #self.A[i, j0] = sigma0
415
416                    j1 = self.mesh.triangles[k,1] #Global vertex id
417                    #self.A[i, j1] = sigma1
418
419                    j2 = self.mesh.triangles[k,2] #Global vertex id
420                    #self.A[i, j2] = sigma2
421
422                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
423                    js     = [j0,j1,j2]
424
425                    for j in js:
426                        self.A[i,j] = sigmas[j]
427                        for k in js:
428                            self.AtA[j,k] += sigmas[j]*sigmas[k]
429                k = k+1
430       
431
432       
433    def get_A(self):
434        return self.A.todense() 
435
436    def get_B(self):
437        return self.B.todense()
438   
439    def get_D(self):
440        return self.D.todense()
441   
442        #FIXME: Remember to re-introduce the 1/n factor in the
443        #interpolation term
444       
445    def build_smoothing_matrix_D(self):
446        """Build m x m smoothing matrix, where
447        m is the number of basis functions phi_k (one per vertex)
448
449        The smoothing matrix is defined as
450
451        D = D1 + D2
452
453        where
454
455        [D1]_{k,l} = \int_\Omega
456           \frac{\partial \phi_k}{\partial x}
457           \frac{\partial \phi_l}{\partial x}\,
458           dx dy
459
460        [D2]_{k,l} = \int_\Omega
461           \frac{\partial \phi_k}{\partial y}
462           \frac{\partial \phi_l}{\partial y}\,
463           dx dy
464
465
466        The derivatives \frac{\partial \phi_k}{\partial x},
467        \frac{\partial \phi_k}{\partial x} for a particular triangle
468        are obtained by computing the gradient a_k, b_k for basis function k
469        """
470
471        #FIXME: algorithm might be optimised by computing local 9x9
472        #"element stiffness matrices:
473
474        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
475
476        self.D = Sparse(m,m)
477
478        #For each triangle compute contributions to D = D1+D2       
479        for i in range(len(self.mesh)):
480
481            #Get area
482            area = self.mesh.areas[i]
483
484            #Get global vertex indices
485            v0 = self.mesh.triangles[i,0]
486            v1 = self.mesh.triangles[i,1]
487            v2 = self.mesh.triangles[i,2]
488           
489            #Get the three vertex_points
490            xi0 = self.mesh.get_vertex_coordinate(i, 0)
491            xi1 = self.mesh.get_vertex_coordinate(i, 1)
492            xi2 = self.mesh.get_vertex_coordinate(i, 2)                 
493
494            #Compute gradients for each vertex
495            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
496                              1, 0, 0)
497
498            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
499                              0, 1, 0)
500
501            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
502                              0, 0, 1)           
503
504            #Compute diagonal contributions
505            self.D[v0,v0] += (a0*a0 + b0*b0)*area
506            self.D[v1,v1] += (a1*a1 + b1*b1)*area
507            self.D[v2,v2] += (a2*a2 + b2*b2)*area           
508
509            #Compute contributions for basis functions sharing edges
510            e01 = (a0*a1 + b0*b1)*area
511            self.D[v0,v1] += e01
512            self.D[v1,v0] += e01
513
514            e12 = (a1*a2 + b1*b2)*area
515            self.D[v1,v2] += e12
516            self.D[v2,v1] += e12
517
518            e20 = (a2*a0 + b2*b0)*area
519            self.D[v2,v0] += e20
520            self.D[v0,v2] += e20             
521
522           
523    def fit(self, z):
524        """Fit a smooth surface to given 1d array of data points z.
525
526        The smooth surface is computed at each vertex in the underlying
527        mesh using the formula given in the module doc string.
528
529        Pre Condition:
530          self.A, self.At and self.B have been initialised
531         
532        Inputs:
533          z: Single 1d vector or array of data at the point_coordinates.
534        """
535
536        #Convert input to Numeric arrays
537        z = array(z).astype(Float)
538
539
540        if len(z.shape) > 1 :
541            raise VectorShapeError, 'Can only deal with 1d data vector'
542       
543        #Compute right hand side based on data
544        Atz = self.A.trans_mult(z)
545
546       
547        #Check sanity
548        n, m = self.A.shape
549        if n<m and self.alpha == 0.0:
550            msg = 'ERROR (least_squares): Too few data points\n'
551            msg += 'There only %d data points. Need at least %d\n' %(n,m)
552            msg += 'Alternatively, increase smoothing parameter alpha' 
553            raise msg
554
555
556
557        return conjugate_gradient(self.B, Atz, Atz,imax=2*len(Atz) )
558        #FIXME: Should we store the result here for later use? (ON)       
559
560           
561    def fit_points(self, z):
562        """Like fit, but more robust when each point has two or more attributes
563        FIXME (Ole): The name fit_points doesn't carry any meaning
564        for me. How about something like fit_multiple or fit_columns?
565        """
566       
567        try:
568            return self.fit(z)
569        except VectorShapeError, e:
570            # broadcasting is not supported.
571
572            #Convert input to Numeric arrays
573            z = array(z).astype(Float)
574           
575            #Build n x m interpolation matrix       
576            m = self.mesh.coordinates.shape[0] #Number of vertices
577            n = z.shape[1]               #Number of data points         
578
579            f = zeros((m,n), Float) #Resulting columns
580           
581            for i in range(z.shape[1]):
582                f[:,i] = self.fit(z[:,i])
583               
584            return f
585           
586       
587    def interpolate(self, f):
588        """Evaluate smooth surface f at data points implied in self.A.
589
590        The mesh values representing a smooth surface are
591        assumed to be specified in f. This argument could,
592        for example have been obtained from the method self.fit()
593       
594        Pre Condition:
595          self.A has been initialised
596       
597        Inputs:
598          f: Vector or array of data at the mesh vertices.
599          If f is an array, interpolation will be done for each column
600        """
601
602        return self.A * f
603       
604           
605#-------------------------------------------------------------
606if __name__ == "__main__":
607    """
608    Load in a mesh and data points with attributes.
609    Fit the attributes to the mesh.
610    Save a new mesh file.
611    """
612    import os, sys
613    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh alpha"\
614            %os.path.basename(sys.argv[0])
615
616    if len(sys.argv) < 4:
617        print usage
618    else:
619        mesh_file = sys.argv[1]
620        point_file = sys.argv[2]
621        mesh_output_file = sys.argv[3]
622        if len(sys.argv) > 4:
623            alpha = sys.argv[4]
624        else:
625            alpha = DEFAULT_ALPHA
626        fit_to_mesh_file(mesh_file, point_file, mesh_output_file, alpha)
627       
Note: See TracBrowser for help on using the repository browser.