source: inundation/ga/storm_surge/pyvolution/least_squares.py @ 1454

Last change on this file since 1454 was 1423, checked in by duncan, 20 years ago

I've removed the function mesh_file_to_mesh_dictionary. Please use the function import_mesh_file instead.

File size: 32.2 KB
RevLine 
[1160]1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
[1178]30
31from coordinate_transforms.geo_reference import Geo_reference
32
[1160]33import time
34
[1178]35
[1160]36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False):
63    """
64    Given a mesh file (tsh) and a point attribute file (xya), fit
65    point attributes to the mesh and write a mesh file with the
66    results.
67
68
69    If data_origin is not None it is assumed to be
70    a 3-tuple with geo referenced
71    UTM coordinates (zone, easting, northing)
72
73    mesh_origin is the same but refers to the input tsh file.
74    FIXME: When the tsh format contains it own origin, these parameters can go.
75    FIXME: And both origins should be obtained from the specified files.
76    """
77
[1423]78    from load_mesh.loadASCII import import_mesh_file, \
[1379]79                 import_points_file, export_mesh_file, \
[1160]80                 concatinate_attributelist
81
[1423]82    mesh_dict = import_mesh_file(mesh_file)
[1160]83    vertex_coordinates = mesh_dict['vertices']
84    triangles = mesh_dict['triangles']
85    if type(mesh_dict['vertex_attributes']) == ArrayType:
86        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
87    else:
88        old_point_attributes = mesh_dict['vertex_attributes']
89
90    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
91        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
92    else:
93        old_title_list = mesh_dict['vertex_attribute_titles']
94
95    if verbose: print 'tsh file %s loaded' %mesh_file
96
97    # load in the .pts file
98    try:
[1379]99        point_dict = import_points_file(point_file,
[1160]100                                      delimiter = ',',
101                                      verbose=verbose)
102    except SyntaxError,e:
[1379]103        point_dict = import_points_file(point_file,
[1160]104                                      delimiter = ' ',
105                                      verbose=verbose)
106
107    point_coordinates = point_dict['pointlist']
108    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
109
110    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
111        data_origin = point_dict['geo_reference'].get_origin()
112    else:
113        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
114
115    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
116        mesh_origin = mesh_dict['geo_reference'].get_origin()
117    else:
118        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
119
120    if verbose: print "points file loaded"
121    if verbose:print "fitting to mesh"
122    f = fit_to_mesh(vertex_coordinates,
123                    triangles,
124                    point_coordinates,
125                    point_attributes,
126                    alpha = alpha,
127                    verbose = verbose,
128                    expand_search = expand_search,
129                    data_origin = data_origin,
130                    mesh_origin = mesh_origin,
131                    precrop = precrop)
132    if verbose: print "finished fitting to mesh"
133
134    # convert array to list of lists
135    new_point_attributes = f.tolist()
136    #FIXME have this overwrite attributes with the same title - DSG
137    #Put the newer attributes last
138    if old_title_list <> []:
139        old_title_list.extend(title_list)
140        #FIXME can this be done a faster way? - DSG
141        for i in range(len(old_point_attributes)):
142            old_point_attributes[i].extend(new_point_attributes[i])
143        mesh_dict['vertex_attributes'] = old_point_attributes
144        mesh_dict['vertex_attribute_titles'] = old_title_list
145    else:
146        mesh_dict['vertex_attributes'] = new_point_attributes
147        mesh_dict['vertex_attribute_titles'] = title_list
148
149    #FIXME (Ole): Remember to output mesh_origin as well
150    if verbose: print "exporting to file ",mesh_output_file
151    export_mesh_file(mesh_output_file, mesh_dict)
152
153
154def fit_to_mesh(vertex_coordinates,
155                triangles,
156                point_coordinates,
157                point_attributes,
158                alpha = DEFAULT_ALPHA,
159                verbose = False,
160                expand_search = False,
161                data_origin = None,
162                mesh_origin = None,
163                precrop = False):
164    """
165    Fit a smooth surface to a triangulation,
166    given data points with attributes.
167
168
169        Inputs:
170
171          vertex_coordinates: List of coordinate pairs [xi, eta] of points
172          constituting mesh (or a an m x 2 Numeric array)
173
174          triangles: List of 3-tuples (or a Numeric array) of
175          integers representing indices of all vertices in the mesh.
176
177          point_coordinates: List of coordinate pairs [x, y] of data points
178          (or an nx2 Numeric array)
179
180          alpha: Smoothing parameter.
181
182          point_attributes: Vector or array of data at the point_coordinates.
183
184          data_origin and mesh_origin are 3-tuples consisting of
185          UTM zone, easting and northing. If specified
186          point coordinates and vertex coordinates are assumed to be
187          relative to their respective origins.
188
189    """
190    interp = Interpolation(vertex_coordinates,
191                           triangles,
192                           point_coordinates,
193                           alpha = alpha,
194                           verbose = verbose,
195                           expand_search = expand_search,
196                           data_origin = data_origin,
197                           mesh_origin = mesh_origin,
198                           precrop = precrop)
199
200    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
201    return vertex_attributes
202
203
204
205def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
206                    verbose = False, reduction = 1, format = 'netcdf'):
207    """Fits attributes from pts file to MxN rectangular mesh
208
209    Read pts file and create rectangular mesh of resolution MxN such that
210    it covers all points specified in pts file.
211
212    FIXME: This may be a temporary function until we decide on
213    netcdf formats etc
214
215    FIXME: Uses elevation hardwired
216    """
217
218    import util, mesh_factory
219
220    if verbose: print 'Read pts'
221    points, attributes = util.read_xya(pts_name, format)
222
223    #Reduce number of points a bit
224    points = points[::reduction]
225    elevation = attributes['elevation']  #Must be elevation
226    elevation = elevation[::reduction]
227
228    if verbose: print 'Got %d data points' %len(points)
229
230    if verbose: print 'Create mesh'
231    #Find extent
232    max_x = min_x = points[0][0]
233    max_y = min_y = points[0][1]
234    for point in points[1:]:
235        x = point[0]
236        if x > max_x: max_x = x
237        if x < min_x: min_x = x
238        y = point[1]
239        if y > max_y: max_y = y
240        if y < min_y: min_y = y
241
242    #Create appropriate mesh
243    vertex_coordinates, triangles, boundary =\
244         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
245                                (min_x, min_y))
246
247    #Fit attributes to mesh
248    vertex_attributes = fit_to_mesh(vertex_coordinates,
249                        triangles,
250                        points,
251                        elevation, alpha=alpha, verbose=verbose)
252
253
254
255    return vertex_coordinates, triangles, boundary, vertex_attributes
256
257
258
259class Interpolation:
260
261    def __init__(self,
262                 vertex_coordinates,
263                 triangles,
264                 point_coordinates = None,
265                 alpha = DEFAULT_ALPHA,
266                 verbose = False,
267                 expand_search = True,
268                 max_points_per_cell = 30,
269                 mesh_origin = None,
270                 data_origin = None,
271                 precrop = False):
272
273
274        """ Build interpolation matrix mapping from
275        function values at vertices to function values at data points
276
277        Inputs:
278
279          vertex_coordinates: List of coordinate pairs [xi, eta] of
280          points constituting mesh (or a an m x 2 Numeric array)
281
282          triangles: List of 3-tuples (or a Numeric array) of
283          integers representing indices of all vertices in the mesh.
284
285          point_coordinates: List of coordinate pairs [x, y] of
286          data points (or an nx2 Numeric array)
287          If point_coordinates is absent, only smoothing matrix will
288          be built
289
290          alpha: Smoothing parameter
291
292          data_origin and mesh_origin are 3-tuples consisting of
293          UTM zone, easting and northing. If specified
294          point coordinates and vertex coordinates are assumed to be
295          relative to their respective origins.
296
297        """
298        from util import ensure_numeric
299
300        #Convert input to Numeric arrays
301        triangles = ensure_numeric(triangles, Int)
302        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
303
304        #Build underlying mesh
305        if verbose: print 'Building mesh'
306        #self.mesh = General_mesh(vertex_coordinates, triangles,
307        #FIXME: Trying the normal mesh while testing precrop,
308        #       The functionality of boundary_polygon is needed for that
[1178]309
310        #FIXME - geo ref does not have to go into mesh.
311        # Change the point co-ords to conform to the
312        # mesh co-ords early in the code
313        if mesh_origin == None:
314            geo = None
315        else:
316            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
[1160]317        self.mesh = Mesh(vertex_coordinates, triangles,
[1178]318                         geo_reference = geo)
[1160]319        #FIXME, remove if mesh checks it.
320        self.mesh.check_integrity()
321        self.data_origin = data_origin
322
323        self.point_indices = None
324
325        #Smoothing parameter
326        self.alpha = alpha
327
328        #Build coefficient matrices
329        self.build_coefficient_matrix_B(point_coordinates,
330                                        verbose = verbose,
331                                        expand_search = expand_search,
332                                        max_points_per_cell =\
333                                        max_points_per_cell,
334                                        data_origin = data_origin,
335                                        precrop = precrop)
336
337
338    def set_point_coordinates(self, point_coordinates,
339                              data_origin = None):
340        """
341        A public interface to setting the point co-ordinates.
342        """
343        self.build_coefficient_matrix_B(point_coordinates, data_origin)
344
345    def build_coefficient_matrix_B(self, point_coordinates=None,
346                                   verbose = False, expand_search = True,
347                                   max_points_per_cell=30,
348                                   data_origin = None,
349                                   precrop = False):
350        """Build final coefficient matrix"""
351
352
353        if self.alpha <> 0:
354            if verbose: print 'Building smoothing matrix'
355            self.build_smoothing_matrix_D()
356
357        if point_coordinates is not None:
358
359            if verbose: print 'Building interpolation matrix'
360            self.build_interpolation_matrix_A(point_coordinates,
361                                              verbose = verbose,
362                                              expand_search = expand_search,
363                                              max_points_per_cell =\
364                                              max_points_per_cell,
365                                              data_origin = data_origin,
366                                              precrop = precrop)
367
368            if self.alpha <> 0:
369                self.B = self.AtA + self.alpha*self.D
370            else:
371                self.B = self.AtA
372
373            #Convert self.B matrix to CSR format for faster matrix vector
374            self.B = Sparse_CSR(self.B)
375
376    def build_interpolation_matrix_A(self, point_coordinates,
377                                     verbose = False, expand_search = True,
378                                     max_points_per_cell=30,
379                                     data_origin = None,
380                                     precrop = False):
381        """Build n x m interpolation matrix, where
382        n is the number of data points and
383        m is the number of basis functions phi_k (one per vertex)
384
385        This algorithm uses a quad tree data structure for fast binning of data points
386        origin is a 3-tuple consisting of UTM zone, easting and northing.
387        If specified coordinates are assumed to be relative to this origin.
388
389        This one will override any data_origin that may be specified in
390        interpolation instance
391
392        """
393
394        from quad import build_quadtree
395        from util import ensure_numeric
396
397        if data_origin is None:
398            data_origin = self.data_origin #Use the one from
399                                           #interpolation instance
400
401        #Convert input to Numeric arrays just in case.
402        point_coordinates = ensure_numeric(point_coordinates, Float)
403
404
405        #Shift data points to same origin as mesh (if specified)
[1178]406
407        #FIXME this will shift if there was no geo_ref.
408        #But all this should be removed amyhow.
409        #change coords before this point
410        mesh_origin = self.mesh.geo_reference.get_origin()
[1160]411        if point_coordinates is not None:
412            if data_origin is not None:
413                if mesh_origin is not None:
414
415                    #Transformation:
416                    #
417                    #Let x_0 be the reference point of the point coordinates
418                    #and xi_0 the reference point of the mesh.
419                    #
420                    #A point coordinate (x + x_0) is then made relative
421                    #to xi_0 by
422                    #
423                    # x_new = x + x_0 - xi_0
424                    #
425                    #and similarly for eta
426
427                    x_offset = data_origin[1] - mesh_origin[1]
428                    y_offset = data_origin[2] - mesh_origin[2]
429                else: #Shift back to a zero origin
430                    x_offset = data_origin[1]
431                    y_offset = data_origin[2]
432
433                point_coordinates[:,0] += x_offset
434                point_coordinates[:,1] += y_offset
435            else:
436                if mesh_origin is not None:
437                    #Use mesh origin for data points
438                    point_coordinates[:,0] -= mesh_origin[1]
439                    point_coordinates[:,1] -= mesh_origin[2]
440
441
442
443        #Remove points falling outside mesh boundary
444        #This reduced one example from 1356 seconds to 825 seconds
445        #And more could be had by writing util.inside_polygon in C
446        if precrop is True:
447            from Numeric import take
448            from util import inside_polygon
449
450            if verbose: print 'Getting boundary polygon'
451            P = self.mesh.get_boundary_polygon()
452
453            if verbose: print 'Getting indices inside mesh boundary'
454            indices = inside_polygon(point_coordinates, P, verbose = verbose)
455
456            if verbose:
457                print 'Done'
458                if len(indices) != point_coordinates.shape[0]:
459                    print '%d points outside mesh have been cropped.'\
460                          %(point_coordinates.shape[0] - len(indices))
461            point_coordinates = take(point_coordinates, indices)
462            self.point_indices = indices
463
464
465
466
467        #Build n x m interpolation matrix
468        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
469        n = point_coordinates.shape[0]     #Nbr of data points
470
471        if verbose: print 'Number of datapoints: %d' %n
472        if verbose: print 'Number of basis functions: %d' %m
473
474        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
475        #However, Sparse_CSR does not have the same methods as Sparse yet
476        #The tests will reveal what needs to be done
477        self.A = Sparse(n,m)
478        self.AtA = Sparse(m,m)
479
480        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
481        root = build_quadtree(self.mesh,
482                              max_points_per_cell = max_points_per_cell)
483
484        #Compute matrix elements
485        for i in range(n):
486            #For each data_coordinate point
487
488            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
489
490            x = point_coordinates[i]
491
492            #Find vertices near x
493            candidate_vertices = root.search(x[0], x[1])
494            is_more_elements = True
495
496            element_found, sigma0, sigma1, sigma2, k = \
497                self.search_triangles_of_vertices(candidate_vertices, x)
498            while not element_found and is_more_elements and expand_search:
499                #if verbose: print 'Expanding search'
500                candidate_vertices, branch = root.expand_search()
501                if branch == []:
502                    # Searching all the verts from the root cell that haven't
503                    # been searched.  This is the last try
504                    element_found, sigma0, sigma1, sigma2, k = \
505                      self.search_triangles_of_vertices(candidate_vertices, x)
506                    is_more_elements = False
507                else:
508                    element_found, sigma0, sigma1, sigma2, k = \
509                      self.search_triangles_of_vertices(candidate_vertices, x)
510
511
512            #Update interpolation matrix A if necessary
513            if element_found is True:
514                #Assign values to matrix A
515
516                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
517                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
518                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
519
520                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
521                js     = [j0,j1,j2]
522
523                for j in js:
524                    self.A[i,j] = sigmas[j]
525                    for k in js:
526                        self.AtA[j,k] += sigmas[j]*sigmas[k]
527            else:
528                pass
529                #Ok if there is no triangle for datapoint
530                #(as in brute force version)
531                #raise 'Could not find triangle for point', x
532
533
534
535    def search_triangles_of_vertices(self, candidate_vertices, x):
536            #Find triangle containing x:
537            element_found = False
538
539            # This will be returned if element_found = False
540            sigma2 = -10.0
541            sigma0 = -10.0
542            sigma1 = -10.0
543            k = -10.0
544
545            #For all vertices in same cell as point x
546            for v in candidate_vertices:
547
548                #for each triangle id (k) which has v as a vertex
549                for k, _ in self.mesh.vertexlist[v]:
550
551                    #Get the three vertex_points of candidate triangle
552                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
553                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
554                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
555
556                    #print "PDSG - k", k
557                    #print "PDSG - xi0", xi0
558                    #print "PDSG - xi1", xi1
559                    #print "PDSG - xi2", xi2
560                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
561                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
562
563                    #Get the three normals
564                    n0 = self.mesh.get_normal(k, 0)
565                    n1 = self.mesh.get_normal(k, 1)
566                    n2 = self.mesh.get_normal(k, 2)
567
568
569                    #Compute interpolation
570                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
571                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
572                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
573
574                    #print "PDSG - sigma0", sigma0
575                    #print "PDSG - sigma1", sigma1
576                    #print "PDSG - sigma2", sigma2
577
578                    #FIXME: Maybe move out to test or something
579                    epsilon = 1.0e-6
580                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
581
582                    #Check that this triangle contains the data point
583
584                    #Sigmas can get negative within
585                    #machine precision on some machines (e.g nautilus)
586                    #Hence the small eps
587                    eps = 1.0e-15
588                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
589                        element_found = True
590                        break
591
592                if element_found is True:
593                    #Don't look for any other triangle
594                    break
595            return element_found, sigma0, sigma1, sigma2, k
596
597
598
599    def build_interpolation_matrix_A_brute(self, point_coordinates):
600        """Build n x m interpolation matrix, where
601        n is the number of data points and
602        m is the number of basis functions phi_k (one per vertex)
603
604        This is the brute force which is too slow for large problems,
605        but could be used for testing
606        """
607
608        from util import ensure_numeric
609
610        #Convert input to Numeric arrays
611        point_coordinates = ensure_numeric(point_coordinates, Float)
612
613        #Build n x m interpolation matrix
614        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
615        n = point_coordinates.shape[0]     #Nbr of data points
616
617        self.A = Sparse(n,m)
618        self.AtA = Sparse(m,m)
619
620        #Compute matrix elements
621        for i in range(n):
622            #For each data_coordinate point
623
624            x = point_coordinates[i]
625            element_found = False
626            k = 0
627            while not element_found and k < len(self.mesh):
628                #For each triangle (brute force)
629                #FIXME: Real algorithm should only visit relevant triangles
630
631                #Get the three vertex_points
632                xi0 = self.mesh.get_vertex_coordinate(k, 0)
633                xi1 = self.mesh.get_vertex_coordinate(k, 1)
634                xi2 = self.mesh.get_vertex_coordinate(k, 2)
635
636                #Get the three normals
637                n0 = self.mesh.get_normal(k, 0)
638                n1 = self.mesh.get_normal(k, 1)
639                n2 = self.mesh.get_normal(k, 2)
640
641                #Compute interpolation
642                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
643                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
644                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
645
646                #FIXME: Maybe move out to test or something
647                epsilon = 1.0e-6
648                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
649
650                #Check that this triangle contains data point
651                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
652                    element_found = True
653                    #Assign values to matrix A
654
655                    j0 = self.mesh.triangles[k,0] #Global vertex id
656                    #self.A[i, j0] = sigma0
657
658                    j1 = self.mesh.triangles[k,1] #Global vertex id
659                    #self.A[i, j1] = sigma1
660
661                    j2 = self.mesh.triangles[k,2] #Global vertex id
662                    #self.A[i, j2] = sigma2
663
664                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
665                    js     = [j0,j1,j2]
666
667                    for j in js:
668                        self.A[i,j] = sigmas[j]
669                        for k in js:
670                            self.AtA[j,k] += sigmas[j]*sigmas[k]
671                k = k+1
672
673
674
675    def get_A(self):
676        return self.A.todense()
677
678    def get_B(self):
679        return self.B.todense()
680
681    def get_D(self):
682        return self.D.todense()
683
684        #FIXME: Remember to re-introduce the 1/n factor in the
685        #interpolation term
686
687    def build_smoothing_matrix_D(self):
688        """Build m x m smoothing matrix, where
689        m is the number of basis functions phi_k (one per vertex)
690
691        The smoothing matrix is defined as
692
693        D = D1 + D2
694
695        where
696
697        [D1]_{k,l} = \int_\Omega
698           \frac{\partial \phi_k}{\partial x}
699           \frac{\partial \phi_l}{\partial x}\,
700           dx dy
701
702        [D2]_{k,l} = \int_\Omega
703           \frac{\partial \phi_k}{\partial y}
704           \frac{\partial \phi_l}{\partial y}\,
705           dx dy
706
707
708        The derivatives \frac{\partial \phi_k}{\partial x},
709        \frac{\partial \phi_k}{\partial x} for a particular triangle
710        are obtained by computing the gradient a_k, b_k for basis function k
711        """
712
713        #FIXME: algorithm might be optimised by computing local 9x9
714        #"element stiffness matrices:
715
716        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
717
718        self.D = Sparse(m,m)
719
720        #For each triangle compute contributions to D = D1+D2
721        for i in range(len(self.mesh)):
722
723            #Get area
724            area = self.mesh.areas[i]
725
726            #Get global vertex indices
727            v0 = self.mesh.triangles[i,0]
728            v1 = self.mesh.triangles[i,1]
729            v2 = self.mesh.triangles[i,2]
730
731            #Get the three vertex_points
732            xi0 = self.mesh.get_vertex_coordinate(i, 0)
733            xi1 = self.mesh.get_vertex_coordinate(i, 1)
734            xi2 = self.mesh.get_vertex_coordinate(i, 2)
735
736            #Compute gradients for each vertex
737            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
738                              1, 0, 0)
739
740            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
741                              0, 1, 0)
742
743            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
744                              0, 0, 1)
745
746            #Compute diagonal contributions
747            self.D[v0,v0] += (a0*a0 + b0*b0)*area
748            self.D[v1,v1] += (a1*a1 + b1*b1)*area
749            self.D[v2,v2] += (a2*a2 + b2*b2)*area
750
751            #Compute contributions for basis functions sharing edges
752            e01 = (a0*a1 + b0*b1)*area
753            self.D[v0,v1] += e01
754            self.D[v1,v0] += e01
755
756            e12 = (a1*a2 + b1*b2)*area
757            self.D[v1,v2] += e12
758            self.D[v2,v1] += e12
759
760            e20 = (a2*a0 + b2*b0)*area
761            self.D[v2,v0] += e20
762            self.D[v0,v2] += e20
763
764
765    def fit(self, z):
766        """Fit a smooth surface to given 1d array of data points z.
767
768        The smooth surface is computed at each vertex in the underlying
769        mesh using the formula given in the module doc string.
770
771        Pre Condition:
772          self.A, self.At and self.B have been initialised
773
774        Inputs:
775          z: Single 1d vector or array of data at the point_coordinates.
776        """
777
778        #Convert input to Numeric arrays
779        from util import ensure_numeric
780        z = ensure_numeric(z, Float)
781
782        if len(z.shape) > 1 :
783            raise VectorShapeError, 'Can only deal with 1d data vector'
784
785        if self.point_indices is not None:
786            #Remove values for any points that were outside mesh
787            z = take(z, self.point_indices)
788
789        #Compute right hand side based on data
790        Atz = self.A.trans_mult(z)
791
792
793        #Check sanity
794        n, m = self.A.shape
795        if n<m and self.alpha == 0.0:
796            msg = 'ERROR (least_squares): Too few data points\n'
797            msg += 'There are only %d data points and alpha == 0. ' %n
798            msg += 'Need at least %d\n' %m
799            msg += 'Alternatively, set smoothing parameter alpha to a small '
800            msg += 'positive value,\ne.g. 1.0e-3.'
801            raise msg
802
803
804
805        return conjugate_gradient(self.B, Atz, Atz,imax=2*len(Atz) )
806        #FIXME: Should we store the result here for later use? (ON)
807
808
809    def fit_points(self, z, verbose=False):
810        """Like fit, but more robust when each point has two or more attributes
811        FIXME (Ole): The name fit_points doesn't carry any meaning
812        for me. How about something like fit_multiple or fit_columns?
813        """
814
815        try:
816            if verbose: print 'Solving penalised least_squares problem'
817            return self.fit(z)
818        except VectorShapeError, e:
819            # broadcasting is not supported.
820
821            #Convert input to Numeric arrays
822            from util import ensure_numeric
823            z = ensure_numeric(z, Float)
824
825            #Build n x m interpolation matrix
826            m = self.mesh.coordinates.shape[0] #Number of vertices
827            n = z.shape[1]                     #Number of data points
828
829            f = zeros((m,n), Float) #Resulting columns
830
831            for i in range(z.shape[1]):
832                f[:,i] = self.fit(z[:,i])
833
834            return f
835
836
837    def interpolate(self, f):
838        """Evaluate smooth surface f at data points implied in self.A.
839
840        The mesh values representing a smooth surface are
841        assumed to be specified in f. This argument could,
842        for example have been obtained from the method self.fit()
843
844        Pre Condition:
845          self.A has been initialised
846
847        Inputs:
848          f: Vector or array of data at the mesh vertices.
849          If f is an array, interpolation will be done for each column as
850          per underlying matrix-matrix multiplication
851
852        Output:
853          Interpolated values at data points implied in self.A
854
855        """
856
857        return self.A * f
858
859    def cull_outsiders(self, f):
860        pass
861
862
863#-------------------------------------------------------------
864if __name__ == "__main__":
865    """
866    Load in a mesh and data points with attributes.
867    Fit the attributes to the mesh.
868    Save a new mesh file.
869    """
870    import os, sys
871    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha]"\
872            %os.path.basename(sys.argv[0])
873
874    if len(sys.argv) < 4:
875        print usage
876    else:
877        mesh_file = sys.argv[1]
878        point_file = sys.argv[2]
879        mesh_output_file = sys.argv[3]
880
881        expand_search = False
882        if len(sys.argv) > 4:
883            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
884                expand_search = True
885            else:
886                expand_search = False
887
888        verbose = False
889        if len(sys.argv) > 5:
890            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
891                verbose = False
892            else:
893                verbose = True
894
895        if len(sys.argv) > 6:
896            alpha = sys.argv[6]
897        else:
898            alpha = DEFAULT_ALPHA
899
900        t0 = time.time()
901        fit_to_mesh_file(mesh_file,
902                         point_file,
903                         mesh_output_file,
904                         alpha,
905                         verbose= verbose,
906                         expand_search = expand_search)
907
908        print 'That took %.2f seconds' %(time.time()-t0)
909
Note: See TracBrowser for help on using the repository browser.