source: inundation/ga/storm_surge/pyvolution/least_squares.py @ 1065

Last change on this file since 1065 was 1004, checked in by duncan, 20 years ago

Ooops

File size: 31.3 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14         
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.   
18"""
19
20
21#FIXME (Ole): Currently datapoints outside the triangular mesh are ignored.
22#             Is there a clean way of including them?
23# (DSG) No clean way was found.  After discussions with stephen the best
24# solution was having the user increase the size of the mesh to
25#       cover all the desired points.
26
27
28
29import exceptions
30class ShapeError(exceptions.Exception): pass
31
32from general_mesh import General_mesh
33from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
34from mesh import Mesh
35
36from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
37from sparse import Sparse, Sparse_CSR
38from cg_solve import conjugate_gradient, VectorShapeError
39import time
40
41try:
42    from util import gradient
43except ImportError, e: 
44    #FIXME reduce the dependency of modules in pyvolution
45    # Have util in a dir, working like load_mesh, and get rid of this
46    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
47        """
48        """
49   
50        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)           
51        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
52        a /= det
53
54        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
55        b /= det           
56
57        return a, b
58
59
60DEFAULT_ALPHA = 0.001
61   
62def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
63                     alpha=DEFAULT_ALPHA, verbose= False,
64                     expand_search = False,
65                     data_origin = None,
66                     mesh_origin = None,
67                     precrop = False):
68    """
69    Given a mesh file (tsh) and a point attribute file (xya), fit
70    point attributes to the mesh and write a mesh file with the
71    results.
72   
73
74    If data_origin is not None it is assumed to be
75    a 3-tuple with geo referenced
76    UTM coordinates (zone, easting, northing)
77   
78    mesh_origin is the same but refers to the input tsh file.
79    FIXME: When the tsh format contains it own origin, these parameters can go.
80    FIXME: And both origins should be obtained from the specified files.
81    """
82   
83    from load_mesh.loadASCII import mesh_file_to_mesh_dictionary, \
84                 load_points_file, export_mesh_file, \
85                 concatinate_attributelist
86   
87    # load in the .tsh file
88    #FIXME (Ole): mesh_origin should be extracted here
89    mesh_dict = mesh_file_to_mesh_dictionary(mesh_file)
90    vertex_coordinates = mesh_dict['vertices']
91    triangles = mesh_dict['triangles']
92    if type(mesh_dict['vertex_attributes']) == ArrayType: 
93        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
94    else:
95        old_point_attributes = mesh_dict['vertex_attributes'] 
96       
97    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
98        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
99    else: 
100        old_title_list = mesh_dict['vertex_attribute_titles']
101       
102    if verbose:print "tsh file loaded"
103   
104    # load in the .pts file
105    #FIXME (Ole): data_origin should be extracted here
106    try:
107        point_dict = load_points_file(point_file,delimiter = ',')
108    except SyntaxError,e:
109        point_dict = load_points_file(point_file,delimiter = ' ')
110    point_coordinates = point_dict['pointlist']
111    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist']) 
112    if verbose: print "points file loaded"
113    if verbose:print "fitting to mesh"
114    f = fit_to_mesh(vertex_coordinates,
115                    triangles,
116                    point_coordinates,
117                    point_attributes,
118                    alpha = alpha,
119                    verbose = verbose,
120                    expand_search = expand_search,
121                    data_origin = data_origin,
122                    mesh_origin = mesh_origin,
123                    precrop = precrop)
124    if verbose: print "finished fitting to mesh"
125   
126    # convert array to list of lists
127    new_point_attributes = f.tolist()
128    #FIXME have this overwrite attributes with the same title - DSG
129    #Put the newer attributes last
130    if old_title_list <> []:
131        old_title_list.extend(title_list)
132        #FIXME can this be done a faster way? - DSG
133        for i in range(len(old_point_attributes)):
134            old_point_attributes[i].extend(new_point_attributes[i])
135        mesh_dict['vertex_attributes'] = old_point_attributes
136        mesh_dict['vertex_attribute_titles'] = old_title_list
137    else:
138        mesh_dict['vertex_attributes'] = new_point_attributes
139        mesh_dict['vertex_attribute_titles'] = title_list
140
141    #FIXME (Ole): Remember to output mesh_origin as well   
142    export_mesh_file(mesh_output_file, mesh_dict)
143       
144
145def fit_to_mesh(vertex_coordinates,
146                triangles,
147                point_coordinates,
148                point_attributes,
149                alpha = DEFAULT_ALPHA,
150                verbose = False,
151                expand_search = False,
152                data_origin = None,
153                mesh_origin = None,
154                precrop = False):
155    """
156    Fit a smooth surface to a triangulation,
157    given data points with attributes.
158
159         
160        Inputs:
161       
162          vertex_coordinates: List of coordinate pairs [xi, eta] of points
163          constituting mesh (or a an m x 2 Numeric array)
164       
165          triangles: List of 3-tuples (or a Numeric array) of
166          integers representing indices of all vertices in the mesh.
167
168          point_coordinates: List of coordinate pairs [x, y] of data points
169          (or an nx2 Numeric array)
170
171          alpha: Smoothing parameter.
172
173          point_attributes: Vector or array of data at the point_coordinates.
174
175          data_origin and mesh_origin are 3-tuples consisting of
176          UTM zone, easting and northing. If specified
177          point coordinates and vertex coordinates are assumed to be
178          relative to their respective origins.
179         
180    """
181    interp = Interpolation(vertex_coordinates,
182                           triangles,
183                           point_coordinates,
184                           alpha = alpha,
185                           verbose = verbose,
186                           expand_search = expand_search,
187                           data_origin = data_origin,
188                           mesh_origin = mesh_origin,
189                           precrop = precrop)
190   
191    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
192    return vertex_attributes
193
194   
195   
196def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
197                    verbose = False, reduction = 1, format = 'netcdf'):
198    """Fits attributes from pts file to MxN rectangular mesh
199   
200    Read pts file and create rectangular mesh of resolution MxN such that
201    it covers all points specified in pts file.
202   
203    FIXME: This may be a temporary function until we decide on
204    netcdf formats etc
205
206    FIXME: Uses elevation hardwired
207    """   
208   
209    import util, mesh_factory
210
211    if verbose: print 'Read pts'
212    points, attributes = util.read_xya(pts_name, format)
213
214    #Reduce number of points a bit
215    points = points[::reduction]
216    elevation = attributes['elevation']  #Must be elevation
217    elevation = elevation[::reduction]
218   
219    if verbose: print 'Got %d data points' %len(points)
220
221    if verbose: print 'Create mesh'
222    #Find extent
223    max_x = min_x = points[0][0]
224    max_y = min_y = points[0][1]
225    for point in points[1:]:
226        x = point[0] 
227        if x > max_x: max_x = x
228        if x < min_x: min_x = x           
229        y = point[1] 
230        if y > max_y: max_y = y
231        if y < min_y: min_y = y           
232   
233    #Create appropriate mesh
234    vertex_coordinates, triangles, boundary =\
235         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y, 
236                                (min_x, min_y))
237
238    #Fit attributes to mesh     
239    vertex_attributes = fit_to_mesh(vertex_coordinates,
240                        triangles,
241                        points,
242                        elevation, alpha=alpha, verbose=verbose)
243
244
245         
246    return vertex_coordinates, triangles, boundary, vertex_attributes
247                         
248   
249
250class Interpolation:
251
252    def __init__(self,
253                 vertex_coordinates,
254                 triangles,
255                 point_coordinates = None,
256                 alpha = DEFAULT_ALPHA,
257                 verbose = False,
258                 expand_search = True,
259                 max_points_per_cell = 30,
260                 mesh_origin = None, 
261                 data_origin = None,
262                 precrop = False):
263
264       
265        """ Build interpolation matrix mapping from
266        function values at vertices to function values at data points
267
268        Inputs:
269       
270          vertex_coordinates: List of coordinate pairs [xi, eta] of
271          points constituting mesh (or a an m x 2 Numeric array)
272       
273          triangles: List of 3-tuples (or a Numeric array) of
274          integers representing indices of all vertices in the mesh.
275
276          point_coordinates: List of coordinate pairs [x, y] of
277          data points (or an nx2 Numeric array)
278          If point_coordinates is absent, only smoothing matrix will
279          be built
280
281          alpha: Smoothing parameter
282
283          data_origin and mesh_origin are 3-tuples consisting of
284          UTM zone, easting and northing. If specified
285          point coordinates and vertex coordinates are assumed to be
286          relative to their respective origins.
287         
288        """
289
290        #Convert input to Numeric arrays
291        triangles = array(triangles).astype(Int)
292        vertex_coordinates = array(vertex_coordinates).astype(Float)
293
294        #Build underlying mesh
295        if verbose: print 'Building mesh' 
296        #self.mesh = General_mesh(vertex_coordinates, triangles,
297        #FIXME: Trying the normal mesh while testing precrop
298        self.mesh = Mesh(vertex_coordinates, triangles,
299                         origin = mesh_origin)
300        self.data_origin = data_origin
301
302        self.point_indices = None
303
304        #Smoothing parameter
305        self.alpha = alpha
306
307        #Build coefficient matrices
308        self.build_coefficient_matrix_B(point_coordinates,
309                                        verbose = verbose,
310                                        expand_search = expand_search,
311                                        max_points_per_cell =\
312                                        max_points_per_cell,
313                                        data_origin = data_origin,
314                                        precrop = precrop)
315       
316
317    def set_point_coordinates(self, point_coordinates,
318                              data_origin = None):
319        """
320        A public interface to setting the point co-ordinates.
321        """
322        self.build_coefficient_matrix_B(point_coordinates, data_origin)
323       
324    def build_coefficient_matrix_B(self, point_coordinates=None,
325                                   verbose = False, expand_search = True,
326                                   max_points_per_cell=30,
327                                   data_origin = None,
328                                   precrop = False):
329        """Build final coefficient matrix"""
330       
331
332        if self.alpha <> 0:
333            if verbose: print 'Building smoothing matrix'         
334            self.build_smoothing_matrix_D()
335       
336        if point_coordinates is not None:
337           
338            if verbose: print 'Building interpolation matrix'         
339            self.build_interpolation_matrix_A(point_coordinates,
340                                              verbose = verbose,
341                                              expand_search = expand_search,
342                                              max_points_per_cell =\
343                                              max_points_per_cell,
344                                              data_origin = data_origin,
345                                              precrop = precrop)
346
347            if self.alpha <> 0:
348                self.B = self.AtA + self.alpha*self.D
349            else:
350                self.B = self.AtA
351
352            #Convert self.B matrix to CSR format for faster matrix vector
353            self.B = Sparse_CSR(self.B)
354       
355    def build_interpolation_matrix_A(self, point_coordinates,
356                                     verbose = False, expand_search = True,
357                                     max_points_per_cell=30,
358                                     data_origin = None,
359                                     precrop = False):
360        """Build n x m interpolation matrix, where
361        n is the number of data points and
362        m is the number of basis functions phi_k (one per vertex)
363
364        This algorithm uses a quad tree data structure for fast binning of data points
365        origin is a 3-tuple consisting of UTM zone, easting and northing.
366        If specified coordinates are assumed to be relative to this origin.
367
368        This one will override any data_origin that may be specified in
369        interpolation instance
370
371        """
372
373        from quad import build_quadtree
374
375        if data_origin is None:
376            data_origin = self.data_origin #Use the one from
377                                           #interpolation instance
378       
379        #Convert input to Numeric arrays just in case.
380        point_coordinates = array(point_coordinates).astype(Float)
381
382
383        #Shift data points to same origin as mesh (if specified)
384        mesh_origin = self.mesh.origin
385        if point_coordinates is not None:
386            if data_origin is not None:
387                if mesh_origin is not None:
388
389                    #Transformation:
390                    #
391                    #Let x_0 be the reference point of the point coordinates
392                    #and xi_0 the reference point of the mesh.
393                    #
394                    #A point coordinate (x + x_0) is then made relative
395                    #to xi_0 by
396                    #
397                    # x_new = x + x_0 - xi_0
398                    #
399                    #and similarly for eta
400                   
401                    x_offset = data_origin[1] - mesh_origin[1]
402                    y_offset = data_origin[2] - mesh_origin[2]
403                else: #Shift back to a zero origin
404                    x_offset = data_origin[1]
405                    y_offset = data_origin[2]               
406                   
407                point_coordinates[:,0] += x_offset
408                point_coordinates[:,1] += y_offset           
409            else:
410                if mesh_origin is not None:
411                    #Use mesh origin for data points
412                    point_coordinates[:,0] -= mesh_origin[1] 
413                    point_coordinates[:,1] -= mesh_origin[2]
414
415
416
417        #Remove points falling outside mesh boundary
418        #This reduced one example from 1356 seconds to 825 seconds
419        #And more could be had by writing util.inside_polygon in C
420        if precrop is True:
421            from Numeric import take
422            from util import inside_polygon
423
424            if verbose: print 'Getting boundary polygon'
425            P = self.mesh.get_boundary_polygon()
426
427            if verbose: print 'Getting indices inside mesh boundary'           
428            indices = inside_polygon(point_coordinates, P, verbose = verbose)
429
430            if verbose:
431                if len(indices) != point_coordinates.shape[0]:
432                    print '%d points outside mesh have been cropped.'\
433                          %(point_coordinates.shape[0] - len(indices))
434            point_coordinates = take(point_coordinates, indices)
435            self.point_indices = indices
436
437
438
439       
440        #Build n x m interpolation matrix       
441        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
442        n = point_coordinates.shape[0]     #Nbr of data points
443
444        if verbose: print 'Number of datapoints: %d' %n
445        if verbose: print 'Number of basis functions: %d' %m
446       
447        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
448        #However, Sparse_CSR does not have the same methods as Sparse yet
449        #The tests will reveal what needs to be done
450        self.A = Sparse(n,m)
451        self.AtA = Sparse(m,m)
452
453        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
454        root = build_quadtree(self.mesh,
455                              max_points_per_cell = max_points_per_cell)
456
457        #Compute matrix elements
458        for i in range(n):
459            #For each data_coordinate point
460
461            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
462
463            x = point_coordinates[i]
464
465            #Find vertices near x
466            candidate_vertices = root.search(x[0], x[1])
467            is_more_elements = True
468           
469            element_found, sigma0, sigma1, sigma2, k = \
470                self.search_triangles_of_vertices(candidate_vertices, x)
471            while not element_found and is_more_elements and expand_search: 
472                #if verbose: print 'Expanding search'
473                candidate_vertices, branch = root.expand_search()
474                if branch == []:
475                    # Searching all the verts from the root cell that haven't
476                    # been searched.  This is the last try
477                    element_found, sigma0, sigma1, sigma2, k = \
478                      self.search_triangles_of_vertices(candidate_vertices, x)
479                    is_more_elements = False
480                else:
481                    element_found, sigma0, sigma1, sigma2, k = \
482                      self.search_triangles_of_vertices(candidate_vertices, x)
483                     
484           
485            #Update interpolation matrix A if necessary     
486            if element_found is True:       
487                #Assign values to matrix A
488
489                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
490                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
491                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
492
493                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
494                js     = [j0,j1,j2]
495
496                for j in js:
497                    self.A[i,j] = sigmas[j]
498                    for k in js:
499                        self.AtA[j,k] += sigmas[j]*sigmas[k]
500            else:
501                pass
502                #Ok if there is no triangle for datapoint
503                #(as in brute force version)
504                #raise 'Could not find triangle for point', x
505
506
507
508    def search_triangles_of_vertices(self, candidate_vertices, x):
509            #Find triangle containing x:
510            element_found = False
511
512            # This will be returned if element_found = False
513            sigma2 = -10.0
514            sigma0 = -10.0
515            sigma1 = -10.0
516            k = -10.0
517
518            #For all vertices in same cell as point x
519            for v in candidate_vertices:
520           
521                #for each triangle id (k) which has v as a vertex
522                for k, _ in self.mesh.vertexlist[v]:
523                   
524                    #Get the three vertex_points of candidate triangle
525                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
526                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
527                    xi2 = self.mesh.get_vertex_coordinate(k, 2)     
528
529                    #print "PDSG - k", k
530                    #print "PDSG - xi0", xi0
531                    #print "PDSG - xi1", xi1
532                    #print "PDSG - xi2", xi2
533                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
534                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
535                   
536                    #Get the three normals
537                    n0 = self.mesh.get_normal(k, 0)
538                    n1 = self.mesh.get_normal(k, 1)
539                    n2 = self.mesh.get_normal(k, 2)               
540
541                   
542                    #Compute interpolation
543                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
544                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
545                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
546
547                    #print "PDSG - sigma0", sigma0
548                    #print "PDSG - sigma1", sigma1
549                    #print "PDSG - sigma2", sigma2
550                   
551                    #FIXME: Maybe move out to test or something
552                    epsilon = 1.0e-6
553                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
554                   
555                    #Check that this triangle contains the data point
556                   
557                    #Sigmas can get negative within
558                    #machine precision on some machines (e.g nautilus)
559                    #Hence the small eps                   
560                    eps = 1.0e-15
561                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
562                        element_found = True
563                        break
564
565                if element_found is True:
566                    #Don't look for any other triangle
567                    break
568            return element_found, sigma0, sigma1, sigma2, k     
569                   
570
571       
572    def build_interpolation_matrix_A_brute(self, point_coordinates):
573        """Build n x m interpolation matrix, where
574        n is the number of data points and
575        m is the number of basis functions phi_k (one per vertex)
576
577        This is the brute force which is too slow for large problems,
578        but could be used for testing
579        """
580
581
582       
583        #Convert input to Numeric arrays
584        point_coordinates = array(point_coordinates).astype(Float)
585       
586        #Build n x m interpolation matrix       
587        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
588        n = point_coordinates.shape[0]     #Nbr of data points
589       
590        self.A = Sparse(n,m)
591        self.AtA = Sparse(m,m)
592
593        #Compute matrix elements
594        for i in range(n):
595            #For each data_coordinate point
596
597            x = point_coordinates[i]
598            element_found = False
599            k = 0
600            while not element_found and k < len(self.mesh):
601                #For each triangle (brute force)
602                #FIXME: Real algorithm should only visit relevant triangles
603
604                #Get the three vertex_points
605                xi0 = self.mesh.get_vertex_coordinate(k, 0)
606                xi1 = self.mesh.get_vertex_coordinate(k, 1)
607                xi2 = self.mesh.get_vertex_coordinate(k, 2)                 
608
609                #Get the three normals
610                n0 = self.mesh.get_normal(k, 0)
611                n1 = self.mesh.get_normal(k, 1)
612                n2 = self.mesh.get_normal(k, 2)               
613
614                #Compute interpolation
615                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
616                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
617                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
618
619                #FIXME: Maybe move out to test or something
620                epsilon = 1.0e-6
621                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
622
623                #Check that this triangle contains data point
624                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
625                    element_found = True
626                    #Assign values to matrix A
627
628                    j0 = self.mesh.triangles[k,0] #Global vertex id
629                    #self.A[i, j0] = sigma0
630
631                    j1 = self.mesh.triangles[k,1] #Global vertex id
632                    #self.A[i, j1] = sigma1
633
634                    j2 = self.mesh.triangles[k,2] #Global vertex id
635                    #self.A[i, j2] = sigma2
636
637                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
638                    js     = [j0,j1,j2]
639
640                    for j in js:
641                        self.A[i,j] = sigmas[j]
642                        for k in js:
643                            self.AtA[j,k] += sigmas[j]*sigmas[k]
644                k = k+1
645       
646
647       
648    def get_A(self):
649        return self.A.todense() 
650
651    def get_B(self):
652        return self.B.todense()
653   
654    def get_D(self):
655        return self.D.todense()
656   
657        #FIXME: Remember to re-introduce the 1/n factor in the
658        #interpolation term
659       
660    def build_smoothing_matrix_D(self):
661        """Build m x m smoothing matrix, where
662        m is the number of basis functions phi_k (one per vertex)
663
664        The smoothing matrix is defined as
665
666        D = D1 + D2
667
668        where
669
670        [D1]_{k,l} = \int_\Omega
671           \frac{\partial \phi_k}{\partial x}
672           \frac{\partial \phi_l}{\partial x}\,
673           dx dy
674
675        [D2]_{k,l} = \int_\Omega
676           \frac{\partial \phi_k}{\partial y}
677           \frac{\partial \phi_l}{\partial y}\,
678           dx dy
679
680
681        The derivatives \frac{\partial \phi_k}{\partial x},
682        \frac{\partial \phi_k}{\partial x} for a particular triangle
683        are obtained by computing the gradient a_k, b_k for basis function k
684        """
685
686        #FIXME: algorithm might be optimised by computing local 9x9
687        #"element stiffness matrices:
688
689        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
690
691        self.D = Sparse(m,m)
692
693        #For each triangle compute contributions to D = D1+D2       
694        for i in range(len(self.mesh)):
695
696            #Get area
697            area = self.mesh.areas[i]
698
699            #Get global vertex indices
700            v0 = self.mesh.triangles[i,0]
701            v1 = self.mesh.triangles[i,1]
702            v2 = self.mesh.triangles[i,2]
703           
704            #Get the three vertex_points
705            xi0 = self.mesh.get_vertex_coordinate(i, 0)
706            xi1 = self.mesh.get_vertex_coordinate(i, 1)
707            xi2 = self.mesh.get_vertex_coordinate(i, 2)                 
708
709            #Compute gradients for each vertex
710            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
711                              1, 0, 0)
712
713            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
714                              0, 1, 0)
715
716            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
717                              0, 0, 1)           
718
719            #Compute diagonal contributions
720            self.D[v0,v0] += (a0*a0 + b0*b0)*area
721            self.D[v1,v1] += (a1*a1 + b1*b1)*area
722            self.D[v2,v2] += (a2*a2 + b2*b2)*area           
723
724            #Compute contributions for basis functions sharing edges
725            e01 = (a0*a1 + b0*b1)*area
726            self.D[v0,v1] += e01
727            self.D[v1,v0] += e01
728
729            e12 = (a1*a2 + b1*b2)*area
730            self.D[v1,v2] += e12
731            self.D[v2,v1] += e12
732
733            e20 = (a2*a0 + b2*b0)*area
734            self.D[v2,v0] += e20
735            self.D[v0,v2] += e20             
736
737           
738    def fit(self, z):
739        """Fit a smooth surface to given 1d array of data points z.
740
741        The smooth surface is computed at each vertex in the underlying
742        mesh using the formula given in the module doc string.
743
744        Pre Condition:
745          self.A, self.At and self.B have been initialised
746         
747        Inputs:
748          z: Single 1d vector or array of data at the point_coordinates.
749        """
750
751        #Convert input to Numeric arrays
752        z = array(z).astype(Float)
753
754        if len(z.shape) > 1 :
755            raise VectorShapeError, 'Can only deal with 1d data vector'
756
757        if self.point_indices is not None:
758            #Remove values for any points that were outside mesh
759            z = take(z, self.point_indices) 
760       
761        #Compute right hand side based on data
762        Atz = self.A.trans_mult(z)
763
764       
765        #Check sanity
766        n, m = self.A.shape
767        if n<m and self.alpha == 0.0:
768            msg = 'ERROR (least_squares): Too few data points\n'
769            msg += 'There are only %d data points and alpha == 0. ' %n
770            msg += 'Need at least %d\n' %m
771            msg += 'Alternatively, set smoothing parameter alpha to a small '
772            msg += 'positive value,\ne.g. 1.0e-3.'
773            raise msg
774
775
776
777        return conjugate_gradient(self.B, Atz, Atz,imax=2*len(Atz) )
778        #FIXME: Should we store the result here for later use? (ON)       
779
780           
781    def fit_points(self, z, verbose=False):
782        """Like fit, but more robust when each point has two or more attributes
783        FIXME (Ole): The name fit_points doesn't carry any meaning
784        for me. How about something like fit_multiple or fit_columns?
785        """
786       
787        try:
788            if verbose: print 'Solving penalised least_squares problem'
789            return self.fit(z)
790        except VectorShapeError, e:
791            # broadcasting is not supported.
792
793            #Convert input to Numeric arrays
794            z = array(z).astype(Float)
795           
796            #Build n x m interpolation matrix       
797            m = self.mesh.coordinates.shape[0] #Number of vertices
798            n = z.shape[1]                     #Number of data points         
799
800            f = zeros((m,n), Float) #Resulting columns
801           
802            for i in range(z.shape[1]):
803                f[:,i] = self.fit(z[:,i])
804               
805            return f
806           
807       
808    def interpolate(self, f):
809        """Evaluate smooth surface f at data points implied in self.A.
810
811        The mesh values representing a smooth surface are
812        assumed to be specified in f. This argument could,
813        for example have been obtained from the method self.fit()
814       
815        Pre Condition:
816          self.A has been initialised
817       
818        Inputs:
819          f: Vector or array of data at the mesh vertices.
820          If f is an array, interpolation will be done for each column as
821          per underlying matrix-matrix multiplication
822         
823        Output:
824          Interpolated values at data points implied in self.A
825           
826        """
827
828        return self.A * f
829   
830    def cull_outsiders(self, f):
831        pass
832       
833           
834#-------------------------------------------------------------
835if __name__ == "__main__":
836    """
837    Load in a mesh and data points with attributes.
838    Fit the attributes to the mesh.
839    Save a new mesh file.
840    """
841    import os, sys
842    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha]"\
843            %os.path.basename(sys.argv[0])
844
845    if len(sys.argv) < 4:
846        print usage
847    else:
848        mesh_file = sys.argv[1]
849        point_file = sys.argv[2]
850        mesh_output_file = sys.argv[3]
851       
852        expand_search = False
853        if len(sys.argv) > 4:
854            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
855                expand_search = True
856            else:   
857                expand_search = False
858               
859        verbose = False
860        if len(sys.argv) > 5:
861            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
862                verbose = False
863            else:   
864                verbose = True
865           
866        if len(sys.argv) > 6:
867            alpha = sys.argv[6]
868        else:
869            alpha = DEFAULT_ALPHA
870           
871        t0 = time.time()
872        fit_to_mesh_file(mesh_file,
873                         point_file,
874                         mesh_output_file,
875                         alpha,
876                         verbose= verbose,
877                         expand_search = expand_search)
878   
879        print 'That took %.2f seconds' %(time.time()-t0)
880       
Note: See TracBrowser for help on using the repository browser.