source: inundation/pyvolution/least_squares.py @ 1812

Last change on this file since 1812 was 1787, checked in by ole, 19 years ago

Better output

File size: 44.3 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False):
63    """
64    Given a mesh file (tsh) and a point attribute file (xya), fit
65    point attributes to the mesh and write a mesh file with the
66    results.
67
68
69    If data_origin is not None it is assumed to be
70    a 3-tuple with geo referenced
71    UTM coordinates (zone, easting, northing)
72
73    mesh_origin is the same but refers to the input tsh file.
74    FIXME: When the tsh format contains it own origin, these parameters can go.
75    FIXME: And both origins should be obtained from the specified files.
76    """
77
78    from load_mesh.loadASCII import import_mesh_file, \
79                 import_points_file, export_mesh_file, \
80                 concatinate_attributelist
81
82    mesh_dict = import_mesh_file(mesh_file)
83    vertex_coordinates = mesh_dict['vertices']
84    triangles = mesh_dict['triangles']
85    if type(mesh_dict['vertex_attributes']) == ArrayType:
86        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
87    else:
88        old_point_attributes = mesh_dict['vertex_attributes']
89
90    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
91        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
92    else:
93        old_title_list = mesh_dict['vertex_attribute_titles']
94
95    if verbose: print 'tsh file %s loaded' %mesh_file
96
97    # load in the .pts file
98    try:
99        point_dict = import_points_file(point_file,
100                                      delimiter = ',',
101                                      verbose=verbose)
102    except SyntaxError,e:
103        point_dict = import_points_file(point_file,
104                                      delimiter = ' ',
105                                      verbose=verbose)
106
107    point_coordinates = point_dict['pointlist']
108    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
109
110    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
111        data_origin = point_dict['geo_reference'].get_origin()
112    else:
113        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
114
115    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
116        mesh_origin = mesh_dict['geo_reference'].get_origin()
117    else:
118        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
119
120    if verbose: print "points file loaded"
121    if verbose:print "fitting to mesh"
122    f = fit_to_mesh(vertex_coordinates,
123                    triangles,
124                    point_coordinates,
125                    point_attributes,
126                    alpha = alpha,
127                    verbose = verbose,
128                    expand_search = expand_search,
129                    data_origin = data_origin,
130                    mesh_origin = mesh_origin,
131                    precrop = precrop)
132    if verbose: print "finished fitting to mesh"
133
134    # convert array to list of lists
135    new_point_attributes = f.tolist()
136    #FIXME have this overwrite attributes with the same title - DSG
137    #Put the newer attributes last
138    if old_title_list <> []:
139        old_title_list.extend(title_list)
140        #FIXME can this be done a faster way? - DSG
141        for i in range(len(old_point_attributes)):
142            old_point_attributes[i].extend(new_point_attributes[i])
143        mesh_dict['vertex_attributes'] = old_point_attributes
144        mesh_dict['vertex_attribute_titles'] = old_title_list
145    else:
146        mesh_dict['vertex_attributes'] = new_point_attributes
147        mesh_dict['vertex_attribute_titles'] = title_list
148
149    #FIXME (Ole): Remember to output mesh_origin as well
150    if verbose: print "exporting to file ",mesh_output_file
151    export_mesh_file(mesh_output_file, mesh_dict)
152
153
154def fit_to_mesh(vertex_coordinates,
155                triangles,
156                point_coordinates,
157                point_attributes,
158                alpha = DEFAULT_ALPHA,
159                verbose = False,
160                expand_search = False,
161                data_origin = None,
162                mesh_origin = None,
163                precrop = False):
164    """
165    Fit a smooth surface to a triangulation,
166    given data points with attributes.
167
168
169        Inputs:
170
171          vertex_coordinates: List of coordinate pairs [xi, eta] of points
172          constituting mesh (or a an m x 2 Numeric array)
173
174          triangles: List of 3-tuples (or a Numeric array) of
175          integers representing indices of all vertices in the mesh.
176
177          point_coordinates: List of coordinate pairs [x, y] of data points
178          (or an nx2 Numeric array)
179
180          alpha: Smoothing parameter.
181
182          point_attributes: Vector or array of data at the point_coordinates.
183
184          data_origin and mesh_origin are 3-tuples consisting of
185          UTM zone, easting and northing. If specified
186          point coordinates and vertex coordinates are assumed to be
187          relative to their respective origins.
188
189    """
190    interp = Interpolation(vertex_coordinates,
191                           triangles,
192                           point_coordinates,
193                           alpha = alpha,
194                           verbose = verbose,
195                           expand_search = expand_search,
196                           data_origin = data_origin,
197                           mesh_origin = mesh_origin,
198                           precrop = precrop)
199
200    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
201    return vertex_attributes
202
203
204
205def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
206                    verbose = False, reduction = 1, format = 'netcdf'):
207    """Fits attributes from pts file to MxN rectangular mesh
208
209    Read pts file and create rectangular mesh of resolution MxN such that
210    it covers all points specified in pts file.
211
212    FIXME: This may be a temporary function until we decide on
213    netcdf formats etc
214
215    FIXME: Uses elevation hardwired
216    """
217
218    import util, mesh_factory
219
220    if verbose: print 'Read pts'
221    points, attributes = util.read_xya(pts_name, format)
222
223    #Reduce number of points a bit
224    points = points[::reduction]
225    elevation = attributes['elevation']  #Must be elevation
226    elevation = elevation[::reduction]
227
228    if verbose: print 'Got %d data points' %len(points)
229
230    if verbose: print 'Create mesh'
231    #Find extent
232    max_x = min_x = points[0][0]
233    max_y = min_y = points[0][1]
234    for point in points[1:]:
235        x = point[0]
236        if x > max_x: max_x = x
237        if x < min_x: min_x = x
238        y = point[1]
239        if y > max_y: max_y = y
240        if y < min_y: min_y = y
241
242    #Create appropriate mesh
243    vertex_coordinates, triangles, boundary =\
244         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
245                                (min_x, min_y))
246
247    #Fit attributes to mesh
248    vertex_attributes = fit_to_mesh(vertex_coordinates,
249                        triangles,
250                        points,
251                        elevation, alpha=alpha, verbose=verbose)
252
253
254
255    return vertex_coordinates, triangles, boundary, vertex_attributes
256
257
258
259class Interpolation:
260
261    def __init__(self,
262                 vertex_coordinates,
263                 triangles,
264                 point_coordinates = None,
265                 alpha = None,
266                 verbose = False,
267                 expand_search = True,
268                 max_points_per_cell = 30,
269                 mesh_origin = None,
270                 data_origin = None,
271                 precrop = False):
272
273
274        """ Build interpolation matrix mapping from
275        function values at vertices to function values at data points
276
277        Inputs:
278
279          vertex_coordinates: List of coordinate pairs [xi, eta] of
280          points constituting mesh (or a an m x 2 Numeric array)
281          Points may appear multiple times
282          (e.g. if vertices have discontinuities)
283
284          triangles: List of 3-tuples (or a Numeric array) of
285          integers representing indices of all vertices in the mesh.
286
287          point_coordinates: List of coordinate pairs [x, y] of
288          data points (or an nx2 Numeric array)
289          If point_coordinates is absent, only smoothing matrix will
290          be built
291
292          alpha: Smoothing parameter
293
294          data_origin and mesh_origin are 3-tuples consisting of
295          UTM zone, easting and northing. If specified
296          point coordinates and vertex coordinates are assumed to be
297          relative to their respective origins.
298
299        """
300        from util import ensure_numeric
301
302        #Convert input to Numeric arrays
303        triangles = ensure_numeric(triangles, Int)
304        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
305
306        #Build underlying mesh
307        if verbose: print 'Building mesh'
308        #self.mesh = General_mesh(vertex_coordinates, triangles,
309        #FIXME: Trying the normal mesh while testing precrop,
310        #       The functionality of boundary_polygon is needed for that
311
312        #FIXME - geo ref does not have to go into mesh.
313        # Change the point co-ords to conform to the
314        # mesh co-ords early in the code
315        if mesh_origin == None:
316            geo = None
317        else:
318            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
319        self.mesh = Mesh(vertex_coordinates, triangles,
320                         geo_reference = geo)
321       
322        self.mesh.check_integrity()
323
324        self.data_origin = data_origin
325
326        self.point_indices = None
327
328        #Smoothing parameter
329        if alpha is None:
330            self.alpha = DEFAULT_ALPHA
331        else:   
332            self.alpha = alpha
333
334        #Build coefficient matrices
335        self.build_coefficient_matrix_B(point_coordinates,
336                                        verbose = verbose,
337                                        expand_search = expand_search,
338                                        max_points_per_cell =\
339                                        max_points_per_cell,
340                                        data_origin = data_origin,
341                                        precrop = precrop)
342
343
344    def set_point_coordinates(self, point_coordinates,
345                              data_origin = None):
346        """
347        A public interface to setting the point co-ordinates.
348        """
349        self.build_coefficient_matrix_B(point_coordinates, data_origin)
350
351    def build_coefficient_matrix_B(self, point_coordinates=None,
352                                   verbose = False, expand_search = True,
353                                   max_points_per_cell=30,
354                                   data_origin = None,
355                                   precrop = False):
356        """Build final coefficient matrix"""
357
358
359        if self.alpha <> 0:
360            if verbose: print 'Building smoothing matrix'
361            self.build_smoothing_matrix_D()
362
363        if point_coordinates is not None:
364
365            if verbose: print 'Building interpolation matrix'
366            self.build_interpolation_matrix_A(point_coordinates,
367                                              verbose = verbose,
368                                              expand_search = expand_search,
369                                              max_points_per_cell =\
370                                              max_points_per_cell,
371                                              data_origin = data_origin,
372                                              precrop = precrop)
373
374            if self.alpha <> 0:
375                self.B = self.AtA + self.alpha*self.D
376            else:
377                self.B = self.AtA
378
379            #Convert self.B matrix to CSR format for faster matrix vector
380            self.B = Sparse_CSR(self.B)
381
382    def build_interpolation_matrix_A(self, point_coordinates,
383                                     verbose = False, expand_search = True,
384                                     max_points_per_cell=30,
385                                     data_origin = None,
386                                     precrop = False):
387        """Build n x m interpolation matrix, where
388        n is the number of data points and
389        m is the number of basis functions phi_k (one per vertex)
390
391        This algorithm uses a quad tree data structure for fast binning of data points
392        origin is a 3-tuple consisting of UTM zone, easting and northing.
393        If specified coordinates are assumed to be relative to this origin.
394
395        This one will override any data_origin that may be specified in
396        interpolation instance
397
398        """
399
400
401        #FIXME (Ole): Check that this function is memeory efficient.
402        #6 million datapoints and 300000 basis functions
403        #causes out-of-memory situation
404        #First thing to check is whether there is room for self.A and self.AtA
405        #
406        #Maybe we need some sort of blocking
407
408        from quad import build_quadtree
409        from util import ensure_numeric
410
411        if data_origin is None:
412            data_origin = self.data_origin #Use the one from
413                                           #interpolation instance
414
415        #Convert input to Numeric arrays just in case.
416        point_coordinates = ensure_numeric(point_coordinates, Float)
417
418        #Keep track of discarded points (if any).
419        #This is only registered if precrop is True
420        self.cropped_points = False
421
422        #Shift data points to same origin as mesh (if specified)
423
424        #FIXME this will shift if there was no geo_ref.
425        #But all this should be removed anyhow.
426        #change coords before this point
427        mesh_origin = self.mesh.geo_reference.get_origin()
428        if point_coordinates is not None:
429            if data_origin is not None:
430                if mesh_origin is not None:
431
432                    #Transformation:
433                    #
434                    #Let x_0 be the reference point of the point coordinates
435                    #and xi_0 the reference point of the mesh.
436                    #
437                    #A point coordinate (x + x_0) is then made relative
438                    #to xi_0 by
439                    #
440                    # x_new = x + x_0 - xi_0
441                    #
442                    #and similarly for eta
443
444                    x_offset = data_origin[1] - mesh_origin[1]
445                    y_offset = data_origin[2] - mesh_origin[2]
446                else: #Shift back to a zero origin
447                    x_offset = data_origin[1]
448                    y_offset = data_origin[2]
449
450                point_coordinates[:,0] += x_offset
451                point_coordinates[:,1] += y_offset
452            else:
453                if mesh_origin is not None:
454                    #Use mesh origin for data points
455                    point_coordinates[:,0] -= mesh_origin[1]
456                    point_coordinates[:,1] -= mesh_origin[2]
457
458
459
460        #Remove points falling outside mesh boundary
461        #This reduced one example from 1356 seconds to 825 seconds
462        if precrop is True:
463            from Numeric import take
464            from util import inside_polygon
465
466            if verbose: print 'Getting boundary polygon'
467            P = self.mesh.get_boundary_polygon()
468
469            if verbose: print 'Getting indices inside mesh boundary'
470            indices = inside_polygon(point_coordinates, P, verbose = verbose)
471
472
473            if len(indices) != point_coordinates.shape[0]:
474                self.cropped_points = True
475                if verbose:
476                    print 'Done - %d points outside mesh have been cropped.'\
477                          %(point_coordinates.shape[0] - len(indices))
478
479            point_coordinates = take(point_coordinates, indices)
480            self.point_indices = indices
481
482
483
484
485        #Build n x m interpolation matrix
486        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
487        n = point_coordinates.shape[0]     #Nbr of data points
488
489        if verbose: print 'Number of datapoints: %d' %n
490        if verbose: print 'Number of basis functions: %d' %m
491
492        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
493        #However, Sparse_CSR does not have the same methods as Sparse yet
494        #The tests will reveal what needs to be done
495        self.A = Sparse(n,m)
496        self.AtA = Sparse(m,m)
497
498        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
499        root = build_quadtree(self.mesh,
500                              max_points_per_cell = max_points_per_cell)
501
502        #Compute matrix elements
503        for i in range(n):
504            #For each data_coordinate point
505
506            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
507
508            x = point_coordinates[i]
509
510            #Find vertices near x
511            candidate_vertices = root.search(x[0], x[1])
512            is_more_elements = True
513
514            element_found, sigma0, sigma1, sigma2, k = \
515                self.search_triangles_of_vertices(candidate_vertices, x)
516            while not element_found and is_more_elements and expand_search:
517                #if verbose: print 'Expanding search'
518                candidate_vertices, branch = root.expand_search()
519                if branch == []:
520                    # Searching all the verts from the root cell that haven't
521                    # been searched.  This is the last try
522                    element_found, sigma0, sigma1, sigma2, k = \
523                      self.search_triangles_of_vertices(candidate_vertices, x)
524                    is_more_elements = False
525                else:
526                    element_found, sigma0, sigma1, sigma2, k = \
527                      self.search_triangles_of_vertices(candidate_vertices, x)
528
529
530            #Update interpolation matrix A if necessary
531            if element_found is True:
532                #Assign values to matrix A
533
534                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
535                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
536                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
537
538                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
539                js     = [j0,j1,j2]
540
541                for j in js:
542                    self.A[i,j] = sigmas[j]
543                    for k in js:
544                        self.AtA[j,k] += sigmas[j]*sigmas[k]
545            else:
546                pass
547                #Ok if there is no triangle for datapoint
548                #(as in brute force version)
549                #raise 'Could not find triangle for point', x
550
551
552
553    def search_triangles_of_vertices(self, candidate_vertices, x):
554            #Find triangle containing x:
555            element_found = False
556
557            # This will be returned if element_found = False
558            sigma2 = -10.0
559            sigma0 = -10.0
560            sigma1 = -10.0
561            k = -10.0
562
563            #For all vertices in same cell as point x
564            for v in candidate_vertices:
565
566                #for each triangle id (k) which has v as a vertex
567                for k, _ in self.mesh.vertexlist[v]:
568
569                    #Get the three vertex_points of candidate triangle
570                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
571                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
572                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
573
574                    #print "PDSG - k", k
575                    #print "PDSG - xi0", xi0
576                    #print "PDSG - xi1", xi1
577                    #print "PDSG - xi2", xi2
578                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
579                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
580
581                    #Get the three normals
582                    n0 = self.mesh.get_normal(k, 0)
583                    n1 = self.mesh.get_normal(k, 1)
584                    n2 = self.mesh.get_normal(k, 2)
585
586
587                    #Compute interpolation
588                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
589                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
590                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
591
592                    #print "PDSG - sigma0", sigma0
593                    #print "PDSG - sigma1", sigma1
594                    #print "PDSG - sigma2", sigma2
595
596                    #FIXME: Maybe move out to test or something
597                    epsilon = 1.0e-6
598                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
599
600                    #Check that this triangle contains the data point
601
602                    #Sigmas can get negative within
603                    #machine precision on some machines (e.g nautilus)
604                    #Hence the small eps
605                    eps = 1.0e-15
606                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
607                        element_found = True
608                        break
609
610                if element_found is True:
611                    #Don't look for any other triangle
612                    break
613            return element_found, sigma0, sigma1, sigma2, k
614
615
616
617    def build_interpolation_matrix_A_brute(self, point_coordinates):
618        """Build n x m interpolation matrix, where
619        n is the number of data points and
620        m is the number of basis functions phi_k (one per vertex)
621
622        This is the brute force which is too slow for large problems,
623        but could be used for testing
624        """
625
626        from util import ensure_numeric
627
628        #Convert input to Numeric arrays
629        point_coordinates = ensure_numeric(point_coordinates, Float)
630
631        #Build n x m interpolation matrix
632        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
633        n = point_coordinates.shape[0]     #Nbr of data points
634
635        self.A = Sparse(n,m)
636        self.AtA = Sparse(m,m)
637
638        #Compute matrix elements
639        for i in range(n):
640            #For each data_coordinate point
641
642            x = point_coordinates[i]
643            element_found = False
644            k = 0
645            while not element_found and k < len(self.mesh):
646                #For each triangle (brute force)
647                #FIXME: Real algorithm should only visit relevant triangles
648
649                #Get the three vertex_points
650                xi0 = self.mesh.get_vertex_coordinate(k, 0)
651                xi1 = self.mesh.get_vertex_coordinate(k, 1)
652                xi2 = self.mesh.get_vertex_coordinate(k, 2)
653
654                #Get the three normals
655                n0 = self.mesh.get_normal(k, 0)
656                n1 = self.mesh.get_normal(k, 1)
657                n2 = self.mesh.get_normal(k, 2)
658
659                #Compute interpolation
660                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
661                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
662                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
663
664                #FIXME: Maybe move out to test or something
665                epsilon = 1.0e-6
666                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
667
668                #Check that this triangle contains data point
669                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
670                    element_found = True
671                    #Assign values to matrix A
672
673                    j0 = self.mesh.triangles[k,0] #Global vertex id
674                    #self.A[i, j0] = sigma0
675
676                    j1 = self.mesh.triangles[k,1] #Global vertex id
677                    #self.A[i, j1] = sigma1
678
679                    j2 = self.mesh.triangles[k,2] #Global vertex id
680                    #self.A[i, j2] = sigma2
681
682                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
683                    js     = [j0,j1,j2]
684
685                    for j in js:
686                        self.A[i,j] = sigmas[j]
687                        for k in js:
688                            self.AtA[j,k] += sigmas[j]*sigmas[k]
689                k = k+1
690
691
692
693    def get_A(self):
694        return self.A.todense()
695
696    def get_B(self):
697        return self.B.todense()
698
699    def get_D(self):
700        return self.D.todense()
701
702        #FIXME: Remember to re-introduce the 1/n factor in the
703        #interpolation term
704
705    def build_smoothing_matrix_D(self):
706        """Build m x m smoothing matrix, where
707        m is the number of basis functions phi_k (one per vertex)
708
709        The smoothing matrix is defined as
710
711        D = D1 + D2
712
713        where
714
715        [D1]_{k,l} = \int_\Omega
716           \frac{\partial \phi_k}{\partial x}
717           \frac{\partial \phi_l}{\partial x}\,
718           dx dy
719
720        [D2]_{k,l} = \int_\Omega
721           \frac{\partial \phi_k}{\partial y}
722           \frac{\partial \phi_l}{\partial y}\,
723           dx dy
724
725
726        The derivatives \frac{\partial \phi_k}{\partial x},
727        \frac{\partial \phi_k}{\partial x} for a particular triangle
728        are obtained by computing the gradient a_k, b_k for basis function k
729        """
730
731        #FIXME: algorithm might be optimised by computing local 9x9
732        #"element stiffness matrices:
733
734        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
735
736        self.D = Sparse(m,m)
737
738        #For each triangle compute contributions to D = D1+D2
739        for i in range(len(self.mesh)):
740
741            #Get area
742            area = self.mesh.areas[i]
743
744            #Get global vertex indices
745            v0 = self.mesh.triangles[i,0]
746            v1 = self.mesh.triangles[i,1]
747            v2 = self.mesh.triangles[i,2]
748
749            #Get the three vertex_points
750            xi0 = self.mesh.get_vertex_coordinate(i, 0)
751            xi1 = self.mesh.get_vertex_coordinate(i, 1)
752            xi2 = self.mesh.get_vertex_coordinate(i, 2)
753
754            #Compute gradients for each vertex
755            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
756                              1, 0, 0)
757
758            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
759                              0, 1, 0)
760
761            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
762                              0, 0, 1)
763
764            #Compute diagonal contributions
765            self.D[v0,v0] += (a0*a0 + b0*b0)*area
766            self.D[v1,v1] += (a1*a1 + b1*b1)*area
767            self.D[v2,v2] += (a2*a2 + b2*b2)*area
768
769            #Compute contributions for basis functions sharing edges
770            e01 = (a0*a1 + b0*b1)*area
771            self.D[v0,v1] += e01
772            self.D[v1,v0] += e01
773
774            e12 = (a1*a2 + b1*b2)*area
775            self.D[v1,v2] += e12
776            self.D[v2,v1] += e12
777
778            e20 = (a2*a0 + b2*b0)*area
779            self.D[v2,v0] += e20
780            self.D[v0,v2] += e20
781
782
783    def fit(self, z):
784        """Fit a smooth surface to given 1d array of data points z.
785
786        The smooth surface is computed at each vertex in the underlying
787        mesh using the formula given in the module doc string.
788
789        Pre Condition:
790          self.A, self.AtA and self.B have been initialised
791
792        Inputs:
793          z: Single 1d vector or array of data at the point_coordinates.
794        """
795
796        #Convert input to Numeric arrays
797        from util import ensure_numeric
798        z = ensure_numeric(z, Float)
799
800        if len(z.shape) > 1 :
801            raise VectorShapeError, 'Can only deal with 1d data vector'
802
803        if self.point_indices is not None:
804            #Remove values for any points that were outside mesh
805            z = take(z, self.point_indices)
806
807        #Compute right hand side based on data
808        Atz = self.A.trans_mult(z)
809
810
811        #Check sanity
812        n, m = self.A.shape
813        if n<m and self.alpha == 0.0:
814            msg = 'ERROR (least_squares): Too few data points\n'
815            msg += 'There are only %d data points and alpha == 0. ' %n
816            msg += 'Need at least %d\n' %m
817            msg += 'Alternatively, set smoothing parameter alpha to a small '
818            msg += 'positive value,\ne.g. 1.0e-3.'
819            raise msg
820
821
822
823        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
824        #FIXME: Should we store the result here for later use? (ON)
825
826
827    def fit_points(self, z, verbose=False):
828        """Like fit, but more robust when each point has two or more attributes
829        FIXME (Ole): The name fit_points doesn't carry any meaning
830        for me. How about something like fit_multiple or fit_columns?
831        """
832
833        try:
834            if verbose: print 'Solving penalised least_squares problem'
835            return self.fit(z)
836        except VectorShapeError, e:
837            # broadcasting is not supported.
838
839            #Convert input to Numeric arrays
840            from util import ensure_numeric
841            z = ensure_numeric(z, Float)
842
843            #Build n x m interpolation matrix
844            m = self.mesh.coordinates.shape[0] #Number of vertices
845            n = z.shape[1]                     #Number of data points
846
847            f = zeros((m,n), Float) #Resulting columns
848
849            for i in range(z.shape[1]):
850                f[:,i] = self.fit(z[:,i])
851
852            return f
853
854
855    def interpolate(self, f):
856        """Evaluate smooth surface f at data points implied in self.A.
857
858        The mesh values representing a smooth surface are
859        assumed to be specified in f. This argument could,
860        for example have been obtained from the method self.fit()
861
862        Pre Condition:
863          self.A has been initialised
864
865        Inputs:
866          f: Vector or array of data at the mesh vertices.
867          If f is an array, interpolation will be done for each column as
868          per underlying matrix-matrix multiplication
869
870        Output:
871          Interpolated values at data points implied in self.A
872
873        """
874
875        return self.A * f
876
877    def cull_outsiders(self, f):
878        pass
879
880
881
882
883class Interpolation_function:
884    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
885    which is interpolated from time series defined at vertices of
886    triangular mesh (such as those stored in sww files)
887
888    Let m be the number of vertices, n the number of triangles
889    and p the number of timesteps.
890
891    Mandatory input
892        time:               px1 array of monotonously increasing times (Float)
893        quantities:         Dictionary of pxm arrays or 1 pxm array (Float)
894       
895    Optional input:
896        quantity_names:     List of keys into the quantities dictionary
897        vertex_coordinates: mx2 array of coordinates (Float)
898        triangles:          nx3 array of indices into vertex_coordinates (Int)
899        interpolation_points: array of coordinates to be interpolated to
900        verbose:            Level of reporting
901   
902   
903    The quantities returned by the callable object are specified by
904    the list quantities which must contain the names of the
905    quantities to be returned and also reflect the order, e.g. for
906    the shallow water wave equation, on would have
907    quantities = ['stage', 'xmomentum', 'ymomentum']
908
909    The parameter interpolation_points decides at which points interpolated
910    quantities are to be computed whenever object is called.
911    If None, return average value
912    """
913
914   
915   
916    def __init__(self,
917                 time,
918                 quantities,
919                 quantity_names = None, 
920                 vertex_coordinates = None,
921                 triangles = None,
922                 interpolation_points = None,
923                 verbose = False):
924        """Initialise object and build spatial interpolation if required
925        """
926
927        from Numeric import array, zeros, Float, alltrue, concatenate,\
928             reshape, ArrayType
929
930        from util import mean, ensure_numeric
931        from config import time_format
932        import types
933
934
935
936        #Check temporal info
937        time = ensure_numeric(time)       
938        msg = 'Time must be a monotonuosly '
939        msg += 'increasing sequence %s' %time
940        assert alltrue(time[1:] - time[:-1] > 0 ), msg
941
942
943        #Check if quantities is a single array only
944        if type(quantities) != types.DictType:
945            quantities = ensure_numeric(quantities)
946            quantity_names = ['Attribute']
947
948            #Make it a dictionary
949            quantities = {quantity_names[0]: quantities}
950
951
952        #Use keys if no names are specified
953        if quantity_names is not None:
954            self.quantity_names = quantity_names
955        else:
956            self.quantity_names = quantities.keys()
957
958
959        #Check spatial info
960        if vertex_coordinates is None:
961            self.spatial = False
962        else:   
963            vertex_coordinates = ensure_numeric(vertex_coordinates)
964
965            assert triangles is not None, 'Triangles array must be specified'
966            triangles = ensure_numeric(triangles)
967            self.spatial = True           
968           
969 
970        #     
971        self.interpolation_points = interpolation_points #FIXWME Needed?
972        self.T = time[:]  #Time assumed to be relative to starttime
973        self.index = 0    #Initial time index
974        self.precomputed_values = {}
975           
976
977
978        #Precomputed spatial interpolation if requested
979        if interpolation_points is not None:
980            if self.spatial is False:
981                raise 'Triangles and vertex_coordinates must be specified'
982           
983
984            try:
985                interpolation_points = ensure_numeric(interpolation_points)
986            except:
987                msg = 'Interpolation points must be an N x 2 Numeric array '+\
988                      'or a list of points\n'
989                msg += 'I got: %s.' %( str(interpolation_points)[:60] + '...')
990                raise msg
991
992
993            for name in quantity_names:
994                self.precomputed_values[name] =\
995                                              zeros((len(self.T),
996                                                     len(interpolation_points)),
997                                                    Float)
998
999            #Build interpolator
1000            interpol = Interpolation(vertex_coordinates,
1001                                     triangles,
1002                                     point_coordinates = interpolation_points,
1003                                     alpha = 0,
1004                                     precrop = False, 
1005                                     verbose = verbose)
1006
1007            #if interpol.cropped_points is True:
1008            #    raise 'Some interpolation points were outside mesh'
1009            #FIXME: This will be raised if triangles are listed as
1010            #discontinuous even though there is no need to stop
1011            #(precrop = True above)
1012
1013            if verbose: print 'Interpolate'
1014            for i, t in enumerate(self.T):
1015                #Interpolate quantities at this timestep
1016                if verbose: print ' time step %d of %d' %(i, len(self.T))
1017                for name in quantity_names:
1018                    self.precomputed_values[name][i, :] =\
1019                    interpol.interpolate(quantities[name][i,:])
1020
1021            #Report
1022            if verbose:
1023                x = vertex_coordinates[:,0]
1024                y = vertex_coordinates[:,1]               
1025           
1026                print '------------------------------------------------'
1027                print 'Interpolation_function statistics:'
1028                print '  Extent:'
1029                print '    x in [%f, %f], len(x) == %d'\
1030                      %(min(x), max(x), len(x))
1031                print '    y in [%f, %f], len(y) == %d'\
1032                      %(min(y), max(y), len(y))
1033                print '    t in [%f, %f], len(t) == %d'\
1034                      %(min(self.T), max(self.T), len(self.T))
1035                print '  Quantities:'
1036                for name in quantity_names:
1037                    q = quantities[name][:].flat
1038                    print '    %s in [%f, %f]' %(name, min(q), max(q))
1039                print '  Interpolation points (xi, eta):'\
1040                      ' number of points == %d ' %interpolation_points.shape[0]
1041                print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1042                                             max(interpolation_points[:,0]))
1043                print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1044                                              max(interpolation_points[:,1]))
1045                print '  Interpolated quantities (over all timesteps):'
1046               
1047                for name in quantity_names:
1048                    q = self.precomputed_values[name][:].flat
1049                    print '    %s at interpolation points in [%f, %f]'\
1050                          %(name, min(q), max(q))
1051                print '------------------------------------------------'
1052           
1053        else:
1054            #Store quantitites as is
1055            for name in quantity_names:
1056                self.precomputed_values[name] = quantities[name]
1057
1058
1059        #else:
1060        #    #Return an average, making this a time series
1061        #    for name in quantity_names:
1062        #        self.values[name] = zeros(len(self.T), Float)
1063        #
1064        #    if verbose: print 'Compute mean values'
1065        #    for i, t in enumerate(self.T):
1066        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1067        #        for name in quantity_names:
1068        #           self.values[name][i] = mean(quantities[name][i,:])
1069
1070
1071
1072
1073    def __repr__(self):
1074        return 'Interpolation function (spation-temporal)'
1075
1076    def __call__(self, t, point_id = None, x = None, y = None):
1077        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1078
1079        Inputs:
1080          t: time - Model time. Must lie within existing timesteps
1081          point_id: index of one of the preprocessed points.
1082          x, y:     Overrides location, point_id ignored
1083         
1084          If spatial info is present and all of x,y,point_id
1085          are None an exception is raised
1086                   
1087          If no spatial info is present, point_id and x,y arguments are ignored
1088          making f a function of time only.
1089
1090         
1091          FIXME: point_id could also be a slice
1092          FIXME: What if x and y are vectors?
1093          FIXME: What about f(x,y) without t?
1094        """
1095
1096        from math import pi, cos, sin, sqrt
1097        from Numeric import zeros, Float
1098        from util import mean       
1099
1100        if self.spatial is True:
1101            if point_id is None:
1102                if x is None or y is None:
1103                    msg = 'Either point_id or x and y must be specified'
1104                    raise msg
1105            else:
1106                if self.interpolation_points is None:
1107                    msg = 'Interpolation_function must be instantiated ' +\
1108                          'with a list of interpolation points before parameter ' +\
1109                          'point_id can be used'
1110                    raise msg
1111
1112
1113        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1114        msg += ' does not match model time: %s\n' %t
1115        if t < self.T[0]: raise msg
1116        if t > self.T[-1]: raise msg
1117
1118        oldindex = self.index #Time index
1119
1120        #Find current time slot
1121        while t > self.T[self.index]: self.index += 1
1122        while t < self.T[self.index]: self.index -= 1
1123
1124        if t == self.T[self.index]:
1125            #Protect against case where t == T[-1] (last time)
1126            # - also works in general when t == T[i]
1127            ratio = 0
1128        else:
1129            #t is now between index and index+1
1130            ratio = (t - self.T[self.index])/\
1131                    (self.T[self.index+1] - self.T[self.index])
1132
1133        #Compute interpolated values
1134        q = zeros(len(self.quantity_names), Float)
1135
1136        for i, name in enumerate(self.quantity_names):
1137            Q = self.precomputed_values[name]
1138
1139            if self.spatial is False:
1140                #If there is no spatial info               
1141                assert len(Q.shape) == 1
1142
1143                Q0 = Q[self.index]
1144                if ratio > 0: Q1 = Q[self.index+1]
1145
1146            else:
1147                if x is not None and y is not None:
1148                    #Interpolate to x, y
1149                   
1150                    raise 'x,y interpolation not yet implemented'
1151                else:
1152                    #Use precomputed point
1153                    Q0 = Q[self.index, point_id]
1154                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1155
1156            #Linear temporal interpolation   
1157            if ratio > 0:
1158                q[i] = Q0 + ratio*(Q1 - Q0)
1159            else:
1160                q[i] = Q0
1161
1162
1163        #Return vector of interpolated values
1164        #if len(q) == 1:
1165        #    return q[0]
1166        #else:
1167        #    return q
1168
1169
1170        #Return vector of interpolated values
1171        #FIXME:
1172        if self.spatial is True:
1173            return q
1174        else:
1175            #Replicate q according to x and y
1176            #This is e.g used for Wind_stress
1177            if x == None or y == None: 
1178                return q
1179            else:
1180                try:
1181                    N = len(x)
1182                except:
1183                    return q
1184                else:
1185                    from Numeric import ones, Float
1186                    #x is a vector - Create one constant column for each value
1187                    N = len(x)
1188                    assert len(y) == N, 'x and y must have same length'
1189                    res = []
1190                    for col in q:
1191                        res.append(col*ones(N, Float))
1192                       
1193                return res
1194
1195
1196
1197
1198#-------------------------------------------------------------
1199if __name__ == "__main__":
1200    """
1201    Load in a mesh and data points with attributes.
1202    Fit the attributes to the mesh.
1203    Save a new mesh file.
1204    """
1205    import os, sys
1206    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha]"\
1207            %os.path.basename(sys.argv[0])
1208
1209    if len(sys.argv) < 4:
1210        print usage
1211    else:
1212        mesh_file = sys.argv[1]
1213        point_file = sys.argv[2]
1214        mesh_output_file = sys.argv[3]
1215
1216        expand_search = False
1217        if len(sys.argv) > 4:
1218            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1219                expand_search = True
1220            else:
1221                expand_search = False
1222
1223        verbose = False
1224        if len(sys.argv) > 5:
1225            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1226                verbose = False
1227            else:
1228                verbose = True
1229
1230        if len(sys.argv) > 6:
1231            alpha = sys.argv[6]
1232        else:
1233            alpha = DEFAULT_ALPHA
1234
1235        t0 = time.time()
1236        fit_to_mesh_file(mesh_file,
1237                         point_file,
1238                         mesh_output_file,
1239                         alpha,
1240                         verbose= verbose,
1241                         expand_search = expand_search)
1242
1243        print 'That took %.2f seconds' %(time.time()-t0)
1244
Note: See TracBrowser for help on using the repository browser.