source: inundation/ga/storm_surge/pyvolution/least_squares.py @ 1668

Last change on this file since 1668 was 1653, checked in by ole, 20 years ago

Comments

File size: 32.5 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False):
63    """
64    Given a mesh file (tsh) and a point attribute file (xya), fit
65    point attributes to the mesh and write a mesh file with the
66    results.
67
68
69    If data_origin is not None it is assumed to be
70    a 3-tuple with geo referenced
71    UTM coordinates (zone, easting, northing)
72
73    mesh_origin is the same but refers to the input tsh file.
74    FIXME: When the tsh format contains it own origin, these parameters can go.
75    FIXME: And both origins should be obtained from the specified files.
76    """
77
78    from load_mesh.loadASCII import import_mesh_file, \
79                 import_points_file, export_mesh_file, \
80                 concatinate_attributelist
81
82    mesh_dict = import_mesh_file(mesh_file)
83    vertex_coordinates = mesh_dict['vertices']
84    triangles = mesh_dict['triangles']
85    if type(mesh_dict['vertex_attributes']) == ArrayType:
86        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
87    else:
88        old_point_attributes = mesh_dict['vertex_attributes']
89
90    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
91        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
92    else:
93        old_title_list = mesh_dict['vertex_attribute_titles']
94
95    if verbose: print 'tsh file %s loaded' %mesh_file
96
97    # load in the .pts file
98    try:
99        point_dict = import_points_file(point_file,
100                                      delimiter = ',',
101                                      verbose=verbose)
102    except SyntaxError,e:
103        point_dict = import_points_file(point_file,
104                                      delimiter = ' ',
105                                      verbose=verbose)
106
107    point_coordinates = point_dict['pointlist']
108    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
109
110    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
111        data_origin = point_dict['geo_reference'].get_origin()
112    else:
113        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
114
115    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
116        mesh_origin = mesh_dict['geo_reference'].get_origin()
117    else:
118        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
119
120    if verbose: print "points file loaded"
121    if verbose:print "fitting to mesh"
122    f = fit_to_mesh(vertex_coordinates,
123                    triangles,
124                    point_coordinates,
125                    point_attributes,
126                    alpha = alpha,
127                    verbose = verbose,
128                    expand_search = expand_search,
129                    data_origin = data_origin,
130                    mesh_origin = mesh_origin,
131                    precrop = precrop)
132    if verbose: print "finished fitting to mesh"
133
134    # convert array to list of lists
135    new_point_attributes = f.tolist()
136    #FIXME have this overwrite attributes with the same title - DSG
137    #Put the newer attributes last
138    if old_title_list <> []:
139        old_title_list.extend(title_list)
140        #FIXME can this be done a faster way? - DSG
141        for i in range(len(old_point_attributes)):
142            old_point_attributes[i].extend(new_point_attributes[i])
143        mesh_dict['vertex_attributes'] = old_point_attributes
144        mesh_dict['vertex_attribute_titles'] = old_title_list
145    else:
146        mesh_dict['vertex_attributes'] = new_point_attributes
147        mesh_dict['vertex_attribute_titles'] = title_list
148
149    #FIXME (Ole): Remember to output mesh_origin as well
150    if verbose: print "exporting to file ",mesh_output_file
151    export_mesh_file(mesh_output_file, mesh_dict)
152
153
154def fit_to_mesh(vertex_coordinates,
155                triangles,
156                point_coordinates,
157                point_attributes,
158                alpha = DEFAULT_ALPHA,
159                verbose = False,
160                expand_search = False,
161                data_origin = None,
162                mesh_origin = None,
163                precrop = False):
164    """
165    Fit a smooth surface to a triangulation,
166    given data points with attributes.
167
168
169        Inputs:
170
171          vertex_coordinates: List of coordinate pairs [xi, eta] of points
172          constituting mesh (or a an m x 2 Numeric array)
173
174          triangles: List of 3-tuples (or a Numeric array) of
175          integers representing indices of all vertices in the mesh.
176
177          point_coordinates: List of coordinate pairs [x, y] of data points
178          (or an nx2 Numeric array)
179
180          alpha: Smoothing parameter.
181
182          point_attributes: Vector or array of data at the point_coordinates.
183
184          data_origin and mesh_origin are 3-tuples consisting of
185          UTM zone, easting and northing. If specified
186          point coordinates and vertex coordinates are assumed to be
187          relative to their respective origins.
188
189    """
190    interp = Interpolation(vertex_coordinates,
191                           triangles,
192                           point_coordinates,
193                           alpha = alpha,
194                           verbose = verbose,
195                           expand_search = expand_search,
196                           data_origin = data_origin,
197                           mesh_origin = mesh_origin,
198                           precrop = precrop)
199
200    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
201    return vertex_attributes
202
203
204
205def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
206                    verbose = False, reduction = 1, format = 'netcdf'):
207    """Fits attributes from pts file to MxN rectangular mesh
208
209    Read pts file and create rectangular mesh of resolution MxN such that
210    it covers all points specified in pts file.
211
212    FIXME: This may be a temporary function until we decide on
213    netcdf formats etc
214
215    FIXME: Uses elevation hardwired
216    """
217
218    import util, mesh_factory
219
220    if verbose: print 'Read pts'
221    points, attributes = util.read_xya(pts_name, format)
222
223    #Reduce number of points a bit
224    points = points[::reduction]
225    elevation = attributes['elevation']  #Must be elevation
226    elevation = elevation[::reduction]
227
228    if verbose: print 'Got %d data points' %len(points)
229
230    if verbose: print 'Create mesh'
231    #Find extent
232    max_x = min_x = points[0][0]
233    max_y = min_y = points[0][1]
234    for point in points[1:]:
235        x = point[0]
236        if x > max_x: max_x = x
237        if x < min_x: min_x = x
238        y = point[1]
239        if y > max_y: max_y = y
240        if y < min_y: min_y = y
241
242    #Create appropriate mesh
243    vertex_coordinates, triangles, boundary =\
244         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
245                                (min_x, min_y))
246
247    #Fit attributes to mesh
248    vertex_attributes = fit_to_mesh(vertex_coordinates,
249                        triangles,
250                        points,
251                        elevation, alpha=alpha, verbose=verbose)
252
253
254
255    return vertex_coordinates, triangles, boundary, vertex_attributes
256
257
258
259class Interpolation:
260
261    def __init__(self,
262                 vertex_coordinates,
263                 triangles,
264                 point_coordinates = None,
265                 alpha = DEFAULT_ALPHA,
266                 verbose = False,
267                 expand_search = True,
268                 max_points_per_cell = 30,
269                 mesh_origin = None,
270                 data_origin = None,
271                 precrop = False):
272
273
274        """ Build interpolation matrix mapping from
275        function values at vertices to function values at data points
276
277        Inputs:
278
279          vertex_coordinates: List of coordinate pairs [xi, eta] of
280          points constituting mesh (or a an m x 2 Numeric array)
281          Points may appear multiple times
282          (e.g. if vertices have discontinuities)
283
284          triangles: List of 3-tuples (or a Numeric array) of
285          integers representing indices of all vertices in the mesh.
286
287          point_coordinates: List of coordinate pairs [x, y] of
288          data points (or an nx2 Numeric array)
289          If point_coordinates is absent, only smoothing matrix will
290          be built
291
292          alpha: Smoothing parameter
293
294          data_origin and mesh_origin are 3-tuples consisting of
295          UTM zone, easting and northing. If specified
296          point coordinates and vertex coordinates are assumed to be
297          relative to their respective origins.
298
299        """
300        from util import ensure_numeric
301
302        #Convert input to Numeric arrays
303        triangles = ensure_numeric(triangles, Int)
304        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
305
306        #Build underlying mesh
307        if verbose: print 'Building mesh'
308        #self.mesh = General_mesh(vertex_coordinates, triangles,
309        #FIXME: Trying the normal mesh while testing precrop,
310        #       The functionality of boundary_polygon is needed for that
311
312        #FIXME - geo ref does not have to go into mesh.
313        # Change the point co-ords to conform to the
314        # mesh co-ords early in the code
315        if mesh_origin == None:
316            geo = None
317        else:
318            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
319        self.mesh = Mesh(vertex_coordinates, triangles,
320                         geo_reference = geo)
321        #FIXME, remove if mesh checks it.
322        self.mesh.check_integrity()
323        self.data_origin = data_origin
324
325        self.point_indices = None
326
327        #Smoothing parameter
328        self.alpha = alpha
329
330        #Build coefficient matrices
331        self.build_coefficient_matrix_B(point_coordinates,
332                                        verbose = verbose,
333                                        expand_search = expand_search,
334                                        max_points_per_cell =\
335                                        max_points_per_cell,
336                                        data_origin = data_origin,
337                                        precrop = precrop)
338
339
340    def set_point_coordinates(self, point_coordinates,
341                              data_origin = None):
342        """
343        A public interface to setting the point co-ordinates.
344        """
345        self.build_coefficient_matrix_B(point_coordinates, data_origin)
346
347    def build_coefficient_matrix_B(self, point_coordinates=None,
348                                   verbose = False, expand_search = True,
349                                   max_points_per_cell=30,
350                                   data_origin = None,
351                                   precrop = False):
352        """Build final coefficient matrix"""
353
354
355        if self.alpha <> 0:
356            if verbose: print 'Building smoothing matrix'
357            self.build_smoothing_matrix_D()
358
359        if point_coordinates is not None:
360
361            if verbose: print 'Building interpolation matrix'
362            self.build_interpolation_matrix_A(point_coordinates,
363                                              verbose = verbose,
364                                              expand_search = expand_search,
365                                              max_points_per_cell =\
366                                              max_points_per_cell,
367                                              data_origin = data_origin,
368                                              precrop = precrop)
369
370            if self.alpha <> 0:
371                self.B = self.AtA + self.alpha*self.D
372            else:
373                self.B = self.AtA
374
375            #Convert self.B matrix to CSR format for faster matrix vector
376            self.B = Sparse_CSR(self.B)
377
378    def build_interpolation_matrix_A(self, point_coordinates,
379                                     verbose = False, expand_search = True,
380                                     max_points_per_cell=30,
381                                     data_origin = None,
382                                     precrop = False):
383        """Build n x m interpolation matrix, where
384        n is the number of data points and
385        m is the number of basis functions phi_k (one per vertex)
386
387        This algorithm uses a quad tree data structure for fast binning of data points
388        origin is a 3-tuple consisting of UTM zone, easting and northing.
389        If specified coordinates are assumed to be relative to this origin.
390
391        This one will override any data_origin that may be specified in
392        interpolation instance
393
394        """
395
396
397        #FIXME (Ole): Check that this function is memeory efficient.
398        #6 million datapoints and 300000 basis functions
399        #causes out-of-memory situation
400        #First thing to check is whether there is room for self.A and self.AtA
401        #
402        #Maybe we need some sort of blocking
403
404        from quad import build_quadtree
405        from util import ensure_numeric
406
407        if data_origin is None:
408            data_origin = self.data_origin #Use the one from
409                                           #interpolation instance
410
411        #Convert input to Numeric arrays just in case.
412        point_coordinates = ensure_numeric(point_coordinates, Float)
413
414
415        #Shift data points to same origin as mesh (if specified)
416
417        #FIXME this will shift if there was no geo_ref.
418        #But all this should be removed anyhow.
419        #change coords before this point
420        mesh_origin = self.mesh.geo_reference.get_origin()
421        if point_coordinates is not None:
422            if data_origin is not None:
423                if mesh_origin is not None:
424
425                    #Transformation:
426                    #
427                    #Let x_0 be the reference point of the point coordinates
428                    #and xi_0 the reference point of the mesh.
429                    #
430                    #A point coordinate (x + x_0) is then made relative
431                    #to xi_0 by
432                    #
433                    # x_new = x + x_0 - xi_0
434                    #
435                    #and similarly for eta
436
437                    x_offset = data_origin[1] - mesh_origin[1]
438                    y_offset = data_origin[2] - mesh_origin[2]
439                else: #Shift back to a zero origin
440                    x_offset = data_origin[1]
441                    y_offset = data_origin[2]
442
443                point_coordinates[:,0] += x_offset
444                point_coordinates[:,1] += y_offset
445            else:
446                if mesh_origin is not None:
447                    #Use mesh origin for data points
448                    point_coordinates[:,0] -= mesh_origin[1]
449                    point_coordinates[:,1] -= mesh_origin[2]
450
451
452
453        #Remove points falling outside mesh boundary
454        #This reduced one example from 1356 seconds to 825 seconds
455        if precrop is True:
456            from Numeric import take
457            from util import inside_polygon
458
459            if verbose: print 'Getting boundary polygon'
460            P = self.mesh.get_boundary_polygon()
461
462            if verbose: print 'Getting indices inside mesh boundary'
463            indices = inside_polygon(point_coordinates, P, verbose = verbose)
464
465            if verbose:
466                print 'Done'
467                if len(indices) != point_coordinates.shape[0]:
468                    print '%d points outside mesh have been cropped.'\
469                          %(point_coordinates.shape[0] - len(indices))
470            point_coordinates = take(point_coordinates, indices)
471            self.point_indices = indices
472
473
474
475
476        #Build n x m interpolation matrix
477        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
478        n = point_coordinates.shape[0]     #Nbr of data points
479
480        if verbose: print 'Number of datapoints: %d' %n
481        if verbose: print 'Number of basis functions: %d' %m
482
483        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
484        #However, Sparse_CSR does not have the same methods as Sparse yet
485        #The tests will reveal what needs to be done
486        self.A = Sparse(n,m)
487        self.AtA = Sparse(m,m)
488
489        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
490        root = build_quadtree(self.mesh,
491                              max_points_per_cell = max_points_per_cell)
492
493        #Compute matrix elements
494        for i in range(n):
495            #For each data_coordinate point
496
497            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
498
499            x = point_coordinates[i]
500
501            #Find vertices near x
502            candidate_vertices = root.search(x[0], x[1])
503            is_more_elements = True
504
505            element_found, sigma0, sigma1, sigma2, k = \
506                self.search_triangles_of_vertices(candidate_vertices, x)
507            while not element_found and is_more_elements and expand_search:
508                #if verbose: print 'Expanding search'
509                candidate_vertices, branch = root.expand_search()
510                if branch == []:
511                    # Searching all the verts from the root cell that haven't
512                    # been searched.  This is the last try
513                    element_found, sigma0, sigma1, sigma2, k = \
514                      self.search_triangles_of_vertices(candidate_vertices, x)
515                    is_more_elements = False
516                else:
517                    element_found, sigma0, sigma1, sigma2, k = \
518                      self.search_triangles_of_vertices(candidate_vertices, x)
519
520
521            #Update interpolation matrix A if necessary
522            if element_found is True:
523                #Assign values to matrix A
524
525                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
526                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
527                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
528
529                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
530                js     = [j0,j1,j2]
531
532                for j in js:
533                    self.A[i,j] = sigmas[j]
534                    for k in js:
535                        self.AtA[j,k] += sigmas[j]*sigmas[k]
536            else:
537                pass
538                #Ok if there is no triangle for datapoint
539                #(as in brute force version)
540                #raise 'Could not find triangle for point', x
541
542
543
544    def search_triangles_of_vertices(self, candidate_vertices, x):
545            #Find triangle containing x:
546            element_found = False
547
548            # This will be returned if element_found = False
549            sigma2 = -10.0
550            sigma0 = -10.0
551            sigma1 = -10.0
552            k = -10.0
553
554            #For all vertices in same cell as point x
555            for v in candidate_vertices:
556
557                #for each triangle id (k) which has v as a vertex
558                for k, _ in self.mesh.vertexlist[v]:
559
560                    #Get the three vertex_points of candidate triangle
561                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
562                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
563                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
564
565                    #print "PDSG - k", k
566                    #print "PDSG - xi0", xi0
567                    #print "PDSG - xi1", xi1
568                    #print "PDSG - xi2", xi2
569                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
570                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
571
572                    #Get the three normals
573                    n0 = self.mesh.get_normal(k, 0)
574                    n1 = self.mesh.get_normal(k, 1)
575                    n2 = self.mesh.get_normal(k, 2)
576
577
578                    #Compute interpolation
579                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
580                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
581                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
582
583                    #print "PDSG - sigma0", sigma0
584                    #print "PDSG - sigma1", sigma1
585                    #print "PDSG - sigma2", sigma2
586
587                    #FIXME: Maybe move out to test or something
588                    epsilon = 1.0e-6
589                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
590
591                    #Check that this triangle contains the data point
592
593                    #Sigmas can get negative within
594                    #machine precision on some machines (e.g nautilus)
595                    #Hence the small eps
596                    eps = 1.0e-15
597                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
598                        element_found = True
599                        break
600
601                if element_found is True:
602                    #Don't look for any other triangle
603                    break
604            return element_found, sigma0, sigma1, sigma2, k
605
606
607
608    def build_interpolation_matrix_A_brute(self, point_coordinates):
609        """Build n x m interpolation matrix, where
610        n is the number of data points and
611        m is the number of basis functions phi_k (one per vertex)
612
613        This is the brute force which is too slow for large problems,
614        but could be used for testing
615        """
616
617        from util import ensure_numeric
618
619        #Convert input to Numeric arrays
620        point_coordinates = ensure_numeric(point_coordinates, Float)
621
622        #Build n x m interpolation matrix
623        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
624        n = point_coordinates.shape[0]     #Nbr of data points
625
626        self.A = Sparse(n,m)
627        self.AtA = Sparse(m,m)
628
629        #Compute matrix elements
630        for i in range(n):
631            #For each data_coordinate point
632
633            x = point_coordinates[i]
634            element_found = False
635            k = 0
636            while not element_found and k < len(self.mesh):
637                #For each triangle (brute force)
638                #FIXME: Real algorithm should only visit relevant triangles
639
640                #Get the three vertex_points
641                xi0 = self.mesh.get_vertex_coordinate(k, 0)
642                xi1 = self.mesh.get_vertex_coordinate(k, 1)
643                xi2 = self.mesh.get_vertex_coordinate(k, 2)
644
645                #Get the three normals
646                n0 = self.mesh.get_normal(k, 0)
647                n1 = self.mesh.get_normal(k, 1)
648                n2 = self.mesh.get_normal(k, 2)
649
650                #Compute interpolation
651                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
652                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
653                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
654
655                #FIXME: Maybe move out to test or something
656                epsilon = 1.0e-6
657                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
658
659                #Check that this triangle contains data point
660                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
661                    element_found = True
662                    #Assign values to matrix A
663
664                    j0 = self.mesh.triangles[k,0] #Global vertex id
665                    #self.A[i, j0] = sigma0
666
667                    j1 = self.mesh.triangles[k,1] #Global vertex id
668                    #self.A[i, j1] = sigma1
669
670                    j2 = self.mesh.triangles[k,2] #Global vertex id
671                    #self.A[i, j2] = sigma2
672
673                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
674                    js     = [j0,j1,j2]
675
676                    for j in js:
677                        self.A[i,j] = sigmas[j]
678                        for k in js:
679                            self.AtA[j,k] += sigmas[j]*sigmas[k]
680                k = k+1
681
682
683
684    def get_A(self):
685        return self.A.todense()
686
687    def get_B(self):
688        return self.B.todense()
689
690    def get_D(self):
691        return self.D.todense()
692
693        #FIXME: Remember to re-introduce the 1/n factor in the
694        #interpolation term
695
696    def build_smoothing_matrix_D(self):
697        """Build m x m smoothing matrix, where
698        m is the number of basis functions phi_k (one per vertex)
699
700        The smoothing matrix is defined as
701
702        D = D1 + D2
703
704        where
705
706        [D1]_{k,l} = \int_\Omega
707           \frac{\partial \phi_k}{\partial x}
708           \frac{\partial \phi_l}{\partial x}\,
709           dx dy
710
711        [D2]_{k,l} = \int_\Omega
712           \frac{\partial \phi_k}{\partial y}
713           \frac{\partial \phi_l}{\partial y}\,
714           dx dy
715
716
717        The derivatives \frac{\partial \phi_k}{\partial x},
718        \frac{\partial \phi_k}{\partial x} for a particular triangle
719        are obtained by computing the gradient a_k, b_k for basis function k
720        """
721
722        #FIXME: algorithm might be optimised by computing local 9x9
723        #"element stiffness matrices:
724
725        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
726
727        self.D = Sparse(m,m)
728
729        #For each triangle compute contributions to D = D1+D2
730        for i in range(len(self.mesh)):
731
732            #Get area
733            area = self.mesh.areas[i]
734
735            #Get global vertex indices
736            v0 = self.mesh.triangles[i,0]
737            v1 = self.mesh.triangles[i,1]
738            v2 = self.mesh.triangles[i,2]
739
740            #Get the three vertex_points
741            xi0 = self.mesh.get_vertex_coordinate(i, 0)
742            xi1 = self.mesh.get_vertex_coordinate(i, 1)
743            xi2 = self.mesh.get_vertex_coordinate(i, 2)
744
745            #Compute gradients for each vertex
746            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
747                              1, 0, 0)
748
749            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
750                              0, 1, 0)
751
752            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
753                              0, 0, 1)
754
755            #Compute diagonal contributions
756            self.D[v0,v0] += (a0*a0 + b0*b0)*area
757            self.D[v1,v1] += (a1*a1 + b1*b1)*area
758            self.D[v2,v2] += (a2*a2 + b2*b2)*area
759
760            #Compute contributions for basis functions sharing edges
761            e01 = (a0*a1 + b0*b1)*area
762            self.D[v0,v1] += e01
763            self.D[v1,v0] += e01
764
765            e12 = (a1*a2 + b1*b2)*area
766            self.D[v1,v2] += e12
767            self.D[v2,v1] += e12
768
769            e20 = (a2*a0 + b2*b0)*area
770            self.D[v2,v0] += e20
771            self.D[v0,v2] += e20
772
773
774    def fit(self, z):
775        """Fit a smooth surface to given 1d array of data points z.
776
777        The smooth surface is computed at each vertex in the underlying
778        mesh using the formula given in the module doc string.
779
780        Pre Condition:
781          self.A, self.AtA and self.B have been initialised
782
783        Inputs:
784          z: Single 1d vector or array of data at the point_coordinates.
785        """
786
787        #Convert input to Numeric arrays
788        from util import ensure_numeric
789        z = ensure_numeric(z, Float)
790
791        if len(z.shape) > 1 :
792            raise VectorShapeError, 'Can only deal with 1d data vector'
793
794        if self.point_indices is not None:
795            #Remove values for any points that were outside mesh
796            z = take(z, self.point_indices)
797
798        #Compute right hand side based on data
799        Atz = self.A.trans_mult(z)
800
801
802        #Check sanity
803        n, m = self.A.shape
804        if n<m and self.alpha == 0.0:
805            msg = 'ERROR (least_squares): Too few data points\n'
806            msg += 'There are only %d data points and alpha == 0. ' %n
807            msg += 'Need at least %d\n' %m
808            msg += 'Alternatively, set smoothing parameter alpha to a small '
809            msg += 'positive value,\ne.g. 1.0e-3.'
810            raise msg
811
812
813
814        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
815        #FIXME: Should we store the result here for later use? (ON)
816
817
818    def fit_points(self, z, verbose=False):
819        """Like fit, but more robust when each point has two or more attributes
820        FIXME (Ole): The name fit_points doesn't carry any meaning
821        for me. How about something like fit_multiple or fit_columns?
822        """
823
824        try:
825            if verbose: print 'Solving penalised least_squares problem'
826            return self.fit(z)
827        except VectorShapeError, e:
828            # broadcasting is not supported.
829
830            #Convert input to Numeric arrays
831            from util import ensure_numeric
832            z = ensure_numeric(z, Float)
833
834            #Build n x m interpolation matrix
835            m = self.mesh.coordinates.shape[0] #Number of vertices
836            n = z.shape[1]                     #Number of data points
837
838            f = zeros((m,n), Float) #Resulting columns
839
840            for i in range(z.shape[1]):
841                f[:,i] = self.fit(z[:,i])
842
843            return f
844
845
846    def interpolate(self, f):
847        """Evaluate smooth surface f at data points implied in self.A.
848
849        The mesh values representing a smooth surface are
850        assumed to be specified in f. This argument could,
851        for example have been obtained from the method self.fit()
852
853        Pre Condition:
854          self.A has been initialised
855
856        Inputs:
857          f: Vector or array of data at the mesh vertices.
858          If f is an array, interpolation will be done for each column as
859          per underlying matrix-matrix multiplication
860
861        Output:
862          Interpolated values at data points implied in self.A
863
864        """
865
866        return self.A * f
867
868    def cull_outsiders(self, f):
869        pass
870
871
872#-------------------------------------------------------------
873if __name__ == "__main__":
874    """
875    Load in a mesh and data points with attributes.
876    Fit the attributes to the mesh.
877    Save a new mesh file.
878    """
879    import os, sys
880    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha]"\
881            %os.path.basename(sys.argv[0])
882
883    if len(sys.argv) < 4:
884        print usage
885    else:
886        mesh_file = sys.argv[1]
887        point_file = sys.argv[2]
888        mesh_output_file = sys.argv[3]
889
890        expand_search = False
891        if len(sys.argv) > 4:
892            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
893                expand_search = True
894            else:
895                expand_search = False
896
897        verbose = False
898        if len(sys.argv) > 5:
899            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
900                verbose = False
901            else:
902                verbose = True
903
904        if len(sys.argv) > 6:
905            alpha = sys.argv[6]
906        else:
907            alpha = DEFAULT_ALPHA
908
909        t0 = time.time()
910        fit_to_mesh_file(mesh_file,
911                         point_file,
912                         mesh_output_file,
913                         alpha,
914                         verbose= verbose,
915                         expand_search = expand_search)
916
917        print 'That took %.2f seconds' %(time.time()-t0)
918
Note: See TracBrowser for help on using the repository browser.