source: inundation/pyvolution/least_squares.py @ 1741

Last change on this file since 1741 was 1741, checked in by ole, 19 years ago
File size: 44.2 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False):
63    """
64    Given a mesh file (tsh) and a point attribute file (xya), fit
65    point attributes to the mesh and write a mesh file with the
66    results.
67
68
69    If data_origin is not None it is assumed to be
70    a 3-tuple with geo referenced
71    UTM coordinates (zone, easting, northing)
72
73    mesh_origin is the same but refers to the input tsh file.
74    FIXME: When the tsh format contains it own origin, these parameters can go.
75    FIXME: And both origins should be obtained from the specified files.
76    """
77
78    from load_mesh.loadASCII import import_mesh_file, \
79                 import_points_file, export_mesh_file, \
80                 concatinate_attributelist
81
82    mesh_dict = import_mesh_file(mesh_file)
83    vertex_coordinates = mesh_dict['vertices']
84    triangles = mesh_dict['triangles']
85    if type(mesh_dict['vertex_attributes']) == ArrayType:
86        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
87    else:
88        old_point_attributes = mesh_dict['vertex_attributes']
89
90    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
91        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
92    else:
93        old_title_list = mesh_dict['vertex_attribute_titles']
94
95    if verbose: print 'tsh file %s loaded' %mesh_file
96
97    # load in the .pts file
98    try:
99        point_dict = import_points_file(point_file,
100                                      delimiter = ',',
101                                      verbose=verbose)
102    except SyntaxError,e:
103        point_dict = import_points_file(point_file,
104                                      delimiter = ' ',
105                                      verbose=verbose)
106
107    point_coordinates = point_dict['pointlist']
108    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
109
110    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
111        data_origin = point_dict['geo_reference'].get_origin()
112    else:
113        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
114
115    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
116        mesh_origin = mesh_dict['geo_reference'].get_origin()
117    else:
118        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
119
120    if verbose: print "points file loaded"
121    if verbose:print "fitting to mesh"
122    f = fit_to_mesh(vertex_coordinates,
123                    triangles,
124                    point_coordinates,
125                    point_attributes,
126                    alpha = alpha,
127                    verbose = verbose,
128                    expand_search = expand_search,
129                    data_origin = data_origin,
130                    mesh_origin = mesh_origin,
131                    precrop = precrop)
132    if verbose: print "finished fitting to mesh"
133
134    # convert array to list of lists
135    new_point_attributes = f.tolist()
136    #FIXME have this overwrite attributes with the same title - DSG
137    #Put the newer attributes last
138    if old_title_list <> []:
139        old_title_list.extend(title_list)
140        #FIXME can this be done a faster way? - DSG
141        for i in range(len(old_point_attributes)):
142            old_point_attributes[i].extend(new_point_attributes[i])
143        mesh_dict['vertex_attributes'] = old_point_attributes
144        mesh_dict['vertex_attribute_titles'] = old_title_list
145    else:
146        mesh_dict['vertex_attributes'] = new_point_attributes
147        mesh_dict['vertex_attribute_titles'] = title_list
148
149    #FIXME (Ole): Remember to output mesh_origin as well
150    if verbose: print "exporting to file ",mesh_output_file
151    export_mesh_file(mesh_output_file, mesh_dict)
152
153
154def fit_to_mesh(vertex_coordinates,
155                triangles,
156                point_coordinates,
157                point_attributes,
158                alpha = DEFAULT_ALPHA,
159                verbose = False,
160                expand_search = False,
161                data_origin = None,
162                mesh_origin = None,
163                precrop = False):
164    """
165    Fit a smooth surface to a triangulation,
166    given data points with attributes.
167
168
169        Inputs:
170
171          vertex_coordinates: List of coordinate pairs [xi, eta] of points
172          constituting mesh (or a an m x 2 Numeric array)
173
174          triangles: List of 3-tuples (or a Numeric array) of
175          integers representing indices of all vertices in the mesh.
176
177          point_coordinates: List of coordinate pairs [x, y] of data points
178          (or an nx2 Numeric array)
179
180          alpha: Smoothing parameter.
181
182          point_attributes: Vector or array of data at the point_coordinates.
183
184          data_origin and mesh_origin are 3-tuples consisting of
185          UTM zone, easting and northing. If specified
186          point coordinates and vertex coordinates are assumed to be
187          relative to their respective origins.
188
189    """
190    interp = Interpolation(vertex_coordinates,
191                           triangles,
192                           point_coordinates,
193                           alpha = alpha,
194                           verbose = verbose,
195                           expand_search = expand_search,
196                           data_origin = data_origin,
197                           mesh_origin = mesh_origin,
198                           precrop = precrop)
199
200    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
201    return vertex_attributes
202
203
204
205def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
206                    verbose = False, reduction = 1, format = 'netcdf'):
207    """Fits attributes from pts file to MxN rectangular mesh
208
209    Read pts file and create rectangular mesh of resolution MxN such that
210    it covers all points specified in pts file.
211
212    FIXME: This may be a temporary function until we decide on
213    netcdf formats etc
214
215    FIXME: Uses elevation hardwired
216    """
217
218    import util, mesh_factory
219
220    if verbose: print 'Read pts'
221    points, attributes = util.read_xya(pts_name, format)
222
223    #Reduce number of points a bit
224    points = points[::reduction]
225    elevation = attributes['elevation']  #Must be elevation
226    elevation = elevation[::reduction]
227
228    if verbose: print 'Got %d data points' %len(points)
229
230    if verbose: print 'Create mesh'
231    #Find extent
232    max_x = min_x = points[0][0]
233    max_y = min_y = points[0][1]
234    for point in points[1:]:
235        x = point[0]
236        if x > max_x: max_x = x
237        if x < min_x: min_x = x
238        y = point[1]
239        if y > max_y: max_y = y
240        if y < min_y: min_y = y
241
242    #Create appropriate mesh
243    vertex_coordinates, triangles, boundary =\
244         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
245                                (min_x, min_y))
246
247    #Fit attributes to mesh
248    vertex_attributes = fit_to_mesh(vertex_coordinates,
249                        triangles,
250                        points,
251                        elevation, alpha=alpha, verbose=verbose)
252
253
254
255    return vertex_coordinates, triangles, boundary, vertex_attributes
256
257
258
259class Interpolation:
260
261    def __init__(self,
262                 vertex_coordinates,
263                 triangles,
264                 point_coordinates = None,
265                 alpha = DEFAULT_ALPHA,
266                 verbose = False,
267                 expand_search = True,
268                 max_points_per_cell = 30,
269                 mesh_origin = None,
270                 data_origin = None,
271                 precrop = False):
272
273
274        """ Build interpolation matrix mapping from
275        function values at vertices to function values at data points
276
277        Inputs:
278
279          vertex_coordinates: List of coordinate pairs [xi, eta] of
280          points constituting mesh (or a an m x 2 Numeric array)
281          Points may appear multiple times
282          (e.g. if vertices have discontinuities)
283
284          triangles: List of 3-tuples (or a Numeric array) of
285          integers representing indices of all vertices in the mesh.
286
287          point_coordinates: List of coordinate pairs [x, y] of
288          data points (or an nx2 Numeric array)
289          If point_coordinates is absent, only smoothing matrix will
290          be built
291
292          alpha: Smoothing parameter
293
294          data_origin and mesh_origin are 3-tuples consisting of
295          UTM zone, easting and northing. If specified
296          point coordinates and vertex coordinates are assumed to be
297          relative to their respective origins.
298
299        """
300        from util import ensure_numeric
301
302        #Convert input to Numeric arrays
303        triangles = ensure_numeric(triangles, Int)
304        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
305
306        #Build underlying mesh
307        if verbose: print 'Building mesh'
308        #self.mesh = General_mesh(vertex_coordinates, triangles,
309        #FIXME: Trying the normal mesh while testing precrop,
310        #       The functionality of boundary_polygon is needed for that
311
312        #FIXME - geo ref does not have to go into mesh.
313        # Change the point co-ords to conform to the
314        # mesh co-ords early in the code
315        if mesh_origin == None:
316            geo = None
317        else:
318            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
319        self.mesh = Mesh(vertex_coordinates, triangles,
320                         geo_reference = geo)
321       
322        self.mesh.check_integrity()
323
324        self.data_origin = data_origin
325
326        self.point_indices = None
327
328        #Smoothing parameter
329        self.alpha = alpha
330
331        #Build coefficient matrices
332        self.build_coefficient_matrix_B(point_coordinates,
333                                        verbose = verbose,
334                                        expand_search = expand_search,
335                                        max_points_per_cell =\
336                                        max_points_per_cell,
337                                        data_origin = data_origin,
338                                        precrop = precrop)
339
340
341    def set_point_coordinates(self, point_coordinates,
342                              data_origin = None):
343        """
344        A public interface to setting the point co-ordinates.
345        """
346        self.build_coefficient_matrix_B(point_coordinates, data_origin)
347
348    def build_coefficient_matrix_B(self, point_coordinates=None,
349                                   verbose = False, expand_search = True,
350                                   max_points_per_cell=30,
351                                   data_origin = None,
352                                   precrop = False):
353        """Build final coefficient matrix"""
354
355
356        if self.alpha <> 0:
357            if verbose: print 'Building smoothing matrix'
358            self.build_smoothing_matrix_D()
359
360        if point_coordinates is not None:
361
362            if verbose: print 'Building interpolation matrix'
363            self.build_interpolation_matrix_A(point_coordinates,
364                                              verbose = verbose,
365                                              expand_search = expand_search,
366                                              max_points_per_cell =\
367                                              max_points_per_cell,
368                                              data_origin = data_origin,
369                                              precrop = precrop)
370
371            if self.alpha <> 0:
372                self.B = self.AtA + self.alpha*self.D
373            else:
374                self.B = self.AtA
375
376            #Convert self.B matrix to CSR format for faster matrix vector
377            self.B = Sparse_CSR(self.B)
378
379    def build_interpolation_matrix_A(self, point_coordinates,
380                                     verbose = False, expand_search = True,
381                                     max_points_per_cell=30,
382                                     data_origin = None,
383                                     precrop = False):
384        """Build n x m interpolation matrix, where
385        n is the number of data points and
386        m is the number of basis functions phi_k (one per vertex)
387
388        This algorithm uses a quad tree data structure for fast binning of data points
389        origin is a 3-tuple consisting of UTM zone, easting and northing.
390        If specified coordinates are assumed to be relative to this origin.
391
392        This one will override any data_origin that may be specified in
393        interpolation instance
394
395        """
396
397
398        #FIXME (Ole): Check that this function is memeory efficient.
399        #6 million datapoints and 300000 basis functions
400        #causes out-of-memory situation
401        #First thing to check is whether there is room for self.A and self.AtA
402        #
403        #Maybe we need some sort of blocking
404
405        from quad import build_quadtree
406        from util import ensure_numeric
407
408        if data_origin is None:
409            data_origin = self.data_origin #Use the one from
410                                           #interpolation instance
411
412        #Convert input to Numeric arrays just in case.
413        point_coordinates = ensure_numeric(point_coordinates, Float)
414
415        #Keep track of discarded points (if any).
416        #This is only registered if precrop is True
417        self.cropped_points = False
418
419        #Shift data points to same origin as mesh (if specified)
420
421        #FIXME this will shift if there was no geo_ref.
422        #But all this should be removed anyhow.
423        #change coords before this point
424        mesh_origin = self.mesh.geo_reference.get_origin()
425        if point_coordinates is not None:
426            if data_origin is not None:
427                if mesh_origin is not None:
428
429                    #Transformation:
430                    #
431                    #Let x_0 be the reference point of the point coordinates
432                    #and xi_0 the reference point of the mesh.
433                    #
434                    #A point coordinate (x + x_0) is then made relative
435                    #to xi_0 by
436                    #
437                    # x_new = x + x_0 - xi_0
438                    #
439                    #and similarly for eta
440
441                    x_offset = data_origin[1] - mesh_origin[1]
442                    y_offset = data_origin[2] - mesh_origin[2]
443                else: #Shift back to a zero origin
444                    x_offset = data_origin[1]
445                    y_offset = data_origin[2]
446
447                point_coordinates[:,0] += x_offset
448                point_coordinates[:,1] += y_offset
449            else:
450                if mesh_origin is not None:
451                    #Use mesh origin for data points
452                    point_coordinates[:,0] -= mesh_origin[1]
453                    point_coordinates[:,1] -= mesh_origin[2]
454
455
456
457        #Remove points falling outside mesh boundary
458        #This reduced one example from 1356 seconds to 825 seconds
459        if precrop is True:
460            from Numeric import take
461            from util import inside_polygon
462
463            if verbose: print 'Getting boundary polygon'
464            P = self.mesh.get_boundary_polygon()
465
466            if verbose: print 'Getting indices inside mesh boundary'
467            indices = inside_polygon(point_coordinates, P, verbose = verbose)
468
469
470            if len(indices) != point_coordinates.shape[0]:
471                self.cropped_points = True
472                if verbose:
473                    print 'Done - %d points outside mesh have been cropped.'\
474                          %(point_coordinates.shape[0] - len(indices))
475
476            point_coordinates = take(point_coordinates, indices)
477            self.point_indices = indices
478
479
480
481
482        #Build n x m interpolation matrix
483        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
484        n = point_coordinates.shape[0]     #Nbr of data points
485
486        if verbose: print 'Number of datapoints: %d' %n
487        if verbose: print 'Number of basis functions: %d' %m
488
489        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
490        #However, Sparse_CSR does not have the same methods as Sparse yet
491        #The tests will reveal what needs to be done
492        self.A = Sparse(n,m)
493        self.AtA = Sparse(m,m)
494
495        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
496        root = build_quadtree(self.mesh,
497                              max_points_per_cell = max_points_per_cell)
498
499        #Compute matrix elements
500        for i in range(n):
501            #For each data_coordinate point
502
503            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
504
505            x = point_coordinates[i]
506
507            #Find vertices near x
508            candidate_vertices = root.search(x[0], x[1])
509            is_more_elements = True
510
511            element_found, sigma0, sigma1, sigma2, k = \
512                self.search_triangles_of_vertices(candidate_vertices, x)
513            while not element_found and is_more_elements and expand_search:
514                #if verbose: print 'Expanding search'
515                candidate_vertices, branch = root.expand_search()
516                if branch == []:
517                    # Searching all the verts from the root cell that haven't
518                    # been searched.  This is the last try
519                    element_found, sigma0, sigma1, sigma2, k = \
520                      self.search_triangles_of_vertices(candidate_vertices, x)
521                    is_more_elements = False
522                else:
523                    element_found, sigma0, sigma1, sigma2, k = \
524                      self.search_triangles_of_vertices(candidate_vertices, x)
525
526
527            #Update interpolation matrix A if necessary
528            if element_found is True:
529                #Assign values to matrix A
530
531                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
532                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
533                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
534
535                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
536                js     = [j0,j1,j2]
537
538                for j in js:
539                    self.A[i,j] = sigmas[j]
540                    for k in js:
541                        self.AtA[j,k] += sigmas[j]*sigmas[k]
542            else:
543                pass
544                #Ok if there is no triangle for datapoint
545                #(as in brute force version)
546                #raise 'Could not find triangle for point', x
547
548
549
550    def search_triangles_of_vertices(self, candidate_vertices, x):
551            #Find triangle containing x:
552            element_found = False
553
554            # This will be returned if element_found = False
555            sigma2 = -10.0
556            sigma0 = -10.0
557            sigma1 = -10.0
558            k = -10.0
559
560            #For all vertices in same cell as point x
561            for v in candidate_vertices:
562
563                #for each triangle id (k) which has v as a vertex
564                for k, _ in self.mesh.vertexlist[v]:
565
566                    #Get the three vertex_points of candidate triangle
567                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
568                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
569                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
570
571                    #print "PDSG - k", k
572                    #print "PDSG - xi0", xi0
573                    #print "PDSG - xi1", xi1
574                    #print "PDSG - xi2", xi2
575                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
576                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
577
578                    #Get the three normals
579                    n0 = self.mesh.get_normal(k, 0)
580                    n1 = self.mesh.get_normal(k, 1)
581                    n2 = self.mesh.get_normal(k, 2)
582
583
584                    #Compute interpolation
585                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
586                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
587                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
588
589                    #print "PDSG - sigma0", sigma0
590                    #print "PDSG - sigma1", sigma1
591                    #print "PDSG - sigma2", sigma2
592
593                    #FIXME: Maybe move out to test or something
594                    epsilon = 1.0e-6
595                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
596
597                    #Check that this triangle contains the data point
598
599                    #Sigmas can get negative within
600                    #machine precision on some machines (e.g nautilus)
601                    #Hence the small eps
602                    eps = 1.0e-15
603                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
604                        element_found = True
605                        break
606
607                if element_found is True:
608                    #Don't look for any other triangle
609                    break
610            return element_found, sigma0, sigma1, sigma2, k
611
612
613
614    def build_interpolation_matrix_A_brute(self, point_coordinates):
615        """Build n x m interpolation matrix, where
616        n is the number of data points and
617        m is the number of basis functions phi_k (one per vertex)
618
619        This is the brute force which is too slow for large problems,
620        but could be used for testing
621        """
622
623        from util import ensure_numeric
624
625        #Convert input to Numeric arrays
626        point_coordinates = ensure_numeric(point_coordinates, Float)
627
628        #Build n x m interpolation matrix
629        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
630        n = point_coordinates.shape[0]     #Nbr of data points
631
632        self.A = Sparse(n,m)
633        self.AtA = Sparse(m,m)
634
635        #Compute matrix elements
636        for i in range(n):
637            #For each data_coordinate point
638
639            x = point_coordinates[i]
640            element_found = False
641            k = 0
642            while not element_found and k < len(self.mesh):
643                #For each triangle (brute force)
644                #FIXME: Real algorithm should only visit relevant triangles
645
646                #Get the three vertex_points
647                xi0 = self.mesh.get_vertex_coordinate(k, 0)
648                xi1 = self.mesh.get_vertex_coordinate(k, 1)
649                xi2 = self.mesh.get_vertex_coordinate(k, 2)
650
651                #Get the three normals
652                n0 = self.mesh.get_normal(k, 0)
653                n1 = self.mesh.get_normal(k, 1)
654                n2 = self.mesh.get_normal(k, 2)
655
656                #Compute interpolation
657                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
658                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
659                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
660
661                #FIXME: Maybe move out to test or something
662                epsilon = 1.0e-6
663                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
664
665                #Check that this triangle contains data point
666                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
667                    element_found = True
668                    #Assign values to matrix A
669
670                    j0 = self.mesh.triangles[k,0] #Global vertex id
671                    #self.A[i, j0] = sigma0
672
673                    j1 = self.mesh.triangles[k,1] #Global vertex id
674                    #self.A[i, j1] = sigma1
675
676                    j2 = self.mesh.triangles[k,2] #Global vertex id
677                    #self.A[i, j2] = sigma2
678
679                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
680                    js     = [j0,j1,j2]
681
682                    for j in js:
683                        self.A[i,j] = sigmas[j]
684                        for k in js:
685                            self.AtA[j,k] += sigmas[j]*sigmas[k]
686                k = k+1
687
688
689
690    def get_A(self):
691        return self.A.todense()
692
693    def get_B(self):
694        return self.B.todense()
695
696    def get_D(self):
697        return self.D.todense()
698
699        #FIXME: Remember to re-introduce the 1/n factor in the
700        #interpolation term
701
702    def build_smoothing_matrix_D(self):
703        """Build m x m smoothing matrix, where
704        m is the number of basis functions phi_k (one per vertex)
705
706        The smoothing matrix is defined as
707
708        D = D1 + D2
709
710        where
711
712        [D1]_{k,l} = \int_\Omega
713           \frac{\partial \phi_k}{\partial x}
714           \frac{\partial \phi_l}{\partial x}\,
715           dx dy
716
717        [D2]_{k,l} = \int_\Omega
718           \frac{\partial \phi_k}{\partial y}
719           \frac{\partial \phi_l}{\partial y}\,
720           dx dy
721
722
723        The derivatives \frac{\partial \phi_k}{\partial x},
724        \frac{\partial \phi_k}{\partial x} for a particular triangle
725        are obtained by computing the gradient a_k, b_k for basis function k
726        """
727
728        #FIXME: algorithm might be optimised by computing local 9x9
729        #"element stiffness matrices:
730
731        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
732
733        self.D = Sparse(m,m)
734
735        #For each triangle compute contributions to D = D1+D2
736        for i in range(len(self.mesh)):
737
738            #Get area
739            area = self.mesh.areas[i]
740
741            #Get global vertex indices
742            v0 = self.mesh.triangles[i,0]
743            v1 = self.mesh.triangles[i,1]
744            v2 = self.mesh.triangles[i,2]
745
746            #Get the three vertex_points
747            xi0 = self.mesh.get_vertex_coordinate(i, 0)
748            xi1 = self.mesh.get_vertex_coordinate(i, 1)
749            xi2 = self.mesh.get_vertex_coordinate(i, 2)
750
751            #Compute gradients for each vertex
752            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
753                              1, 0, 0)
754
755            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
756                              0, 1, 0)
757
758            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
759                              0, 0, 1)
760
761            #Compute diagonal contributions
762            self.D[v0,v0] += (a0*a0 + b0*b0)*area
763            self.D[v1,v1] += (a1*a1 + b1*b1)*area
764            self.D[v2,v2] += (a2*a2 + b2*b2)*area
765
766            #Compute contributions for basis functions sharing edges
767            e01 = (a0*a1 + b0*b1)*area
768            self.D[v0,v1] += e01
769            self.D[v1,v0] += e01
770
771            e12 = (a1*a2 + b1*b2)*area
772            self.D[v1,v2] += e12
773            self.D[v2,v1] += e12
774
775            e20 = (a2*a0 + b2*b0)*area
776            self.D[v2,v0] += e20
777            self.D[v0,v2] += e20
778
779
780    def fit(self, z):
781        """Fit a smooth surface to given 1d array of data points z.
782
783        The smooth surface is computed at each vertex in the underlying
784        mesh using the formula given in the module doc string.
785
786        Pre Condition:
787          self.A, self.AtA and self.B have been initialised
788
789        Inputs:
790          z: Single 1d vector or array of data at the point_coordinates.
791        """
792
793        #Convert input to Numeric arrays
794        from util import ensure_numeric
795        z = ensure_numeric(z, Float)
796
797        if len(z.shape) > 1 :
798            raise VectorShapeError, 'Can only deal with 1d data vector'
799
800        if self.point_indices is not None:
801            #Remove values for any points that were outside mesh
802            z = take(z, self.point_indices)
803
804        #Compute right hand side based on data
805        Atz = self.A.trans_mult(z)
806
807
808        #Check sanity
809        n, m = self.A.shape
810        if n<m and self.alpha == 0.0:
811            msg = 'ERROR (least_squares): Too few data points\n'
812            msg += 'There are only %d data points and alpha == 0. ' %n
813            msg += 'Need at least %d\n' %m
814            msg += 'Alternatively, set smoothing parameter alpha to a small '
815            msg += 'positive value,\ne.g. 1.0e-3.'
816            raise msg
817
818
819
820        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
821        #FIXME: Should we store the result here for later use? (ON)
822
823
824    def fit_points(self, z, verbose=False):
825        """Like fit, but more robust when each point has two or more attributes
826        FIXME (Ole): The name fit_points doesn't carry any meaning
827        for me. How about something like fit_multiple or fit_columns?
828        """
829
830        try:
831            if verbose: print 'Solving penalised least_squares problem'
832            return self.fit(z)
833        except VectorShapeError, e:
834            # broadcasting is not supported.
835
836            #Convert input to Numeric arrays
837            from util import ensure_numeric
838            z = ensure_numeric(z, Float)
839
840            #Build n x m interpolation matrix
841            m = self.mesh.coordinates.shape[0] #Number of vertices
842            n = z.shape[1]                     #Number of data points
843
844            f = zeros((m,n), Float) #Resulting columns
845
846            for i in range(z.shape[1]):
847                f[:,i] = self.fit(z[:,i])
848
849            return f
850
851
852    def interpolate(self, f):
853        """Evaluate smooth surface f at data points implied in self.A.
854
855        The mesh values representing a smooth surface are
856        assumed to be specified in f. This argument could,
857        for example have been obtained from the method self.fit()
858
859        Pre Condition:
860          self.A has been initialised
861
862        Inputs:
863          f: Vector or array of data at the mesh vertices.
864          If f is an array, interpolation will be done for each column as
865          per underlying matrix-matrix multiplication
866
867        Output:
868          Interpolated values at data points implied in self.A
869
870        """
871
872        return self.A * f
873
874    def cull_outsiders(self, f):
875        pass
876
877
878
879
880class Interpolation_function:
881    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
882    which is interpolated from time series defined at vertices of
883    triangular mesh (such as those stored in sww files)
884
885    Let m be the number of vertices, n the number of triangles
886    and p the number of timesteps.
887
888    Mandatory input
889        time:               px1 array of monotonously increasing times (Float)
890        quantities:         Dictionary of pxm arrays or 1 pxm array (Float)
891       
892    Optional input:
893        quantity_names:     List of keys into the quantities dictionary
894        vertex_coordinates: mx2 array of coordinates (Float)
895        triangles:          nx3 array of indices into vertex_coordinates (Int)
896        interpolation_points: array of coordinates to be interpolated to
897        verbose:            Level of reporting
898   
899   
900    The quantities returned by the callable object are specified by
901    the list quantities which must contain the names of the
902    quantities to be returned and also reflect the order, e.g. for
903    the shallow water wave equation, on would have
904    quantities = ['stage', 'xmomentum', 'ymomentum']
905
906    The parameter interpolation_points decides at which points interpolated
907    quantities are to be computed whenever object is called.
908    If None, return average value
909    """
910
911   
912   
913    def __init__(self,
914                 time,
915                 quantities,
916                 quantity_names = None, 
917                 vertex_coordinates = None,
918                 triangles = None,
919                 interpolation_points = None,
920                 verbose = False):
921        """Initialise object and build spatial interpolation if required
922        """
923
924        from Numeric import array, zeros, Float, alltrue, concatenate,\
925             reshape, ArrayType
926
927        from util import mean, ensure_numeric
928        from config import time_format
929        import types
930
931
932
933        #Check temporal info
934        time = ensure_numeric(time)       
935        msg = 'Time must be a monotonuosly '
936        msg += 'increasing sequence %s' %time
937        assert alltrue(time[1:] - time[:-1] > 0 ), msg
938
939
940        #Check if quantities is a single array only
941        if type(quantities) != types.DictType:
942            quantities = ensure_numeric(quantities)
943            quantity_names = ['Attribute']
944
945            #Make it a dictionary
946            quantities = {quantity_names[0]: quantities}
947
948
949        #Use keys if no names are specified
950        if quantity_names is not None:
951            self.quantity_names = quantity_names
952        else:
953            self.quantity_names = quantities.keys()
954
955
956        #Check spatial info
957        if vertex_coordinates is None:
958            self.spatial = False
959        else:   
960            vertex_coordinates = ensure_numeric(vertex_coordinates)
961
962            assert triangles is not None, 'Triangles array must be specified'
963            triangles = ensure_numeric(triangles)
964            self.spatial = True           
965           
966 
967        #     
968        self.interpolation_points = interpolation_points #FIXWME Needed?
969        self.T = time[:]  #Time assumed to be relative to starttime
970        self.index = 0    #Initial time index
971        self.precomputed_values = {}
972           
973
974
975        #Precomputed spatial interpolation if requested
976        if interpolation_points is not None:
977            if self.spatial is False:
978                raise 'Triangles and vertex_coordinates must be specified'
979           
980
981            try:
982                interpolation_points = ensure_numeric(interpolation_points)
983            except:
984                msg = 'Interpolation points must be an N x 2 Numeric array '+\
985                      'or a list of points\n'
986                msg += 'I got: %s.' %( str(interpolation_points)[:60] + '...')
987                raise msg
988
989
990            for name in quantity_names:
991                self.precomputed_values[name] =\
992                                              zeros((len(self.T),
993                                                     len(interpolation_points)),
994                                                    Float)
995
996            #Build interpolator
997            interpol = Interpolation(vertex_coordinates,
998                                     triangles,
999                                     point_coordinates = interpolation_points,
1000                                     alpha = 0,
1001                                     precrop = False, 
1002                                     verbose = verbose)
1003
1004            #if interpol.cropped_points is True:
1005            #    raise 'Some interpolation points were outside mesh'
1006            #FIXME: This will be raised if triangles are listed as
1007            #discontinuous even though there is no need to stop
1008            #(precrop = True above)
1009
1010            if verbose: print 'Interpolate'
1011            for i, t in enumerate(self.T):
1012                #Interpolate quantities at this timestep
1013                if verbose: print ' time step %d of %d' %(i, len(self.T))
1014                for name in quantity_names:
1015                    self.precomputed_values[name][i, :] =\
1016                    interpol.interpolate(quantities[name][i,:])
1017
1018            #Report
1019            if verbose:
1020                x = vertex_coordinates[:,0]
1021                y = vertex_coordinates[:,1]               
1022           
1023                print '------------------------------------------------'
1024                print 'Interpolation_function statistics:'
1025                print '  Extent:'
1026                print '    x in [%f, %f], len(x) == %d'\
1027                      %(min(x), max(x), len(x))
1028                print '    y in [%f, %f], len(y) == %d'\
1029                      %(min(y), max(y), len(y))
1030                print '    t in [%f, %f], len(t) == %d'\
1031                      %(min(self.T), max(self.T), len(self.T))
1032                print '  Quantities:'
1033                for name in quantity_names:
1034                    q = quantities[name][:].flat
1035                    print '    %s in [%f, %f]' %(name, min(q), max(q))
1036                print '  Interpolation points (xi, eta):'\
1037                      ' number of points == %d ' %interpolation_points.shape[0]
1038                print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1039                                             max(interpolation_points[:,0]))
1040                print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1041                                              max(interpolation_points[:,1]))
1042                print '  Interpolated quantities (over all timesteps):'
1043               
1044                for name in quantity_names:
1045                    q = self.precomputed_values[name][:].flat
1046                    print '    %s at interpolation points in [%f, %f]'\
1047                          %(name, min(q), max(q))
1048                print '------------------------------------------------'
1049           
1050        else:
1051            #Store quantitites as is
1052            for name in quantity_names:
1053                self.precomputed_values[name] = quantities[name]
1054
1055
1056        #else:
1057        #    #Return an average, making this a time series
1058        #    for name in quantity_names:
1059        #        self.values[name] = zeros(len(self.T), Float)
1060        #
1061        #    if verbose: print 'Compute mean values'
1062        #    for i, t in enumerate(self.T):
1063        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1064        #        for name in quantity_names:
1065        #           self.values[name][i] = mean(quantities[name][i,:])
1066
1067
1068
1069
1070    def __repr__(self):
1071        return 'Interpolation function (spation-temporal)'
1072
1073    def __call__(self, t, point_id = None, x = None, y = None):
1074        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1075
1076        Inputs:
1077          t: time - Model time. Must lie within existing timesteps
1078          point_id: index of one of the preprocessed points.
1079          x, y:     Overrides location, point_id ignored
1080         
1081          If spatial info is present and all of x,y,point_id
1082          are None an exception is raised
1083                   
1084          If no spatial info is present, point_id and x,y arguments are ignored
1085          making f a function of time only.
1086
1087         
1088          FIXME: point_id could also be a slice
1089          FIXME: What if x and y are vectors?
1090          FIXME: What about f(x,y) without t?
1091        """
1092
1093        from math import pi, cos, sin, sqrt
1094        from Numeric import zeros, Float
1095        from util import mean       
1096
1097        if self.spatial is True:
1098            if point_id is None:
1099                if x is None or y is None:
1100                    msg = 'Either point_id or x and y must be specified'
1101                    raise msg
1102            else:
1103                if self.interpolation_points is None:
1104                    msg = 'Interpolation_function must be instantiated ' +\
1105                          'with a list of interpolation points before parameter ' +\
1106                          'point_id can be used'
1107                    raise msg
1108
1109
1110        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1111        msg += ' does not match model time: %s\n' %t
1112        if t < self.T[0]: raise msg
1113        if t > self.T[-1]: raise msg
1114
1115        oldindex = self.index #Time index
1116
1117        #Find current time slot
1118        while t > self.T[self.index]: self.index += 1
1119        while t < self.T[self.index]: self.index -= 1
1120
1121        if t == self.T[self.index]:
1122            #Protect against case where t == T[-1] (last time)
1123            # - also works in general when t == T[i]
1124            ratio = 0
1125        else:
1126            #t is now between index and index+1
1127            ratio = (t - self.T[self.index])/\
1128                    (self.T[self.index+1] - self.T[self.index])
1129
1130        #Compute interpolated values
1131        q = zeros(len(self.quantity_names), Float)
1132
1133        for i, name in enumerate(self.quantity_names):
1134            Q = self.precomputed_values[name]
1135
1136            if self.spatial is False:
1137                #If there is no spatial info               
1138                assert len(Q.shape) == 1
1139
1140                Q0 = Q[self.index]
1141                if ratio > 0: Q1 = Q[self.index+1]
1142
1143            else:
1144                if x is not None and y is not None:
1145                    #Interpolate to x, y
1146                   
1147                    raise 'x,y interpolation not yet implemented'
1148                else:
1149                    #Use precomputed point
1150                    Q0 = Q[self.index, point_id]
1151                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1152
1153            #Linear temporal interpolation   
1154            if ratio > 0:
1155                q[i] = Q0 + ratio*(Q1 - Q0)
1156            else:
1157                q[i] = Q0
1158
1159
1160        #Return vector of interpolated values
1161        #if len(q) == 1:
1162        #    return q[0]
1163        #else:
1164        #    return q
1165
1166
1167        #Return vector of interpolated values
1168        #FIXME:
1169        if self.spatial is True:
1170            return q
1171        else:
1172            #Replicate q according to x and y
1173            #This is e.g used for Wind_stress
1174            if x == None or y == None: 
1175                return q
1176            else:
1177                try:
1178                    N = len(x)
1179                except:
1180                    return q
1181                else:
1182                    from Numeric import ones, Float
1183                    #x is a vector - Create one constant column for each value
1184                    N = len(x)
1185                    assert len(y) == N, 'x and y must have same length'
1186                    res = []
1187                    for col in q:
1188                        res.append(col*ones(N, Float))
1189                       
1190                return res
1191
1192
1193
1194
1195#-------------------------------------------------------------
1196if __name__ == "__main__":
1197    """
1198    Load in a mesh and data points with attributes.
1199    Fit the attributes to the mesh.
1200    Save a new mesh file.
1201    """
1202    import os, sys
1203    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha]"\
1204            %os.path.basename(sys.argv[0])
1205
1206    if len(sys.argv) < 4:
1207        print usage
1208    else:
1209        mesh_file = sys.argv[1]
1210        point_file = sys.argv[2]
1211        mesh_output_file = sys.argv[3]
1212
1213        expand_search = False
1214        if len(sys.argv) > 4:
1215            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1216                expand_search = True
1217            else:
1218                expand_search = False
1219
1220        verbose = False
1221        if len(sys.argv) > 5:
1222            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1223                verbose = False
1224            else:
1225                verbose = True
1226
1227        if len(sys.argv) > 6:
1228            alpha = sys.argv[6]
1229        else:
1230            alpha = DEFAULT_ALPHA
1231
1232        t0 = time.time()
1233        fit_to_mesh_file(mesh_file,
1234                         point_file,
1235                         mesh_output_file,
1236                         alpha,
1237                         verbose= verbose,
1238                         expand_search = expand_search)
1239
1240        print 'That took %.2f seconds' %(time.time()-t0)
1241
Note: See TracBrowser for help on using the repository browser.