source: inundation/ga/storm_surge/pyvolution/least_squares.py @ 1697

Last change on this file since 1697 was 1681, checked in by ole, 20 years ago

Played with set_quantity

File size: 44.3 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False):
63    """
64    Given a mesh file (tsh) and a point attribute file (xya), fit
65    point attributes to the mesh and write a mesh file with the
66    results.
67
68
69    If data_origin is not None it is assumed to be
70    a 3-tuple with geo referenced
71    UTM coordinates (zone, easting, northing)
72
73    mesh_origin is the same but refers to the input tsh file.
74    FIXME: When the tsh format contains it own origin, these parameters can go.
75    FIXME: And both origins should be obtained from the specified files.
76    """
77
78    from load_mesh.loadASCII import import_mesh_file, \
79                 import_points_file, export_mesh_file, \
80                 concatinate_attributelist
81
82    mesh_dict = import_mesh_file(mesh_file)
83    vertex_coordinates = mesh_dict['vertices']
84    triangles = mesh_dict['triangles']
85    if type(mesh_dict['vertex_attributes']) == ArrayType:
86        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
87    else:
88        old_point_attributes = mesh_dict['vertex_attributes']
89
90    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
91        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
92    else:
93        old_title_list = mesh_dict['vertex_attribute_titles']
94
95    if verbose: print 'tsh file %s loaded' %mesh_file
96
97    # load in the .pts file
98    try:
99        point_dict = import_points_file(point_file,
100                                      delimiter = ',',
101                                      verbose=verbose)
102    except SyntaxError,e:
103        point_dict = import_points_file(point_file,
104                                      delimiter = ' ',
105                                      verbose=verbose)
106
107    point_coordinates = point_dict['pointlist']
108    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
109
110    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
111        data_origin = point_dict['geo_reference'].get_origin()
112    else:
113        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
114
115    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
116        mesh_origin = mesh_dict['geo_reference'].get_origin()
117    else:
118        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
119
120    if verbose: print "points file loaded"
121    if verbose:print "fitting to mesh"
122    f = fit_to_mesh(vertex_coordinates,
123                    triangles,
124                    point_coordinates,
125                    point_attributes,
126                    alpha = alpha,
127                    verbose = verbose,
128                    expand_search = expand_search,
129                    data_origin = data_origin,
130                    mesh_origin = mesh_origin,
131                    precrop = precrop)
132    if verbose: print "finished fitting to mesh"
133
134    # convert array to list of lists
135    new_point_attributes = f.tolist()
136    #FIXME have this overwrite attributes with the same title - DSG
137    #Put the newer attributes last
138    if old_title_list <> []:
139        old_title_list.extend(title_list)
140        #FIXME can this be done a faster way? - DSG
141        for i in range(len(old_point_attributes)):
142            old_point_attributes[i].extend(new_point_attributes[i])
143        mesh_dict['vertex_attributes'] = old_point_attributes
144        mesh_dict['vertex_attribute_titles'] = old_title_list
145    else:
146        mesh_dict['vertex_attributes'] = new_point_attributes
147        mesh_dict['vertex_attribute_titles'] = title_list
148
149    #FIXME (Ole): Remember to output mesh_origin as well
150    if verbose: print "exporting to file ",mesh_output_file
151    export_mesh_file(mesh_output_file, mesh_dict)
152
153
154def fit_to_mesh(vertex_coordinates,
155                triangles,
156                point_coordinates,
157                point_attributes,
158                alpha = DEFAULT_ALPHA,
159                verbose = False,
160                expand_search = False,
161                data_origin = None,
162                mesh_origin = None,
163                precrop = False):
164    """
165    Fit a smooth surface to a triangulation,
166    given data points with attributes.
167
168
169        Inputs:
170
171          vertex_coordinates: List of coordinate pairs [xi, eta] of points
172          constituting mesh (or a an m x 2 Numeric array)
173
174          triangles: List of 3-tuples (or a Numeric array) of
175          integers representing indices of all vertices in the mesh.
176
177          point_coordinates: List of coordinate pairs [x, y] of data points
178          (or an nx2 Numeric array)
179
180          alpha: Smoothing parameter.
181
182          point_attributes: Vector or array of data at the point_coordinates.
183
184          data_origin and mesh_origin are 3-tuples consisting of
185          UTM zone, easting and northing. If specified
186          point coordinates and vertex coordinates are assumed to be
187          relative to their respective origins.
188
189    """
190    interp = Interpolation(vertex_coordinates,
191                           triangles,
192                           point_coordinates,
193                           alpha = alpha,
194                           verbose = verbose,
195                           expand_search = expand_search,
196                           data_origin = data_origin,
197                           mesh_origin = mesh_origin,
198                           precrop = precrop)
199
200    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
201    return vertex_attributes
202
203
204
205def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
206                    verbose = False, reduction = 1, format = 'netcdf'):
207    """Fits attributes from pts file to MxN rectangular mesh
208
209    Read pts file and create rectangular mesh of resolution MxN such that
210    it covers all points specified in pts file.
211
212    FIXME: This may be a temporary function until we decide on
213    netcdf formats etc
214
215    FIXME: Uses elevation hardwired
216    """
217
218    import util, mesh_factory
219
220    if verbose: print 'Read pts'
221    points, attributes = util.read_xya(pts_name, format)
222
223    #Reduce number of points a bit
224    points = points[::reduction]
225    elevation = attributes['elevation']  #Must be elevation
226    elevation = elevation[::reduction]
227
228    if verbose: print 'Got %d data points' %len(points)
229
230    if verbose: print 'Create mesh'
231    #Find extent
232    max_x = min_x = points[0][0]
233    max_y = min_y = points[0][1]
234    for point in points[1:]:
235        x = point[0]
236        if x > max_x: max_x = x
237        if x < min_x: min_x = x
238        y = point[1]
239        if y > max_y: max_y = y
240        if y < min_y: min_y = y
241
242    #Create appropriate mesh
243    vertex_coordinates, triangles, boundary =\
244         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
245                                (min_x, min_y))
246
247    #Fit attributes to mesh
248    vertex_attributes = fit_to_mesh(vertex_coordinates,
249                        triangles,
250                        points,
251                        elevation, alpha=alpha, verbose=verbose)
252
253
254
255    return vertex_coordinates, triangles, boundary, vertex_attributes
256
257
258
259class Interpolation:
260
261    def __init__(self,
262                 vertex_coordinates,
263                 triangles,
264                 point_coordinates = None,
265                 alpha = DEFAULT_ALPHA,
266                 verbose = False,
267                 expand_search = True,
268                 max_points_per_cell = 30,
269                 mesh_origin = None,
270                 data_origin = None,
271                 precrop = False):
272
273
274        """ Build interpolation matrix mapping from
275        function values at vertices to function values at data points
276
277        Inputs:
278
279          vertex_coordinates: List of coordinate pairs [xi, eta] of
280          points constituting mesh (or a an m x 2 Numeric array)
281          Points may appear multiple times
282          (e.g. if vertices have discontinuities)
283
284          triangles: List of 3-tuples (or a Numeric array) of
285          integers representing indices of all vertices in the mesh.
286
287          point_coordinates: List of coordinate pairs [x, y] of
288          data points (or an nx2 Numeric array)
289          If point_coordinates is absent, only smoothing matrix will
290          be built
291
292          alpha: Smoothing parameter
293
294          data_origin and mesh_origin are 3-tuples consisting of
295          UTM zone, easting and northing. If specified
296          point coordinates and vertex coordinates are assumed to be
297          relative to their respective origins.
298
299        """
300        from util import ensure_numeric
301
302        #Convert input to Numeric arrays
303        triangles = ensure_numeric(triangles, Int)
304        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
305
306        #Build underlying mesh
307        if verbose: print 'Building mesh'
308        #self.mesh = General_mesh(vertex_coordinates, triangles,
309        #FIXME: Trying the normal mesh while testing precrop,
310        #       The functionality of boundary_polygon is needed for that
311
312        #FIXME - geo ref does not have to go into mesh.
313        # Change the point co-ords to conform to the
314        # mesh co-ords early in the code
315        if mesh_origin == None:
316            geo = None
317        else:
318            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
319        self.mesh = Mesh(vertex_coordinates, triangles,
320                         geo_reference = geo)
321        #FIXME, remove if mesh checks it.
322        self.mesh.check_integrity()
323        self.data_origin = data_origin
324
325        self.point_indices = None
326
327        #Smoothing parameter
328        self.alpha = alpha
329
330        #Build coefficient matrices
331        self.build_coefficient_matrix_B(point_coordinates,
332                                        verbose = verbose,
333                                        expand_search = expand_search,
334                                        max_points_per_cell =\
335                                        max_points_per_cell,
336                                        data_origin = data_origin,
337                                        precrop = precrop)
338
339
340    def set_point_coordinates(self, point_coordinates,
341                              data_origin = None):
342        """
343        A public interface to setting the point co-ordinates.
344        """
345        self.build_coefficient_matrix_B(point_coordinates, data_origin)
346
347    def build_coefficient_matrix_B(self, point_coordinates=None,
348                                   verbose = False, expand_search = True,
349                                   max_points_per_cell=30,
350                                   data_origin = None,
351                                   precrop = False):
352        """Build final coefficient matrix"""
353
354
355        if self.alpha <> 0:
356            if verbose: print 'Building smoothing matrix'
357            self.build_smoothing_matrix_D()
358
359        if point_coordinates is not None:
360
361            if verbose: print 'Building interpolation matrix'
362            self.build_interpolation_matrix_A(point_coordinates,
363                                              verbose = verbose,
364                                              expand_search = expand_search,
365                                              max_points_per_cell =\
366                                              max_points_per_cell,
367                                              data_origin = data_origin,
368                                              precrop = precrop)
369
370            if self.alpha <> 0:
371                self.B = self.AtA + self.alpha*self.D
372            else:
373                self.B = self.AtA
374
375            #Convert self.B matrix to CSR format for faster matrix vector
376            self.B = Sparse_CSR(self.B)
377
378    def build_interpolation_matrix_A(self, point_coordinates,
379                                     verbose = False, expand_search = True,
380                                     max_points_per_cell=30,
381                                     data_origin = None,
382                                     precrop = False):
383        """Build n x m interpolation matrix, where
384        n is the number of data points and
385        m is the number of basis functions phi_k (one per vertex)
386
387        This algorithm uses a quad tree data structure for fast binning of data points
388        origin is a 3-tuple consisting of UTM zone, easting and northing.
389        If specified coordinates are assumed to be relative to this origin.
390
391        This one will override any data_origin that may be specified in
392        interpolation instance
393
394        """
395
396
397        #FIXME (Ole): Check that this function is memeory efficient.
398        #6 million datapoints and 300000 basis functions
399        #causes out-of-memory situation
400        #First thing to check is whether there is room for self.A and self.AtA
401        #
402        #Maybe we need some sort of blocking
403
404        from quad import build_quadtree
405        from util import ensure_numeric
406
407        if data_origin is None:
408            data_origin = self.data_origin #Use the one from
409                                           #interpolation instance
410
411        #Convert input to Numeric arrays just in case.
412        point_coordinates = ensure_numeric(point_coordinates, Float)
413
414        #Keep track of discarded points (if any).
415        #This is only registered if precrop is True
416        self.cropped_points = False
417
418        #Shift data points to same origin as mesh (if specified)
419
420        #FIXME this will shift if there was no geo_ref.
421        #But all this should be removed anyhow.
422        #change coords before this point
423        mesh_origin = self.mesh.geo_reference.get_origin()
424        if point_coordinates is not None:
425            if data_origin is not None:
426                if mesh_origin is not None:
427
428                    #Transformation:
429                    #
430                    #Let x_0 be the reference point of the point coordinates
431                    #and xi_0 the reference point of the mesh.
432                    #
433                    #A point coordinate (x + x_0) is then made relative
434                    #to xi_0 by
435                    #
436                    # x_new = x + x_0 - xi_0
437                    #
438                    #and similarly for eta
439
440                    x_offset = data_origin[1] - mesh_origin[1]
441                    y_offset = data_origin[2] - mesh_origin[2]
442                else: #Shift back to a zero origin
443                    x_offset = data_origin[1]
444                    y_offset = data_origin[2]
445
446                point_coordinates[:,0] += x_offset
447                point_coordinates[:,1] += y_offset
448            else:
449                if mesh_origin is not None:
450                    #Use mesh origin for data points
451                    point_coordinates[:,0] -= mesh_origin[1]
452                    point_coordinates[:,1] -= mesh_origin[2]
453
454
455
456        #Remove points falling outside mesh boundary
457        #This reduced one example from 1356 seconds to 825 seconds
458        if precrop is True:
459            from Numeric import take
460            from util import inside_polygon
461
462            if verbose: print 'Getting boundary polygon'
463            P = self.mesh.get_boundary_polygon()
464
465            if verbose: print 'Getting indices inside mesh boundary'
466            indices = inside_polygon(point_coordinates, P, verbose = verbose)
467
468
469            if len(indices) != point_coordinates.shape[0]:
470                self.cropped_points = True
471                if verbose:
472                    print 'Done - %d points outside mesh have been cropped.'\
473                          %(point_coordinates.shape[0] - len(indices))
474
475            point_coordinates = take(point_coordinates, indices)
476            self.point_indices = indices
477
478
479
480
481        #Build n x m interpolation matrix
482        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
483        n = point_coordinates.shape[0]     #Nbr of data points
484
485        if verbose: print 'Number of datapoints: %d' %n
486        if verbose: print 'Number of basis functions: %d' %m
487
488        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
489        #However, Sparse_CSR does not have the same methods as Sparse yet
490        #The tests will reveal what needs to be done
491        self.A = Sparse(n,m)
492        self.AtA = Sparse(m,m)
493
494        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
495        root = build_quadtree(self.mesh,
496                              max_points_per_cell = max_points_per_cell)
497
498        #Compute matrix elements
499        for i in range(n):
500            #For each data_coordinate point
501
502            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
503
504            x = point_coordinates[i]
505
506            #Find vertices near x
507            candidate_vertices = root.search(x[0], x[1])
508            is_more_elements = True
509
510            element_found, sigma0, sigma1, sigma2, k = \
511                self.search_triangles_of_vertices(candidate_vertices, x)
512            while not element_found and is_more_elements and expand_search:
513                #if verbose: print 'Expanding search'
514                candidate_vertices, branch = root.expand_search()
515                if branch == []:
516                    # Searching all the verts from the root cell that haven't
517                    # been searched.  This is the last try
518                    element_found, sigma0, sigma1, sigma2, k = \
519                      self.search_triangles_of_vertices(candidate_vertices, x)
520                    is_more_elements = False
521                else:
522                    element_found, sigma0, sigma1, sigma2, k = \
523                      self.search_triangles_of_vertices(candidate_vertices, x)
524
525
526            #Update interpolation matrix A if necessary
527            if element_found is True:
528                #Assign values to matrix A
529
530                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
531                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
532                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
533
534                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
535                js     = [j0,j1,j2]
536
537                for j in js:
538                    self.A[i,j] = sigmas[j]
539                    for k in js:
540                        self.AtA[j,k] += sigmas[j]*sigmas[k]
541            else:
542                pass
543                #Ok if there is no triangle for datapoint
544                #(as in brute force version)
545                #raise 'Could not find triangle for point', x
546
547
548
549    def search_triangles_of_vertices(self, candidate_vertices, x):
550            #Find triangle containing x:
551            element_found = False
552
553            # This will be returned if element_found = False
554            sigma2 = -10.0
555            sigma0 = -10.0
556            sigma1 = -10.0
557            k = -10.0
558
559            #For all vertices in same cell as point x
560            for v in candidate_vertices:
561
562                #for each triangle id (k) which has v as a vertex
563                for k, _ in self.mesh.vertexlist[v]:
564
565                    #Get the three vertex_points of candidate triangle
566                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
567                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
568                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
569
570                    #print "PDSG - k", k
571                    #print "PDSG - xi0", xi0
572                    #print "PDSG - xi1", xi1
573                    #print "PDSG - xi2", xi2
574                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
575                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
576
577                    #Get the three normals
578                    n0 = self.mesh.get_normal(k, 0)
579                    n1 = self.mesh.get_normal(k, 1)
580                    n2 = self.mesh.get_normal(k, 2)
581
582
583                    #Compute interpolation
584                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
585                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
586                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
587
588                    #print "PDSG - sigma0", sigma0
589                    #print "PDSG - sigma1", sigma1
590                    #print "PDSG - sigma2", sigma2
591
592                    #FIXME: Maybe move out to test or something
593                    epsilon = 1.0e-6
594                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
595
596                    #Check that this triangle contains the data point
597
598                    #Sigmas can get negative within
599                    #machine precision on some machines (e.g nautilus)
600                    #Hence the small eps
601                    eps = 1.0e-15
602                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
603                        element_found = True
604                        break
605
606                if element_found is True:
607                    #Don't look for any other triangle
608                    break
609            return element_found, sigma0, sigma1, sigma2, k
610
611
612
613    def build_interpolation_matrix_A_brute(self, point_coordinates):
614        """Build n x m interpolation matrix, where
615        n is the number of data points and
616        m is the number of basis functions phi_k (one per vertex)
617
618        This is the brute force which is too slow for large problems,
619        but could be used for testing
620        """
621
622        from util import ensure_numeric
623
624        #Convert input to Numeric arrays
625        point_coordinates = ensure_numeric(point_coordinates, Float)
626
627        #Build n x m interpolation matrix
628        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
629        n = point_coordinates.shape[0]     #Nbr of data points
630
631        self.A = Sparse(n,m)
632        self.AtA = Sparse(m,m)
633
634        #Compute matrix elements
635        for i in range(n):
636            #For each data_coordinate point
637
638            x = point_coordinates[i]
639            element_found = False
640            k = 0
641            while not element_found and k < len(self.mesh):
642                #For each triangle (brute force)
643                #FIXME: Real algorithm should only visit relevant triangles
644
645                #Get the three vertex_points
646                xi0 = self.mesh.get_vertex_coordinate(k, 0)
647                xi1 = self.mesh.get_vertex_coordinate(k, 1)
648                xi2 = self.mesh.get_vertex_coordinate(k, 2)
649
650                #Get the three normals
651                n0 = self.mesh.get_normal(k, 0)
652                n1 = self.mesh.get_normal(k, 1)
653                n2 = self.mesh.get_normal(k, 2)
654
655                #Compute interpolation
656                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
657                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
658                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
659
660                #FIXME: Maybe move out to test or something
661                epsilon = 1.0e-6
662                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
663
664                #Check that this triangle contains data point
665                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
666                    element_found = True
667                    #Assign values to matrix A
668
669                    j0 = self.mesh.triangles[k,0] #Global vertex id
670                    #self.A[i, j0] = sigma0
671
672                    j1 = self.mesh.triangles[k,1] #Global vertex id
673                    #self.A[i, j1] = sigma1
674
675                    j2 = self.mesh.triangles[k,2] #Global vertex id
676                    #self.A[i, j2] = sigma2
677
678                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
679                    js     = [j0,j1,j2]
680
681                    for j in js:
682                        self.A[i,j] = sigmas[j]
683                        for k in js:
684                            self.AtA[j,k] += sigmas[j]*sigmas[k]
685                k = k+1
686
687
688
689    def get_A(self):
690        return self.A.todense()
691
692    def get_B(self):
693        return self.B.todense()
694
695    def get_D(self):
696        return self.D.todense()
697
698        #FIXME: Remember to re-introduce the 1/n factor in the
699        #interpolation term
700
701    def build_smoothing_matrix_D(self):
702        """Build m x m smoothing matrix, where
703        m is the number of basis functions phi_k (one per vertex)
704
705        The smoothing matrix is defined as
706
707        D = D1 + D2
708
709        where
710
711        [D1]_{k,l} = \int_\Omega
712           \frac{\partial \phi_k}{\partial x}
713           \frac{\partial \phi_l}{\partial x}\,
714           dx dy
715
716        [D2]_{k,l} = \int_\Omega
717           \frac{\partial \phi_k}{\partial y}
718           \frac{\partial \phi_l}{\partial y}\,
719           dx dy
720
721
722        The derivatives \frac{\partial \phi_k}{\partial x},
723        \frac{\partial \phi_k}{\partial x} for a particular triangle
724        are obtained by computing the gradient a_k, b_k for basis function k
725        """
726
727        #FIXME: algorithm might be optimised by computing local 9x9
728        #"element stiffness matrices:
729
730        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
731
732        self.D = Sparse(m,m)
733
734        #For each triangle compute contributions to D = D1+D2
735        for i in range(len(self.mesh)):
736
737            #Get area
738            area = self.mesh.areas[i]
739
740            #Get global vertex indices
741            v0 = self.mesh.triangles[i,0]
742            v1 = self.mesh.triangles[i,1]
743            v2 = self.mesh.triangles[i,2]
744
745            #Get the three vertex_points
746            xi0 = self.mesh.get_vertex_coordinate(i, 0)
747            xi1 = self.mesh.get_vertex_coordinate(i, 1)
748            xi2 = self.mesh.get_vertex_coordinate(i, 2)
749
750            #Compute gradients for each vertex
751            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
752                              1, 0, 0)
753
754            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
755                              0, 1, 0)
756
757            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
758                              0, 0, 1)
759
760            #Compute diagonal contributions
761            self.D[v0,v0] += (a0*a0 + b0*b0)*area
762            self.D[v1,v1] += (a1*a1 + b1*b1)*area
763            self.D[v2,v2] += (a2*a2 + b2*b2)*area
764
765            #Compute contributions for basis functions sharing edges
766            e01 = (a0*a1 + b0*b1)*area
767            self.D[v0,v1] += e01
768            self.D[v1,v0] += e01
769
770            e12 = (a1*a2 + b1*b2)*area
771            self.D[v1,v2] += e12
772            self.D[v2,v1] += e12
773
774            e20 = (a2*a0 + b2*b0)*area
775            self.D[v2,v0] += e20
776            self.D[v0,v2] += e20
777
778
779    def fit(self, z):
780        """Fit a smooth surface to given 1d array of data points z.
781
782        The smooth surface is computed at each vertex in the underlying
783        mesh using the formula given in the module doc string.
784
785        Pre Condition:
786          self.A, self.AtA and self.B have been initialised
787
788        Inputs:
789          z: Single 1d vector or array of data at the point_coordinates.
790        """
791
792        #Convert input to Numeric arrays
793        from util import ensure_numeric
794        z = ensure_numeric(z, Float)
795
796        if len(z.shape) > 1 :
797            raise VectorShapeError, 'Can only deal with 1d data vector'
798
799        if self.point_indices is not None:
800            #Remove values for any points that were outside mesh
801            z = take(z, self.point_indices)
802
803        #Compute right hand side based on data
804        Atz = self.A.trans_mult(z)
805
806
807        #Check sanity
808        n, m = self.A.shape
809        if n<m and self.alpha == 0.0:
810            msg = 'ERROR (least_squares): Too few data points\n'
811            msg += 'There are only %d data points and alpha == 0. ' %n
812            msg += 'Need at least %d\n' %m
813            msg += 'Alternatively, set smoothing parameter alpha to a small '
814            msg += 'positive value,\ne.g. 1.0e-3.'
815            raise msg
816
817
818
819        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
820        #FIXME: Should we store the result here for later use? (ON)
821
822
823    def fit_points(self, z, verbose=False):
824        """Like fit, but more robust when each point has two or more attributes
825        FIXME (Ole): The name fit_points doesn't carry any meaning
826        for me. How about something like fit_multiple or fit_columns?
827        """
828
829        try:
830            if verbose: print 'Solving penalised least_squares problem'
831            return self.fit(z)
832        except VectorShapeError, e:
833            # broadcasting is not supported.
834
835            #Convert input to Numeric arrays
836            from util import ensure_numeric
837            z = ensure_numeric(z, Float)
838
839            #Build n x m interpolation matrix
840            m = self.mesh.coordinates.shape[0] #Number of vertices
841            n = z.shape[1]                     #Number of data points
842
843            f = zeros((m,n), Float) #Resulting columns
844
845            for i in range(z.shape[1]):
846                f[:,i] = self.fit(z[:,i])
847
848            return f
849
850
851    def interpolate(self, f):
852        """Evaluate smooth surface f at data points implied in self.A.
853
854        The mesh values representing a smooth surface are
855        assumed to be specified in f. This argument could,
856        for example have been obtained from the method self.fit()
857
858        Pre Condition:
859          self.A has been initialised
860
861        Inputs:
862          f: Vector or array of data at the mesh vertices.
863          If f is an array, interpolation will be done for each column as
864          per underlying matrix-matrix multiplication
865
866        Output:
867          Interpolated values at data points implied in self.A
868
869        """
870
871        return self.A * f
872
873    def cull_outsiders(self, f):
874        pass
875
876
877
878
879class Interpolation_function:
880    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
881    which is interpolated from time series defined at vertices of
882    triangular mesh (such as those stored in sww files)
883
884    Let m be the number of vertices, n the number of triangles
885    and p the number of timesteps.
886
887    Mandatory input
888        time:               px1 array of monotonously increasing times (Float)
889        quantities:         Dictionary of pxm arrays or 1 pxm array (Float)
890       
891    Optional input:
892        quantity_names:     List of keys into the quantities dictionary
893        vertex_coordinates: mx2 array of coordinates (Float)
894        triangles:          nx3 array of indices into vertex_coordinates (Int)
895        interpolation_points: array of coordinates to be interpolated to
896        verbose:            Level of reporting
897   
898   
899    The quantities returned by the callable object are specified by
900    the list quantities which must contain the names of the
901    quantities to be returned and also reflect the order, e.g. for
902    the shallow water wave equation, on would have
903    quantities = ['stage', 'xmomentum', 'ymomentum']
904
905    The parameter interpolation_points decides at which points interpolated
906    quantities are to be computed whenever object is called.
907    If None, return average value
908    """
909
910   
911   
912    def __init__(self,
913                 time,
914                 quantities,
915                 quantity_names = None, 
916                 vertex_coordinates = None,
917                 triangles = None,
918                 interpolation_points = None,
919                 verbose = False):
920        """Initialise object and build spatial interpolation if required
921        """
922
923        from Numeric import array, zeros, Float, alltrue, concatenate,\
924             reshape, ArrayType
925
926        from util import mean, ensure_numeric
927        from config import time_format
928        import types
929
930
931
932        #Check temporal info
933        time = ensure_numeric(time)       
934        msg = 'Time must be a monotonuosly '
935        msg += 'increasing sequence %s' %time
936        assert alltrue(time[1:] - time[:-1] > 0 ), msg
937
938
939        #Check if quantities is a single array only
940        if type(quantities) != types.DictType:
941            quantities = ensure_numeric(quantities)
942            quantity_names = ['Attribute']
943
944            #Make it a dictionary
945            quantities = {quantity_names[0]: quantities}
946
947
948        #Use keys if no names are specified
949        if quantity_names is not None:
950            self.quantity_names = quantity_names
951        else:
952            self.quantity_names = quantities.keys()
953
954
955        #Check spatial info
956        if vertex_coordinates is None:
957            self.spatial = False
958        else:   
959            vertex_coordinates = ensure_numeric(vertex_coordinates)
960
961            assert triangles is not None, 'Triangles array must be specified'
962            triangles = ensure_numeric(triangles)
963            self.spatial = True           
964           
965 
966        #     
967        self.interpolation_points = interpolation_points #FIXWME Needed?
968        self.T = time[:]  #Time assumed to be relative to starttime
969        self.index = 0    #Initial time index
970        self.precomputed_values = {}
971           
972
973
974        #Precomputed spatial interpolation if requested
975        if interpolation_points is not None:
976            if self.spatial is False:
977                raise 'Triangles and vertex_coordinates must be specified'
978           
979
980            try:
981                interpolation_points = ensure_numeric(interpolation_points)
982            except:
983                msg = 'Interpolation points must be an N x 2 Numeric array '+\
984                      'or a list of points\n'
985                msg += 'I got: %s.' %( str(interpolation_points)[:60] + '...')
986                raise msg
987
988
989            for name in quantity_names:
990                self.precomputed_values[name] =\
991                                              zeros((len(self.T),
992                                                     len(interpolation_points)),
993                                                    Float)
994
995            #Build interpolator
996            interpol = Interpolation(vertex_coordinates,
997                                     triangles,
998                                     point_coordinates = interpolation_points,
999                                     alpha = 0,
1000                                     precrop = False, 
1001                                     verbose = verbose)
1002
1003            #if interpol.cropped_points is True:
1004            #    raise 'Some interpolation points were outside mesh'
1005            #FIXME: This will be raised if triangles are listed as
1006            #discontinuous even though there is no need to stop
1007            #(precrop = True above)
1008
1009            if verbose: print 'Interpolate'
1010            for i, t in enumerate(self.T):
1011                #Interpolate quantities at this timestep
1012                if verbose: print ' time step %d of %d' %(i, len(self.T))
1013                for name in quantity_names:
1014                    self.precomputed_values[name][i, :] =\
1015                    interpol.interpolate(quantities[name][i,:])
1016
1017            #Report
1018            if verbose:
1019                x = vertex_coordinates[:,0]
1020                y = vertex_coordinates[:,1]               
1021           
1022                print '------------------------------------------------'
1023                print 'Interpolation_function statistics:'
1024                print '  Extent:'
1025                print '    x in [%f, %f], len(x) == %d'\
1026                      %(min(x), max(x), len(x))
1027                print '    y in [%f, %f], len(y) == %d'\
1028                      %(min(y), max(y), len(y))
1029                print '    t in [%f, %f], len(t) == %d'\
1030                      %(min(self.T), max(self.T), len(self.T))
1031                print '  Quantities:'
1032                for name in quantity_names:
1033                    q = quantities[name][:].flat
1034                    print '    %s in [%f, %f]' %(name, min(q), max(q))
1035                print '  Interpolation points (xi, eta):'\
1036                      ' number of points == %d ' %interpolation_points.shape[0]
1037                print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1038                                             max(interpolation_points[:,0]))
1039                print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1040                                              max(interpolation_points[:,1]))
1041                print '  Interpolated quantities (over all timesteps):'
1042               
1043                for name in quantity_names:
1044                    q = self.precomputed_values[name][:].flat
1045                    print '    %s at interpolation points in [%f, %f]'\
1046                          %(name, min(q), max(q))
1047                print '------------------------------------------------'
1048           
1049        else:
1050            #Store quantitites as is
1051            for name in quantity_names:
1052                self.precomputed_values[name] = quantities[name]
1053
1054
1055        #else:
1056        #    #Return an average, making this a time series
1057        #    for name in quantity_names:
1058        #        self.values[name] = zeros(len(self.T), Float)
1059        #
1060        #    if verbose: print 'Compute mean values'
1061        #    for i, t in enumerate(self.T):
1062        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1063        #        for name in quantity_names:
1064        #           self.values[name][i] = mean(quantities[name][i,:])
1065
1066
1067
1068
1069    def __repr__(self):
1070        return 'Interpolation function (spation-temporal)'
1071
1072    def __call__(self, t, point_id = None, x = None, y = None):
1073        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1074
1075        Inputs:
1076          t: time - Model time. Must lie within existing timesteps
1077          point_id: index of one of the preprocessed points.
1078          x, y:     Overrides location, point_id ignored
1079         
1080          If spatial info is present and all of x,y,point_id
1081          are None an exception is raised
1082                   
1083          If no spatial info is present, point_id and x,y arguments are ignored
1084          making f a function of time only.
1085
1086         
1087          FIXME: point_id could also be a slice
1088          FIXME: What if x and y are vectors?
1089          FIXME: What about f(x,y) without t?
1090        """
1091
1092        from math import pi, cos, sin, sqrt
1093        from Numeric import zeros, Float
1094        from util import mean       
1095
1096        if self.spatial is True:
1097            if point_id is None:
1098                if x is None or y is None:
1099                    msg = 'Either point_id or x and y must be specified'
1100                    raise msg
1101            else:
1102                if self.interpolation_points is None:
1103                    msg = 'Interpolation_function must be instantiated ' +\
1104                          'with a list of interpolation points before parameter ' +\
1105                          'point_id can be used'
1106                    raise msg
1107
1108
1109        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1110        msg += ' does not match model time: %s\n' %t
1111        if t < self.T[0]: raise msg
1112        if t > self.T[-1]: raise msg
1113
1114        oldindex = self.index #Time index
1115
1116        #Find current time slot
1117        while t > self.T[self.index]: self.index += 1
1118        while t < self.T[self.index]: self.index -= 1
1119
1120        if t == self.T[self.index]:
1121            #Protect against case where t == T[-1] (last time)
1122            # - also works in general when t == T[i]
1123            ratio = 0
1124        else:
1125            #t is now between index and index+1
1126            ratio = (t - self.T[self.index])/\
1127                    (self.T[self.index+1] - self.T[self.index])
1128
1129        #Compute interpolated values
1130        q = zeros(len(self.quantity_names), Float)
1131
1132        for i, name in enumerate(self.quantity_names):
1133            Q = self.precomputed_values[name]
1134
1135            if self.spatial is False:
1136                #If there is no spatial info               
1137                assert len(Q.shape) == 1
1138
1139                Q0 = Q[self.index]
1140                if ratio > 0: Q1 = Q[self.index+1]
1141
1142            else:
1143                if x is not None and y is not None:
1144                    #Interpolate to x, y
1145                   
1146                    raise 'x,y interpolation not yet implemented'
1147                else:
1148                    #Use precomputed point
1149                    Q0 = Q[self.index, point_id]
1150                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1151
1152            #Linear temporal interpolation   
1153            if ratio > 0:
1154                q[i] = Q0 + ratio*(Q1 - Q0)
1155            else:
1156                q[i] = Q0
1157
1158
1159        #Return vector of interpolated values
1160        #if len(q) == 1:
1161        #    return q[0]
1162        #else:
1163        #    return q
1164
1165
1166        #Return vector of interpolated values
1167        #FIXME:
1168        if self.spatial is True:
1169            return q
1170        else:
1171            #Replicate q according to x and y
1172            #This is e.g used for Wind_stress
1173            if x == None or y == None: 
1174                return q
1175            else:
1176                try:
1177                    N = len(x)
1178                except:
1179                    return q
1180                else:
1181                    from Numeric import ones, Float
1182                    #x is a vector - Create one constant column for each value
1183                    N = len(x)
1184                    assert len(y) == N, 'x and y must have same length'
1185                    res = []
1186                    for col in q:
1187                        res.append(col*ones(N, Float))
1188                       
1189                return res
1190
1191
1192
1193
1194#-------------------------------------------------------------
1195if __name__ == "__main__":
1196    """
1197    Load in a mesh and data points with attributes.
1198    Fit the attributes to the mesh.
1199    Save a new mesh file.
1200    """
1201    import os, sys
1202    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha]"\
1203            %os.path.basename(sys.argv[0])
1204
1205    if len(sys.argv) < 4:
1206        print usage
1207    else:
1208        mesh_file = sys.argv[1]
1209        point_file = sys.argv[2]
1210        mesh_output_file = sys.argv[3]
1211
1212        expand_search = False
1213        if len(sys.argv) > 4:
1214            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1215                expand_search = True
1216            else:
1217                expand_search = False
1218
1219        verbose = False
1220        if len(sys.argv) > 5:
1221            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1222                verbose = False
1223            else:
1224                verbose = True
1225
1226        if len(sys.argv) > 6:
1227            alpha = sys.argv[6]
1228        else:
1229            alpha = DEFAULT_ALPHA
1230
1231        t0 = time.time()
1232        fit_to_mesh_file(mesh_file,
1233                         point_file,
1234                         mesh_output_file,
1235                         alpha,
1236                         verbose= verbose,
1237                         expand_search = expand_search)
1238
1239        print 'That took %.2f seconds' %(time.time()-t0)
1240
Note: See TracBrowser for help on using the repository browser.