source: inundation/ga/storm_surge/pyvolution/least_squares.py @ 1154

Last change on this file since 1154 was 1154, checked in by duncan, 19 years ago

removing conversation

File size: 31.9 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14         
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.   
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30import time
31
32try:
33    from util import gradient
34except ImportError, e: 
35    #FIXME reduce the dependency of modules in pyvolution
36    # Have util in a dir, working like load_mesh, and get rid of this
37    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
38        """
39        """
40   
41        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)           
42        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
43        a /= det
44
45        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
46        b /= det           
47
48        return a, b
49
50
51DEFAULT_ALPHA = 0.001
52   
53def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
54                     alpha=DEFAULT_ALPHA, verbose= False,
55                     expand_search = False,
56                     data_origin = None,
57                     mesh_origin = None,
58                     precrop = False):
59    """
60    Given a mesh file (tsh) and a point attribute file (xya), fit
61    point attributes to the mesh and write a mesh file with the
62    results.
63   
64
65    If data_origin is not None it is assumed to be
66    a 3-tuple with geo referenced
67    UTM coordinates (zone, easting, northing)
68   
69    mesh_origin is the same but refers to the input tsh file.
70    FIXME: When the tsh format contains it own origin, these parameters can go.
71    FIXME: And both origins should be obtained from the specified files.
72    """
73   
74    from load_mesh.loadASCII import mesh_file_to_mesh_dictionary, \
75                 load_points_file, export_mesh_file, \
76                 concatinate_attributelist
77   
78    mesh_dict = mesh_file_to_mesh_dictionary(mesh_file)
79    vertex_coordinates = mesh_dict['vertices']
80    triangles = mesh_dict['triangles']
81    if type(mesh_dict['vertex_attributes']) == ArrayType: 
82        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
83    else:
84        old_point_attributes = mesh_dict['vertex_attributes'] 
85       
86    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
87        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
88    else: 
89        old_title_list = mesh_dict['vertex_attribute_titles']
90       
91    if verbose: print 'tsh file %s loaded' %mesh_file
92   
93    # load in the .pts file
94    try:
95        point_dict = load_points_file(point_file,
96                                      delimiter = ',',
97                                      verbose=verbose)
98    except SyntaxError,e:
99        point_dict = load_points_file(point_file,
100                                      delimiter = ' ',
101                                      verbose=verbose)
102       
103    point_coordinates = point_dict['pointlist']
104    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
105   
106    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
107        data_origin = point_dict['geo_reference'].get_origin()
108    else:
109        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
110       
111    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
112        mesh_origin = mesh_dict['geo_reference'].get_origin()
113    else:
114        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
115       
116    if verbose: print "points file loaded"
117    if verbose:print "fitting to mesh"
118    f = fit_to_mesh(vertex_coordinates,
119                    triangles,
120                    point_coordinates,
121                    point_attributes,
122                    alpha = alpha,
123                    verbose = verbose,
124                    expand_search = expand_search,
125                    data_origin = data_origin,
126                    mesh_origin = mesh_origin,
127                    precrop = precrop)
128    if verbose: print "finished fitting to mesh"
129   
130    # convert array to list of lists
131    new_point_attributes = f.tolist()
132    #FIXME have this overwrite attributes with the same title - DSG
133    #Put the newer attributes last
134    if old_title_list <> []:
135        old_title_list.extend(title_list)
136        #FIXME can this be done a faster way? - DSG
137        for i in range(len(old_point_attributes)):
138            old_point_attributes[i].extend(new_point_attributes[i])
139        mesh_dict['vertex_attributes'] = old_point_attributes
140        mesh_dict['vertex_attribute_titles'] = old_title_list
141    else:
142        mesh_dict['vertex_attributes'] = new_point_attributes
143        mesh_dict['vertex_attribute_titles'] = title_list
144
145    #FIXME (Ole): Remember to output mesh_origin as well   
146    if verbose: print "exporting to file ",mesh_output_file
147    export_mesh_file(mesh_output_file, mesh_dict)
148       
149
150def fit_to_mesh(vertex_coordinates,
151                triangles,
152                point_coordinates,
153                point_attributes,
154                alpha = DEFAULT_ALPHA,
155                verbose = False,
156                expand_search = False,
157                data_origin = None,
158                mesh_origin = None,
159                precrop = False):
160    """
161    Fit a smooth surface to a triangulation,
162    given data points with attributes.
163
164         
165        Inputs:
166       
167          vertex_coordinates: List of coordinate pairs [xi, eta] of points
168          constituting mesh (or a an m x 2 Numeric array)
169       
170          triangles: List of 3-tuples (or a Numeric array) of
171          integers representing indices of all vertices in the mesh.
172
173          point_coordinates: List of coordinate pairs [x, y] of data points
174          (or an nx2 Numeric array)
175
176          alpha: Smoothing parameter.
177
178          point_attributes: Vector or array of data at the point_coordinates.
179
180          data_origin and mesh_origin are 3-tuples consisting of
181          UTM zone, easting and northing. If specified
182          point coordinates and vertex coordinates are assumed to be
183          relative to their respective origins.
184         
185    """
186    interp = Interpolation(vertex_coordinates,
187                           triangles,
188                           point_coordinates,
189                           alpha = alpha,
190                           verbose = verbose,
191                           expand_search = expand_search,
192                           data_origin = data_origin,
193                           mesh_origin = mesh_origin,
194                           precrop = precrop)
195   
196    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
197    return vertex_attributes
198
199   
200   
201def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
202                    verbose = False, reduction = 1, format = 'netcdf'):
203    """Fits attributes from pts file to MxN rectangular mesh
204   
205    Read pts file and create rectangular mesh of resolution MxN such that
206    it covers all points specified in pts file.
207   
208    FIXME: This may be a temporary function until we decide on
209    netcdf formats etc
210
211    FIXME: Uses elevation hardwired
212    """   
213   
214    import util, mesh_factory
215
216    if verbose: print 'Read pts'
217    points, attributes = util.read_xya(pts_name, format)
218
219    #Reduce number of points a bit
220    points = points[::reduction]
221    elevation = attributes['elevation']  #Must be elevation
222    elevation = elevation[::reduction]
223   
224    if verbose: print 'Got %d data points' %len(points)
225
226    if verbose: print 'Create mesh'
227    #Find extent
228    max_x = min_x = points[0][0]
229    max_y = min_y = points[0][1]
230    for point in points[1:]:
231        x = point[0] 
232        if x > max_x: max_x = x
233        if x < min_x: min_x = x           
234        y = point[1] 
235        if y > max_y: max_y = y
236        if y < min_y: min_y = y           
237   
238    #Create appropriate mesh
239    vertex_coordinates, triangles, boundary =\
240         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y, 
241                                (min_x, min_y))
242
243    #Fit attributes to mesh     
244    vertex_attributes = fit_to_mesh(vertex_coordinates,
245                        triangles,
246                        points,
247                        elevation, alpha=alpha, verbose=verbose)
248
249
250         
251    return vertex_coordinates, triangles, boundary, vertex_attributes
252                         
253   
254
255class Interpolation:
256
257    def __init__(self,
258                 vertex_coordinates,
259                 triangles,
260                 point_coordinates = None,
261                 alpha = DEFAULT_ALPHA,
262                 verbose = False,
263                 expand_search = True,
264                 max_points_per_cell = 30,
265                 mesh_origin = None, 
266                 data_origin = None,
267                 precrop = False):
268
269       
270        """ Build interpolation matrix mapping from
271        function values at vertices to function values at data points
272
273        Inputs:
274       
275          vertex_coordinates: List of coordinate pairs [xi, eta] of
276          points constituting mesh (or a an m x 2 Numeric array)
277       
278          triangles: List of 3-tuples (or a Numeric array) of
279          integers representing indices of all vertices in the mesh.
280
281          point_coordinates: List of coordinate pairs [x, y] of
282          data points (or an nx2 Numeric array)
283          If point_coordinates is absent, only smoothing matrix will
284          be built
285
286          alpha: Smoothing parameter
287
288          data_origin and mesh_origin are 3-tuples consisting of
289          UTM zone, easting and northing. If specified
290          point coordinates and vertex coordinates are assumed to be
291          relative to their respective origins.
292         
293        """
294        from util import ensure_numeric
295       
296        #Convert input to Numeric arrays
297        triangles = ensure_numeric(triangles, Int)
298        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
299
300        #Build underlying mesh
301        if verbose: print 'Building mesh' 
302        #self.mesh = General_mesh(vertex_coordinates, triangles,
303        #FIXME: Trying the normal mesh while testing precrop,
304        #       The functionality of boundary_polygon is needed for that
305        self.mesh = Mesh(vertex_coordinates, triangles,
306                         origin = mesh_origin)
307        self.data_origin = data_origin
308
309        self.point_indices = None
310
311        #Smoothing parameter
312        self.alpha = alpha
313
314        #Build coefficient matrices
315        self.build_coefficient_matrix_B(point_coordinates,
316                                        verbose = verbose,
317                                        expand_search = expand_search,
318                                        max_points_per_cell =\
319                                        max_points_per_cell,
320                                        data_origin = data_origin,
321                                        precrop = precrop)
322       
323
324    def set_point_coordinates(self, point_coordinates,
325                              data_origin = None):
326        """
327        A public interface to setting the point co-ordinates.
328        """
329        self.build_coefficient_matrix_B(point_coordinates, data_origin)
330       
331    def build_coefficient_matrix_B(self, point_coordinates=None,
332                                   verbose = False, expand_search = True,
333                                   max_points_per_cell=30,
334                                   data_origin = None,
335                                   precrop = False):
336        """Build final coefficient matrix"""
337       
338
339        if self.alpha <> 0:
340            if verbose: print 'Building smoothing matrix'         
341            self.build_smoothing_matrix_D()
342       
343        if point_coordinates is not None:
344           
345            if verbose: print 'Building interpolation matrix'         
346            self.build_interpolation_matrix_A(point_coordinates,
347                                              verbose = verbose,
348                                              expand_search = expand_search,
349                                              max_points_per_cell =\
350                                              max_points_per_cell,
351                                              data_origin = data_origin,
352                                              precrop = precrop)
353
354            if self.alpha <> 0:
355                self.B = self.AtA + self.alpha*self.D
356            else:
357                self.B = self.AtA
358
359            #Convert self.B matrix to CSR format for faster matrix vector
360            self.B = Sparse_CSR(self.B)
361       
362    def build_interpolation_matrix_A(self, point_coordinates,
363                                     verbose = False, expand_search = True,
364                                     max_points_per_cell=30,
365                                     data_origin = None,
366                                     precrop = False):
367        """Build n x m interpolation matrix, where
368        n is the number of data points and
369        m is the number of basis functions phi_k (one per vertex)
370
371        This algorithm uses a quad tree data structure for fast binning of data points
372        origin is a 3-tuple consisting of UTM zone, easting and northing.
373        If specified coordinates are assumed to be relative to this origin.
374
375        This one will override any data_origin that may be specified in
376        interpolation instance
377
378        """
379
380        from quad import build_quadtree
381        from util import ensure_numeric
382
383        if data_origin is None:
384            data_origin = self.data_origin #Use the one from
385                                           #interpolation instance
386       
387        #Convert input to Numeric arrays just in case.
388        point_coordinates = ensure_numeric(point_coordinates, Float)       
389
390
391        #Shift data points to same origin as mesh (if specified)
392        mesh_origin = self.mesh.origin
393        if point_coordinates is not None:
394            if data_origin is not None:
395                if mesh_origin is not None:
396
397                    #Transformation:
398                    #
399                    #Let x_0 be the reference point of the point coordinates
400                    #and xi_0 the reference point of the mesh.
401                    #
402                    #A point coordinate (x + x_0) is then made relative
403                    #to xi_0 by
404                    #
405                    # x_new = x + x_0 - xi_0
406                    #
407                    #and similarly for eta
408                   
409                    x_offset = data_origin[1] - mesh_origin[1]
410                    y_offset = data_origin[2] - mesh_origin[2]
411                else: #Shift back to a zero origin
412                    x_offset = data_origin[1]
413                    y_offset = data_origin[2]               
414                   
415                point_coordinates[:,0] += x_offset
416                point_coordinates[:,1] += y_offset           
417            else:
418                if mesh_origin is not None:
419                    #Use mesh origin for data points
420                    point_coordinates[:,0] -= mesh_origin[1] 
421                    point_coordinates[:,1] -= mesh_origin[2]
422
423
424
425        #Remove points falling outside mesh boundary
426        #This reduced one example from 1356 seconds to 825 seconds
427        #And more could be had by writing util.inside_polygon in C
428        if precrop is True:
429            from Numeric import take
430            from util import inside_polygon
431
432            if verbose: print 'Getting boundary polygon'
433            P = self.mesh.get_boundary_polygon()
434
435            if verbose: print 'Getting indices inside mesh boundary'           
436            indices = inside_polygon(point_coordinates, P, verbose = verbose)
437           
438            if verbose:
439                print 'Done'
440                if len(indices) != point_coordinates.shape[0]:
441                    print '%d points outside mesh have been cropped.'\
442                          %(point_coordinates.shape[0] - len(indices))
443            point_coordinates = take(point_coordinates, indices)
444            self.point_indices = indices
445
446
447
448       
449        #Build n x m interpolation matrix       
450        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
451        n = point_coordinates.shape[0]     #Nbr of data points
452
453        if verbose: print 'Number of datapoints: %d' %n
454        if verbose: print 'Number of basis functions: %d' %m
455       
456        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
457        #However, Sparse_CSR does not have the same methods as Sparse yet
458        #The tests will reveal what needs to be done
459        self.A = Sparse(n,m)
460        self.AtA = Sparse(m,m)
461
462        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
463        root = build_quadtree(self.mesh,
464                              max_points_per_cell = max_points_per_cell)
465
466        #Compute matrix elements
467        for i in range(n):
468            #For each data_coordinate point
469
470            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
471
472            x = point_coordinates[i]
473
474            #Find vertices near x
475            candidate_vertices = root.search(x[0], x[1])
476            is_more_elements = True
477           
478            element_found, sigma0, sigma1, sigma2, k = \
479                self.search_triangles_of_vertices(candidate_vertices, x)
480            while not element_found and is_more_elements and expand_search: 
481                #if verbose: print 'Expanding search'
482                candidate_vertices, branch = root.expand_search()
483                if branch == []:
484                    # Searching all the verts from the root cell that haven't
485                    # been searched.  This is the last try
486                    element_found, sigma0, sigma1, sigma2, k = \
487                      self.search_triangles_of_vertices(candidate_vertices, x)
488                    is_more_elements = False
489                else:
490                    element_found, sigma0, sigma1, sigma2, k = \
491                      self.search_triangles_of_vertices(candidate_vertices, x)
492                     
493           
494            #Update interpolation matrix A if necessary     
495            if element_found is True:       
496                #Assign values to matrix A
497
498                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
499                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
500                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
501
502                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
503                js     = [j0,j1,j2]
504
505                for j in js:
506                    self.A[i,j] = sigmas[j]
507                    for k in js:
508                        self.AtA[j,k] += sigmas[j]*sigmas[k]
509            else:
510                pass
511                #Ok if there is no triangle for datapoint
512                #(as in brute force version)
513                #raise 'Could not find triangle for point', x
514
515
516
517    def search_triangles_of_vertices(self, candidate_vertices, x):
518            #Find triangle containing x:
519            element_found = False
520
521            # This will be returned if element_found = False
522            sigma2 = -10.0
523            sigma0 = -10.0
524            sigma1 = -10.0
525            k = -10.0
526
527            #For all vertices in same cell as point x
528            for v in candidate_vertices:
529           
530                #for each triangle id (k) which has v as a vertex
531                for k, _ in self.mesh.vertexlist[v]:
532                   
533                    #Get the three vertex_points of candidate triangle
534                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
535                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
536                    xi2 = self.mesh.get_vertex_coordinate(k, 2)     
537
538                    #print "PDSG - k", k
539                    #print "PDSG - xi0", xi0
540                    #print "PDSG - xi1", xi1
541                    #print "PDSG - xi2", xi2
542                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
543                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
544                   
545                    #Get the three normals
546                    n0 = self.mesh.get_normal(k, 0)
547                    n1 = self.mesh.get_normal(k, 1)
548                    n2 = self.mesh.get_normal(k, 2)               
549
550                   
551                    #Compute interpolation
552                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
553                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
554                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
555
556                    #print "PDSG - sigma0", sigma0
557                    #print "PDSG - sigma1", sigma1
558                    #print "PDSG - sigma2", sigma2
559                   
560                    #FIXME: Maybe move out to test or something
561                    epsilon = 1.0e-6
562                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
563                   
564                    #Check that this triangle contains the data point
565                   
566                    #Sigmas can get negative within
567                    #machine precision on some machines (e.g nautilus)
568                    #Hence the small eps                   
569                    eps = 1.0e-15
570                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
571                        element_found = True
572                        break
573
574                if element_found is True:
575                    #Don't look for any other triangle
576                    break
577            return element_found, sigma0, sigma1, sigma2, k     
578                   
579
580       
581    def build_interpolation_matrix_A_brute(self, point_coordinates):
582        """Build n x m interpolation matrix, where
583        n is the number of data points and
584        m is the number of basis functions phi_k (one per vertex)
585
586        This is the brute force which is too slow for large problems,
587        but could be used for testing
588        """
589
590        from util import ensure_numeric
591       
592        #Convert input to Numeric arrays
593        point_coordinates = ensure_numeric(point_coordinates, Float)         
594       
595        #Build n x m interpolation matrix       
596        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
597        n = point_coordinates.shape[0]     #Nbr of data points
598       
599        self.A = Sparse(n,m)
600        self.AtA = Sparse(m,m)
601
602        #Compute matrix elements
603        for i in range(n):
604            #For each data_coordinate point
605
606            x = point_coordinates[i]
607            element_found = False
608            k = 0
609            while not element_found and k < len(self.mesh):
610                #For each triangle (brute force)
611                #FIXME: Real algorithm should only visit relevant triangles
612
613                #Get the three vertex_points
614                xi0 = self.mesh.get_vertex_coordinate(k, 0)
615                xi1 = self.mesh.get_vertex_coordinate(k, 1)
616                xi2 = self.mesh.get_vertex_coordinate(k, 2)                 
617
618                #Get the three normals
619                n0 = self.mesh.get_normal(k, 0)
620                n1 = self.mesh.get_normal(k, 1)
621                n2 = self.mesh.get_normal(k, 2)               
622
623                #Compute interpolation
624                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
625                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
626                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
627
628                #FIXME: Maybe move out to test or something
629                epsilon = 1.0e-6
630                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
631
632                #Check that this triangle contains data point
633                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
634                    element_found = True
635                    #Assign values to matrix A
636
637                    j0 = self.mesh.triangles[k,0] #Global vertex id
638                    #self.A[i, j0] = sigma0
639
640                    j1 = self.mesh.triangles[k,1] #Global vertex id
641                    #self.A[i, j1] = sigma1
642
643                    j2 = self.mesh.triangles[k,2] #Global vertex id
644                    #self.A[i, j2] = sigma2
645
646                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
647                    js     = [j0,j1,j2]
648
649                    for j in js:
650                        self.A[i,j] = sigmas[j]
651                        for k in js:
652                            self.AtA[j,k] += sigmas[j]*sigmas[k]
653                k = k+1
654       
655
656       
657    def get_A(self):
658        return self.A.todense() 
659
660    def get_B(self):
661        return self.B.todense()
662   
663    def get_D(self):
664        return self.D.todense()
665   
666        #FIXME: Remember to re-introduce the 1/n factor in the
667        #interpolation term
668       
669    def build_smoothing_matrix_D(self):
670        """Build m x m smoothing matrix, where
671        m is the number of basis functions phi_k (one per vertex)
672
673        The smoothing matrix is defined as
674
675        D = D1 + D2
676
677        where
678
679        [D1]_{k,l} = \int_\Omega
680           \frac{\partial \phi_k}{\partial x}
681           \frac{\partial \phi_l}{\partial x}\,
682           dx dy
683
684        [D2]_{k,l} = \int_\Omega
685           \frac{\partial \phi_k}{\partial y}
686           \frac{\partial \phi_l}{\partial y}\,
687           dx dy
688
689
690        The derivatives \frac{\partial \phi_k}{\partial x},
691        \frac{\partial \phi_k}{\partial x} for a particular triangle
692        are obtained by computing the gradient a_k, b_k for basis function k
693        """
694
695        #FIXME: algorithm might be optimised by computing local 9x9
696        #"element stiffness matrices:
697
698        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
699
700        self.D = Sparse(m,m)
701
702        #For each triangle compute contributions to D = D1+D2       
703        for i in range(len(self.mesh)):
704
705            #Get area
706            area = self.mesh.areas[i]
707
708            #Get global vertex indices
709            v0 = self.mesh.triangles[i,0]
710            v1 = self.mesh.triangles[i,1]
711            v2 = self.mesh.triangles[i,2]
712           
713            #Get the three vertex_points
714            xi0 = self.mesh.get_vertex_coordinate(i, 0)
715            xi1 = self.mesh.get_vertex_coordinate(i, 1)
716            xi2 = self.mesh.get_vertex_coordinate(i, 2)                 
717
718            #Compute gradients for each vertex
719            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
720                              1, 0, 0)
721
722            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
723                              0, 1, 0)
724
725            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
726                              0, 0, 1)           
727
728            #Compute diagonal contributions
729            self.D[v0,v0] += (a0*a0 + b0*b0)*area
730            self.D[v1,v1] += (a1*a1 + b1*b1)*area
731            self.D[v2,v2] += (a2*a2 + b2*b2)*area           
732
733            #Compute contributions for basis functions sharing edges
734            e01 = (a0*a1 + b0*b1)*area
735            self.D[v0,v1] += e01
736            self.D[v1,v0] += e01
737
738            e12 = (a1*a2 + b1*b2)*area
739            self.D[v1,v2] += e12
740            self.D[v2,v1] += e12
741
742            e20 = (a2*a0 + b2*b0)*area
743            self.D[v2,v0] += e20
744            self.D[v0,v2] += e20             
745
746           
747    def fit(self, z):
748        """Fit a smooth surface to given 1d array of data points z.
749
750        The smooth surface is computed at each vertex in the underlying
751        mesh using the formula given in the module doc string.
752
753        Pre Condition:
754          self.A, self.At and self.B have been initialised
755         
756        Inputs:
757          z: Single 1d vector or array of data at the point_coordinates.
758        """
759       
760        #Convert input to Numeric arrays
761        from util import ensure_numeric
762        z = ensure_numeric(z, Float)
763       
764        if len(z.shape) > 1 :
765            raise VectorShapeError, 'Can only deal with 1d data vector'
766
767        if self.point_indices is not None:
768            #Remove values for any points that were outside mesh
769            z = take(z, self.point_indices) 
770       
771        #Compute right hand side based on data
772        Atz = self.A.trans_mult(z)
773
774       
775        #Check sanity
776        n, m = self.A.shape
777        if n<m and self.alpha == 0.0:
778            msg = 'ERROR (least_squares): Too few data points\n'
779            msg += 'There are only %d data points and alpha == 0. ' %n
780            msg += 'Need at least %d\n' %m
781            msg += 'Alternatively, set smoothing parameter alpha to a small '
782            msg += 'positive value,\ne.g. 1.0e-3.'
783            raise msg
784
785
786
787        return conjugate_gradient(self.B, Atz, Atz,imax=2*len(Atz) )
788        #FIXME: Should we store the result here for later use? (ON)       
789
790           
791    def fit_points(self, z, verbose=False):
792        """Like fit, but more robust when each point has two or more attributes
793        FIXME (Ole): The name fit_points doesn't carry any meaning
794        for me. How about something like fit_multiple or fit_columns?
795        """
796       
797        try:
798            if verbose: print 'Solving penalised least_squares problem'
799            return self.fit(z)
800        except VectorShapeError, e:
801            # broadcasting is not supported.
802
803            #Convert input to Numeric arrays
804            from util import ensure_numeric
805            z = ensure_numeric(z, Float)           
806           
807            #Build n x m interpolation matrix       
808            m = self.mesh.coordinates.shape[0] #Number of vertices
809            n = z.shape[1]                     #Number of data points         
810
811            f = zeros((m,n), Float) #Resulting columns
812           
813            for i in range(z.shape[1]):
814                f[:,i] = self.fit(z[:,i])
815               
816            return f
817           
818       
819    def interpolate(self, f):
820        """Evaluate smooth surface f at data points implied in self.A.
821
822        The mesh values representing a smooth surface are
823        assumed to be specified in f. This argument could,
824        for example have been obtained from the method self.fit()
825       
826        Pre Condition:
827          self.A has been initialised
828       
829        Inputs:
830          f: Vector or array of data at the mesh vertices.
831          If f is an array, interpolation will be done for each column as
832          per underlying matrix-matrix multiplication
833         
834        Output:
835          Interpolated values at data points implied in self.A
836           
837        """
838
839        return self.A * f
840   
841    def cull_outsiders(self, f):
842        pass
843       
844           
845#-------------------------------------------------------------
846if __name__ == "__main__":
847    """
848    Load in a mesh and data points with attributes.
849    Fit the attributes to the mesh.
850    Save a new mesh file.
851    """
852    import os, sys
853    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha]"\
854            %os.path.basename(sys.argv[0])
855
856    if len(sys.argv) < 4:
857        print usage
858    else:
859        mesh_file = sys.argv[1]
860        point_file = sys.argv[2]
861        mesh_output_file = sys.argv[3]
862       
863        expand_search = False
864        if len(sys.argv) > 4:
865            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
866                expand_search = True
867            else:   
868                expand_search = False
869               
870        verbose = False
871        if len(sys.argv) > 5:
872            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
873                verbose = False
874            else:   
875                verbose = True
876           
877        if len(sys.argv) > 6:
878            alpha = sys.argv[6]
879        else:
880            alpha = DEFAULT_ALPHA
881           
882        t0 = time.time()
883        fit_to_mesh_file(mesh_file,
884                         point_file,
885                         mesh_output_file,
886                         alpha,
887                         verbose= verbose,
888                         expand_search = expand_search)
889   
890        print 'That took %.2f seconds' %(time.time()-t0)
891       
Note: See TracBrowser for help on using the repository browser.