source: inundation/ga/storm_surge/pyvolution/least_squares.py @ 820

Last change on this file since 820 was 820, checked in by ole, 20 years ago

Implemented xya2rectangular and tested it

File size: 22.6 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14         
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.   
18"""
19
20
21#FIXME (Ole): Currently datapoints outside the triangular mesh are ignored.
22#             Is there a clean way of including them?
23# (DSG) No clean way was found.  After discussions with stephen the best
24# solution was having the user increase the size of the mesh to
25#       cover all the desired points.
26
27
28
29import exceptions
30class ShapeError(exceptions.Exception): pass
31
32from general_mesh import General_mesh
33from Numeric import zeros, array, Float, Int, dot, transpose
34from sparse import Sparse, Sparse_CSR
35from cg_solve import conjugate_gradient, VectorShapeError
36
37try:
38    from util import gradient
39except ImportError, e: 
40    #FIXME reduce the dependency of modules in pyvolution
41    # Have util in a dir, working like load_mesh, and get rid of this
42    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
43        """
44        """
45   
46        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)           
47        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
48        a /= det
49
50        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
51        b /= det           
52
53        return a, b
54
55
56DEFAULT_ALPHA = 0.001
57   
58def fit_to_mesh_file(mesh_file, point_file, mesh_output_file, alpha=DEFAULT_ALPHA):
59    """
60    Given a mesh file (tsh) and a point attribute file (xya), fit
61    point attributes to the mesh and write a mesh file with the
62    results.
63    """
64    from load_mesh.loadASCII import mesh_file_to_mesh_dictionary, \
65                 load_xya_file, export_trianglulation_file
66    # load in the .tsh file
67    mesh_dict = mesh_file_to_mesh_dictionary(mesh_file)
68    vertex_coordinates = mesh_dict['generatedpointlist']
69    triangles = mesh_dict['generatedtrianglelist']
70   
71    old_point_attributes = mesh_dict['generatedpointattributelist'] 
72    old_title_list = mesh_dict['generatedpointattributetitlelist']
73
74   
75    # load in the .xya file
76    try:
77        point_dict = load_xya_file(point_file)
78    except SyntaxError,e:
79        point_dict = load_xya_file(point_file,delimiter = ' ')
80    point_coordinates = point_dict['pointlist']
81    point_attributes = point_dict['pointattributelist']
82    title_string = point_dict['title']
83    title_list = title_string.split(',') #FIXME iffy! Hard coding title delimiter 
84    for i in range(len(title_list)):
85        title_list[i] = title_list[i].strip() 
86    #print "title_list stripped", title_list   
87    f = fit_to_mesh(vertex_coordinates,
88                    triangles,
89                    point_coordinates,
90                    point_attributes,
91                    alpha = alpha)
92   
93    # convert array to list of lists
94    new_point_attributes = f.tolist()
95
96    #FIXME have this overwrite attributes with the same title - DSG
97    #Put the newer attributes last
98    if old_title_list <> []:
99        old_title_list.extend(title_list)
100        #FIXME can this be done a faster way? - DSG
101        for i in range(len(old_point_attributes)):
102            old_point_attributes[i].extend(new_point_attributes[i])
103        mesh_dict['generatedpointattributelist'] = old_point_attributes
104        mesh_dict['generatedpointattributetitlelist'] = old_title_list
105    else:
106        mesh_dict['generatedpointattributelist'] = new_point_attributes
107        mesh_dict['generatedpointattributetitlelist'] = title_list
108   
109    export_trianglulation_file(mesh_output_file, mesh_dict)
110       
111
112def fit_to_mesh(vertex_coordinates,
113                triangles,
114                point_coordinates,
115                point_attributes,
116                alpha = DEFAULT_ALPHA):
117    """
118    Fit a smooth surface to a trianglulation,
119    given data points with attributes.
120
121         
122        Inputs:
123       
124          vertex_coordinates: List of coordinate pairs [xi, eta] of points
125          constituting mesh (or a an m x 2 Numeric array)
126       
127          triangles: List of 3-tuples (or a Numeric array) of
128          integers representing indices of all vertices in the mesh.
129
130          point_coordinates: List of coordinate pairs [x, y] of data points
131          (or an nx2 Numeric array)
132
133          alpha: Smoothing parameter.
134
135          point_attributes: Vector or array of data at the point_coordinates.
136    """
137    interp = Interpolation(vertex_coordinates,
138                           triangles,
139                           point_coordinates,
140                           alpha = alpha)
141   
142    vertex_attributes = interp.fit_points(point_attributes)
143    return vertex_attributes
144
145   
146   
147def xya2rectangular(xya_name, M, N, alpha = DEFAULT_ALPHA):
148    """Fits attributes from xya file to MxN rectangular mesh
149   
150    Read xya file and create rectangular mesh of resolution MxN such that
151    it covers all points specified in xya file.
152   
153    FIXME: This may be a temporary function until we decide on
154    netcdf formats etc
155    """   
156   
157    import util, mesh_factory
158   
159    points, attributes = util.read_xya(xya_name) 
160   
161    #Find extent
162    max_x = min_x = points[0][0]
163    max_y = min_y = points[0][1]
164    for point in points[1:]:
165        x = point[0] 
166        if x > max_x: max_x = x
167        if x < min_x: min_x = x           
168        y = point[1] 
169        if y > max_y: max_y = y
170        if y < min_y: min_y = y           
171   
172    #Create appropriate mesh
173    vertex_coordinates, triangles, boundary =\
174         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y, 
175                                (min_x, min_y))
176
177    #Fit attributes to mesh     
178    vertex_attributes = fit_to_mesh(vertex_coordinates,
179                        triangles,
180                        points,
181                        attributes, alpha=alpha)
182
183
184         
185    return vertex_coordinates, triangles, boundary, vertex_attributes
186                         
187   
188
189class Interpolation:
190
191    def __init__(self,
192                 vertex_coordinates,
193                 triangles,
194                 point_coordinates = None,
195                 alpha = DEFAULT_ALPHA):
196
197       
198        """ Build interpolation matrix mapping from
199        function values at vertices to function values at data points
200
201        Inputs:
202       
203          vertex_coordinates: List of coordinate pairs [xi, eta] of
204          points constituting mesh (or a an m x 2 Numeric array)
205       
206          triangles: List of 3-tuples (or a Numeric array) of
207          integers representing indices of all vertices in the mesh.
208
209          point_coordinates: List of coordinate pairs [x, y] of
210          data points (or an nx2 Numeric array)
211          If point_coordinates is absent, only smoothing matrix will
212          be built
213
214          alpha: Smoothing parameter
215         
216        """
217
218
219        #Convert input to Numeric arrays
220        vertex_coordinates = array(vertex_coordinates).astype(Float)
221        triangles = array(triangles).astype(Int)               
222       
223        #Build underlying mesh
224        self.mesh = General_mesh(vertex_coordinates, triangles)
225
226        #Smoothing parameter
227        self.alpha = alpha
228
229        #Build coefficient matrices
230        self.build_coefficient_matrix_B(point_coordinates)   
231
232    def set_point_coordinates(self, point_coordinates):
233        """
234        A public interface to setting the point co-ordinates.
235        """
236        self.build_coefficient_matrix_B(point_coordinates)
237       
238    def build_coefficient_matrix_B(self, point_coordinates=None):
239        """Build final coefficient matrix"""
240       
241
242        if self.alpha <> 0:
243            self.build_smoothing_matrix_D()
244       
245        if point_coordinates:
246
247            self.build_interpolation_matrix_A(point_coordinates)
248
249            if self.alpha <> 0:
250                self.B = self.AtA + self.alpha*self.D
251            else:
252                self.B = self.AtA
253
254            #Convert self.B matrix to CSR format for faster matrix vector
255            self.B = Sparse_CSR(self.B)
256       
257    def build_interpolation_matrix_A(self, point_coordinates):
258        """Build n x m interpolation matrix, where
259        n is the number of data points and
260        m is the number of basis functions phi_k (one per vertex)
261
262        This algorithm uses a quad tree data structure for fast binning of data points
263        """
264
265        from quad import build_quadtree
266       
267        #Convert input to Numeric arrays
268        point_coordinates = array(point_coordinates).astype(Float)
269       
270        #Build n x m interpolation matrix       
271        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
272        n = point_coordinates.shape[0]     #Nbr of data points
273
274       
275        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
276        #However, Sparse_CSR does not have the same methods as Sparse yet
277        #The tests will reveal what needs to be done
278        self.A = Sparse(n,m)
279        self.AtA = Sparse(m,m)
280
281        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
282        root = build_quadtree(self.mesh)
283
284        #Compute matrix elements
285        for i in range(n):
286            #For each data_coordinate point
287
288            #print 'Doing %d of %d' %(i, n)
289
290            x = point_coordinates[i]
291
292            #Find vertices near x
293            candidate_vertices = root.search(x[0], x[1])
294            is_more_elements = True
295           
296            element_found, sigma0, sigma1, sigma2, k = \
297                self.search_triangles_of_vertices(candidate_vertices, x)
298            while not element_found and is_more_elements: 
299                candidate_vertices, branch = root.expand_search()
300                if branch == []:
301                    # Searching all the verts from the root cell that haven't
302                    # been searched.  This is the last try
303                    element_found, sigma0, sigma1, sigma2, k = \
304                      self.search_triangles_of_vertices(candidate_vertices, x)
305                    is_more_elements = False
306                else:
307                    element_found, sigma0, sigma1, sigma2, k = \
308                      self.search_triangles_of_vertices(candidate_vertices, x)
309                     
310           
311            #Update interpolation matrix A if necessary     
312            if element_found is True:       
313                #Assign values to matrix A
314
315                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
316                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
317                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
318
319                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
320                js     = [j0,j1,j2]
321
322                for j in js:
323                    self.A[i,j] = sigmas[j]
324                    for k in js:
325                        self.AtA[j,k] += sigmas[j]*sigmas[k]
326            else:
327                pass
328                #Ok if there is no triangle for datapoint
329                #(as in brute force version)
330                #raise 'Could not find triangle for point', x
331
332
333
334    def search_triangles_of_vertices(self, candidate_vertices, x):
335            #Find triangle containing x:
336            element_found = False
337
338            # This will be returned if element_found = False
339            sigma2 = -10.0
340            sigma0 = -10.0
341            sigma1 = -10.0
342            k = -10.0
343
344            #For all vertices in same cell as point x
345            for v in candidate_vertices:
346           
347                #for each triangle id (k) which has v as a vertex
348                for k, _ in self.mesh.vertexlist[v]:
349                   
350                    #Get the three vertex_points of candidate triangle
351                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
352                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
353                    xi2 = self.mesh.get_vertex_coordinate(k, 2)     
354
355                    #print "PDSG - k", k
356                    #print "PDSG - xi0", xi0
357                    #print "PDSG - xi1", xi1
358                    #print "PDSG - xi2", xi2
359                   #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))" \
360                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
361                   
362                    #Get the three normals
363                    n0 = self.mesh.get_normal(k, 0)
364                    n1 = self.mesh.get_normal(k, 1)
365                    n2 = self.mesh.get_normal(k, 2)               
366
367                   
368
369                    #Compute interpolation
370                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
371                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
372                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
373
374                    #print "PDSG - sigma0", sigma0
375                    #print "PDSG - sigma1", sigma1
376                    #print "PDSG - sigma2", sigma2
377                   
378                    #FIXME: Maybe move out to test or something
379                    epsilon = 1.0e-6
380                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
381                   
382                    #Check that this triangle contains the data point
383                   
384                    #Sigmas can get negative within
385                    #machine precision on some machines (e.g nautilus)
386                    #Hence the small eps                   
387                    eps = 1.0e-15
388                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
389                        element_found = True
390                        break
391
392                if element_found is True:
393                    #Don't look for any other triangle
394                    break
395            return element_found, sigma0, sigma1, sigma2, k     
396                   
397
398       
399    def build_interpolation_matrix_A_brute(self, point_coordinates):
400        """Build n x m interpolation matrix, where
401        n is the number of data points and
402        m is the number of basis functions phi_k (one per vertex)
403
404        This is the brute force which is too slow for large problems,
405        but could be used for testing
406        """
407
408
409       
410        #Convert input to Numeric arrays
411        point_coordinates = array(point_coordinates).astype(Float)
412       
413        #Build n x m interpolation matrix       
414        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
415        n = point_coordinates.shape[0]     #Nbr of data points
416       
417        self.A = Sparse(n,m)
418        self.AtA = Sparse(m,m)
419
420        #Compute matrix elements
421        for i in range(n):
422            #For each data_coordinate point
423
424            x = point_coordinates[i]
425            element_found = False
426            k = 0
427            while not element_found and k < len(self.mesh):
428                #For each triangle (brute force)
429                #FIXME: Real algorithm should only visit relevant triangles
430
431                #Get the three vertex_points
432                xi0 = self.mesh.get_vertex_coordinate(k, 0)
433                xi1 = self.mesh.get_vertex_coordinate(k, 1)
434                xi2 = self.mesh.get_vertex_coordinate(k, 2)                 
435
436                #Get the three normals
437                n0 = self.mesh.get_normal(k, 0)
438                n1 = self.mesh.get_normal(k, 1)
439                n2 = self.mesh.get_normal(k, 2)               
440
441                #Compute interpolation
442                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
443                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
444                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
445
446                #FIXME: Maybe move out to test or something
447                epsilon = 1.0e-6
448                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
449
450                #Check that this triangle contains data point
451                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
452                    element_found = True
453                    #Assign values to matrix A
454
455                    j0 = self.mesh.triangles[k,0] #Global vertex id
456                    #self.A[i, j0] = sigma0
457
458                    j1 = self.mesh.triangles[k,1] #Global vertex id
459                    #self.A[i, j1] = sigma1
460
461                    j2 = self.mesh.triangles[k,2] #Global vertex id
462                    #self.A[i, j2] = sigma2
463
464                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
465                    js     = [j0,j1,j2]
466
467                    for j in js:
468                        self.A[i,j] = sigmas[j]
469                        for k in js:
470                            self.AtA[j,k] += sigmas[j]*sigmas[k]
471                k = k+1
472       
473
474       
475    def get_A(self):
476        return self.A.todense() 
477
478    def get_B(self):
479        return self.B.todense()
480   
481    def get_D(self):
482        return self.D.todense()
483   
484        #FIXME: Remember to re-introduce the 1/n factor in the
485        #interpolation term
486       
487    def build_smoothing_matrix_D(self):
488        """Build m x m smoothing matrix, where
489        m is the number of basis functions phi_k (one per vertex)
490
491        The smoothing matrix is defined as
492
493        D = D1 + D2
494
495        where
496
497        [D1]_{k,l} = \int_\Omega
498           \frac{\partial \phi_k}{\partial x}
499           \frac{\partial \phi_l}{\partial x}\,
500           dx dy
501
502        [D2]_{k,l} = \int_\Omega
503           \frac{\partial \phi_k}{\partial y}
504           \frac{\partial \phi_l}{\partial y}\,
505           dx dy
506
507
508        The derivatives \frac{\partial \phi_k}{\partial x},
509        \frac{\partial \phi_k}{\partial x} for a particular triangle
510        are obtained by computing the gradient a_k, b_k for basis function k
511        """
512
513        #FIXME: algorithm might be optimised by computing local 9x9
514        #"element stiffness matrices:
515
516        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
517
518        self.D = Sparse(m,m)
519
520        #For each triangle compute contributions to D = D1+D2       
521        for i in range(len(self.mesh)):
522
523            #Get area
524            area = self.mesh.areas[i]
525
526            #Get global vertex indices
527            v0 = self.mesh.triangles[i,0]
528            v1 = self.mesh.triangles[i,1]
529            v2 = self.mesh.triangles[i,2]
530           
531            #Get the three vertex_points
532            xi0 = self.mesh.get_vertex_coordinate(i, 0)
533            xi1 = self.mesh.get_vertex_coordinate(i, 1)
534            xi2 = self.mesh.get_vertex_coordinate(i, 2)                 
535
536            #Compute gradients for each vertex
537            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
538                              1, 0, 0)
539
540            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
541                              0, 1, 0)
542
543            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
544                              0, 0, 1)           
545
546            #Compute diagonal contributions
547            self.D[v0,v0] += (a0*a0 + b0*b0)*area
548            self.D[v1,v1] += (a1*a1 + b1*b1)*area
549            self.D[v2,v2] += (a2*a2 + b2*b2)*area           
550
551            #Compute contributions for basis functions sharing edges
552            e01 = (a0*a1 + b0*b1)*area
553            self.D[v0,v1] += e01
554            self.D[v1,v0] += e01
555
556            e12 = (a1*a2 + b1*b2)*area
557            self.D[v1,v2] += e12
558            self.D[v2,v1] += e12
559
560            e20 = (a2*a0 + b2*b0)*area
561            self.D[v2,v0] += e20
562            self.D[v0,v2] += e20             
563
564           
565    def fit(self, z):
566        """Fit a smooth surface to given 1d array of data points z.
567
568        The smooth surface is computed at each vertex in the underlying
569        mesh using the formula given in the module doc string.
570
571        Pre Condition:
572          self.A, self.At and self.B have been initialised
573         
574        Inputs:
575          z: Single 1d vector or array of data at the point_coordinates.
576        """
577
578        #Convert input to Numeric arrays
579        z = array(z).astype(Float)
580
581
582        if len(z.shape) > 1 :
583            raise VectorShapeError, 'Can only deal with 1d data vector'
584       
585        #Compute right hand side based on data
586        Atz = self.A.trans_mult(z)
587
588       
589        #Check sanity
590        n, m = self.A.shape
591        if n<m and self.alpha == 0.0:
592            msg = 'ERROR (least_squares): Too few data points\n'
593            msg += 'There are only %d data points and alpha == 0. ' %n
594            msg += 'Need at least %d\n' %m
595            msg += 'Alternatively, set smoothing parameter alpha to a small '
596            msg += 'positive value,\ne.g. 1.0e-3.'
597            raise msg
598
599
600
601        return conjugate_gradient(self.B, Atz, Atz,imax=2*len(Atz) )
602        #FIXME: Should we store the result here for later use? (ON)       
603
604           
605    def fit_points(self, z):
606        """Like fit, but more robust when each point has two or more attributes
607        FIXME (Ole): The name fit_points doesn't carry any meaning
608        for me. How about something like fit_multiple or fit_columns?
609        """
610       
611        try:
612            return self.fit(z)
613        except VectorShapeError, e:
614            # broadcasting is not supported.
615
616            #Convert input to Numeric arrays
617            z = array(z).astype(Float)
618           
619            #Build n x m interpolation matrix       
620            m = self.mesh.coordinates.shape[0] #Number of vertices
621            n = z.shape[1]               #Number of data points         
622
623            f = zeros((m,n), Float) #Resulting columns
624           
625            for i in range(z.shape[1]):
626                f[:,i] = self.fit(z[:,i])
627               
628            return f
629           
630       
631    def interpolate(self, f):
632        """Evaluate smooth surface f at data points implied in self.A.
633
634        The mesh values representing a smooth surface are
635        assumed to be specified in f. This argument could,
636        for example have been obtained from the method self.fit()
637       
638        Pre Condition:
639          self.A has been initialised
640       
641        Inputs:
642          f: Vector or array of data at the mesh vertices.
643          If f is an array, interpolation will be done for each column as
644          per underlying matrix-matrix multiplication
645        """
646
647        return self.A * f
648       
649           
650#-------------------------------------------------------------
651if __name__ == "__main__":
652    """
653    Load in a mesh and data points with attributes.
654    Fit the attributes to the mesh.
655    Save a new mesh file.
656    """
657    import os, sys
658    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh alpha"\
659            %os.path.basename(sys.argv[0])
660
661    if len(sys.argv) < 4:
662        print usage
663    else:
664        mesh_file = sys.argv[1]
665        point_file = sys.argv[2]
666        mesh_output_file = sys.argv[3]
667        if len(sys.argv) > 4:
668            alpha = sys.argv[4]
669        else:
670            alpha = DEFAULT_ALPHA
671        fit_to_mesh_file(mesh_file, point_file, mesh_output_file, alpha)
672       
Note: See TracBrowser for help on using the repository browser.