source: inundation/pyvolution/least_squares.py @ 1846

Last change on this file since 1846 was 1846, checked in by duncan, 19 years ago

added comments

File size: 46.4 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    mesh_origin is the same but refers to the input tsh file.
75    FIXME: When the tsh format contains it own origin, these parameters can go.
76    FIXME: And both origins should be obtained from the specified files.
77    """
78
79    from load_mesh.loadASCII import import_mesh_file, \
80                 import_points_file, export_mesh_file, \
81                 concatinate_attributelist
82
83   
84    try:
85        mesh_dict = import_mesh_file(mesh_file)
86    except IOError,e:
87        if display_errors:
88            print "Could not load bad file. ", e
89        import sys; sys.exit()
90    vertex_coordinates = mesh_dict['vertices']
91    triangles = mesh_dict['triangles']
92    if type(mesh_dict['vertex_attributes']) == ArrayType:
93        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
94    else:
95        old_point_attributes = mesh_dict['vertex_attributes']
96
97    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
98        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
99    else:
100        old_title_list = mesh_dict['vertex_attribute_titles']
101
102    if verbose: print 'tsh file %s loaded' %mesh_file
103
104    # load in the .pts file
105    try:
106        point_dict = import_points_file(point_file, verbose=verbose)
107    except IOError,e:
108        if display_errors:
109            print "Could not load bad file. ", e
110        import sys; sys.exit()
111
112    point_coordinates = point_dict['pointlist']
113    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
114
115    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
116        data_origin = point_dict['geo_reference'].get_origin()
117    else:
118        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
119
120    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
121        mesh_origin = mesh_dict['geo_reference'].get_origin()
122    else:
123        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
124
125    if verbose: print "points file loaded"
126    if verbose:print "fitting to mesh"
127    f = fit_to_mesh(vertex_coordinates,
128                    triangles,
129                    point_coordinates,
130                    point_attributes,
131                    alpha = alpha,
132                    verbose = verbose,
133                    expand_search = expand_search,
134                    data_origin = data_origin,
135                    mesh_origin = mesh_origin,
136                    precrop = precrop)
137    if verbose: print "finished fitting to mesh"
138
139    # convert array to list of lists
140    new_point_attributes = f.tolist()
141    #FIXME have this overwrite attributes with the same title - DSG
142    #Put the newer attributes last
143    if old_title_list <> []:
144        old_title_list.extend(title_list)
145        #FIXME can this be done a faster way? - DSG
146        for i in range(len(old_point_attributes)):
147            old_point_attributes[i].extend(new_point_attributes[i])
148        mesh_dict['vertex_attributes'] = old_point_attributes
149        mesh_dict['vertex_attribute_titles'] = old_title_list
150    else:
151        mesh_dict['vertex_attributes'] = new_point_attributes
152        mesh_dict['vertex_attribute_titles'] = title_list
153
154    #FIXME (Ole): Remember to output mesh_origin as well
155    if verbose: print "exporting to file ",mesh_output_file
156
157    try:
158        export_mesh_file(mesh_output_file, mesh_dict)
159    except IOError,e:
160        if display_errors:
161            print "Could not write file. ", e
162        import sys; sys.exit()
163
164def fit_to_mesh(vertex_coordinates,
165                triangles,
166                point_coordinates,
167                point_attributes,
168                alpha = DEFAULT_ALPHA,
169                verbose = False,
170                expand_search = False,
171                data_origin = None,
172                mesh_origin = None,
173                precrop = False):
174    """
175    Fit a smooth surface to a triangulation,
176    given data points with attributes.
177
178
179        Inputs:
180
181          vertex_coordinates: List of coordinate pairs [xi, eta] of points
182          constituting mesh (or a an m x 2 Numeric array)
183
184          triangles: List of 3-tuples (or a Numeric array) of
185          integers representing indices of all vertices in the mesh.
186
187          point_coordinates: List of coordinate pairs [x, y] of data points
188          (or an nx2 Numeric array)
189
190          alpha: Smoothing parameter.
191
192          point_attributes: Vector or array of data at the point_coordinates.
193
194          data_origin and mesh_origin are 3-tuples consisting of
195          UTM zone, easting and northing. If specified
196          point coordinates and vertex coordinates are assumed to be
197          relative to their respective origins.
198
199    """
200    interp = Interpolation(vertex_coordinates,
201                           triangles,
202                           point_coordinates,
203                           alpha = alpha,
204                           verbose = verbose,
205                           expand_search = expand_search,
206                           data_origin = data_origin,
207                           mesh_origin = mesh_origin,
208                           precrop = precrop)
209
210    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
211    return vertex_attributes
212
213
214
215def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
216                    verbose = False, reduction = 1, format = 'netcdf'):
217    """Fits attributes from pts file to MxN rectangular mesh
218
219    Read pts file and create rectangular mesh of resolution MxN such that
220    it covers all points specified in pts file.
221
222    FIXME: This may be a temporary function until we decide on
223    netcdf formats etc
224
225    FIXME: Uses elevation hardwired
226    """
227
228    import util, mesh_factory
229
230    if verbose: print 'Read pts'
231    points, attributes = util.read_xya(pts_name, format)
232
233    #Reduce number of points a bit
234    points = points[::reduction]
235    elevation = attributes['elevation']  #Must be elevation
236    elevation = elevation[::reduction]
237
238    if verbose: print 'Got %d data points' %len(points)
239
240    if verbose: print 'Create mesh'
241    #Find extent
242    max_x = min_x = points[0][0]
243    max_y = min_y = points[0][1]
244    for point in points[1:]:
245        x = point[0]
246        if x > max_x: max_x = x
247        if x < min_x: min_x = x
248        y = point[1]
249        if y > max_y: max_y = y
250        if y < min_y: min_y = y
251
252    #Create appropriate mesh
253    vertex_coordinates, triangles, boundary =\
254         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
255                                (min_x, min_y))
256
257    #Fit attributes to mesh
258    vertex_attributes = fit_to_mesh(vertex_coordinates,
259                        triangles,
260                        points,
261                        elevation, alpha=alpha, verbose=verbose)
262
263
264
265    return vertex_coordinates, triangles, boundary, vertex_attributes
266
267
268
269class Interpolation:
270
271    def __init__(self,
272                 vertex_coordinates,
273                 triangles,
274                 point_coordinates = None,
275                 alpha = None,
276                 verbose = False,
277                 expand_search = True,
278                 max_points_per_cell = 30,
279                 mesh_origin = None,
280                 data_origin = None,
281                 precrop = False):
282
283
284        """ Build interpolation matrix mapping from
285        function values at vertices to function values at data points
286
287        Inputs:
288
289          vertex_coordinates: List of coordinate pairs [xi, eta] of
290          points constituting mesh (or a an m x 2 Numeric array)
291          Points may appear multiple times
292          (e.g. if vertices have discontinuities)
293
294          triangles: List of 3-tuples (or a Numeric array) of
295          integers representing indices of all vertices in the mesh.
296
297          point_coordinates: List of coordinate pairs [x, y] of
298          data points (or an nx2 Numeric array)
299          If point_coordinates is absent, only smoothing matrix will
300          be built
301
302          alpha: Smoothing parameter
303
304          data_origin and mesh_origin are 3-tuples consisting of
305          UTM zone, easting and northing. If specified
306          point coordinates and vertex coordinates are assumed to be
307          relative to their respective origins.
308
309        """
310        from util import ensure_numeric
311
312        #Convert input to Numeric arrays
313        triangles = ensure_numeric(triangles, Int)
314        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
315
316        #Build underlying mesh
317        if verbose: print 'Building mesh'
318        #self.mesh = General_mesh(vertex_coordinates, triangles,
319        #FIXME: Trying the normal mesh while testing precrop,
320        #       The functionality of boundary_polygon is needed for that
321
322        #FIXME - geo ref does not have to go into mesh.
323        # Change the point co-ords to conform to the
324        # mesh co-ords early in the code
325        if mesh_origin == None:
326            geo = None
327        else:
328            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
329        self.mesh = Mesh(vertex_coordinates, triangles,
330                         geo_reference = geo)
331       
332        self.mesh.check_integrity()
333
334        self.data_origin = data_origin
335
336        self.point_indices = None
337
338        #Smoothing parameter
339        if alpha is None:
340            self.alpha = DEFAULT_ALPHA
341        else:   
342            self.alpha = alpha
343
344        #Build coefficient matrices
345        self.build_coefficient_matrix_B(point_coordinates,
346                                        verbose = verbose,
347                                        expand_search = expand_search,
348                                        max_points_per_cell =\
349                                        max_points_per_cell,
350                                        data_origin = data_origin,
351                                        precrop = precrop)
352
353
354    def set_point_coordinates(self, point_coordinates,
355                              data_origin = None):
356        """
357        A public interface to setting the point co-ordinates.
358        """
359        self.build_coefficient_matrix_B(point_coordinates, data_origin)
360
361    def build_coefficient_matrix_B(self, point_coordinates=None,
362                                   verbose = False, expand_search = True,
363                                   max_points_per_cell=30,
364                                   data_origin = None,
365                                   precrop = False):
366        """Build final coefficient matrix"""
367
368
369        if self.alpha <> 0:
370            if verbose: print 'Building smoothing matrix'
371            self.build_smoothing_matrix_D()
372
373        if point_coordinates is not None:
374
375            if verbose: print 'Building interpolation matrix'
376            self.build_interpolation_matrix_A(point_coordinates,
377                                              verbose = verbose,
378                                              expand_search = expand_search,
379                                              max_points_per_cell =\
380                                              max_points_per_cell,
381                                              data_origin = data_origin,
382                                              precrop = precrop)
383
384            if self.alpha <> 0:
385                self.B = self.AtA + self.alpha*self.D
386            else:
387                self.B = self.AtA
388
389            #Convert self.B matrix to CSR format for faster matrix vector
390            self.B = Sparse_CSR(self.B)
391
392    def build_interpolation_matrix_A(self, point_coordinates,
393                                     verbose = False, expand_search = True,
394                                     max_points_per_cell=30,
395                                     data_origin = None,
396                                     precrop = False):
397        """Build n x m interpolation matrix, where
398        n is the number of data points and
399        m is the number of basis functions phi_k (one per vertex)
400
401        This algorithm uses a quad tree data structure for fast binning of data points
402        origin is a 3-tuple consisting of UTM zone, easting and northing.
403        If specified coordinates are assumed to be relative to this origin.
404
405        This one will override any data_origin that may be specified in
406        interpolation instance
407
408        """
409
410
411        #FIXME (Ole): Check that this function is memeory efficient.
412        #6 million datapoints and 300000 basis functions
413        #causes out-of-memory situation
414        #First thing to check is whether there is room for self.A and self.AtA
415        #
416        #Maybe we need some sort of blocking
417
418        from quad import build_quadtree
419        from util import ensure_numeric
420
421        if data_origin is None:
422            data_origin = self.data_origin #Use the one from
423                                           #interpolation instance
424
425        #Convert input to Numeric arrays just in case.
426        point_coordinates = ensure_numeric(point_coordinates, Float)
427
428        #Keep track of discarded points (if any).
429        #This is only registered if precrop is True
430        self.cropped_points = False
431
432        #Shift data points to same origin as mesh (if specified)
433
434        #FIXME this will shift if there was no geo_ref.
435        #But all this should be removed anyhow.
436        #change coords before this point
437        mesh_origin = self.mesh.geo_reference.get_origin()
438        if point_coordinates is not None:
439            if data_origin is not None:
440                if mesh_origin is not None:
441
442                    #Transformation:
443                    #
444                    #Let x_0 be the reference point of the point coordinates
445                    #and xi_0 the reference point of the mesh.
446                    #
447                    #A point coordinate (x + x_0) is then made relative
448                    #to xi_0 by
449                    #
450                    # x_new = x + x_0 - xi_0
451                    #
452                    #and similarly for eta
453
454                    x_offset = data_origin[1] - mesh_origin[1]
455                    y_offset = data_origin[2] - mesh_origin[2]
456                else: #Shift back to a zero origin
457                    x_offset = data_origin[1]
458                    y_offset = data_origin[2]
459
460                point_coordinates[:,0] += x_offset
461                point_coordinates[:,1] += y_offset
462            else:
463                if mesh_origin is not None:
464                    #Use mesh origin for data points
465                    point_coordinates[:,0] -= mesh_origin[1]
466                    point_coordinates[:,1] -= mesh_origin[2]
467
468
469
470        #Remove points falling outside mesh boundary
471        #This reduced one example from 1356 seconds to 825 seconds
472        if precrop is True:
473            from Numeric import take
474            from util import inside_polygon
475
476            if verbose: print 'Getting boundary polygon'
477            P = self.mesh.get_boundary_polygon()
478
479            if verbose: print 'Getting indices inside mesh boundary'
480            indices = inside_polygon(point_coordinates, P, verbose = verbose)
481
482
483            if len(indices) != point_coordinates.shape[0]:
484                self.cropped_points = True
485                if verbose:
486                    print 'Done - %d points outside mesh have been cropped.'\
487                          %(point_coordinates.shape[0] - len(indices))
488
489            point_coordinates = take(point_coordinates, indices)
490            self.point_indices = indices
491
492
493
494
495        #Build n x m interpolation matrix
496        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
497        n = point_coordinates.shape[0]     #Nbr of data points
498
499        if verbose: print 'Number of datapoints: %d' %n
500        if verbose: print 'Number of basis functions: %d' %m
501
502        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
503        #However, Sparse_CSR does not have the same methods as Sparse yet
504        #The tests will reveal what needs to be done
505        #self.A = Sparse_CSR(Sparse(n,m))
506        #self.AtA = Sparse_CSR(Sparse(m,m))
507        self.A = Sparse(n,m)
508        self.AtA = Sparse(m,m)
509
510        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
511        root = build_quadtree(self.mesh,
512                              max_points_per_cell = max_points_per_cell)
513
514        #Compute matrix elements
515        for i in range(n):
516            #For each data_coordinate point
517
518            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
519            x = point_coordinates[i]
520
521            #Find vertices near x
522            candidate_vertices = root.search(x[0], x[1])
523            is_more_elements = True
524
525            element_found, sigma0, sigma1, sigma2, k = \
526                self.search_triangles_of_vertices(candidate_vertices, x)
527            while not element_found and is_more_elements and expand_search:
528                #if verbose: print 'Expanding search'
529                candidate_vertices, branch = root.expand_search()
530                if branch == []:
531                    # Searching all the verts from the root cell that haven't
532                    # been searched.  This is the last try
533                    element_found, sigma0, sigma1, sigma2, k = \
534                      self.search_triangles_of_vertices(candidate_vertices, x)
535                    is_more_elements = False
536                else:
537                    element_found, sigma0, sigma1, sigma2, k = \
538                      self.search_triangles_of_vertices(candidate_vertices, x)
539
540
541            #Update interpolation matrix A if necessary
542            if element_found is True:
543                #Assign values to matrix A
544
545                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
546                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
547                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
548
549                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
550                js     = [j0,j1,j2]
551
552                for j in js:
553                    self.A[i,j] = sigmas[j]
554                    for k in js:
555                        self.AtA[j,k] += sigmas[j]*sigmas[k]
556            else:
557                pass
558                #Ok if there is no triangle for datapoint
559                #(as in brute force version)
560                #raise 'Could not find triangle for point', x
561
562
563
564    def search_triangles_of_vertices(self, candidate_vertices, x):
565            #Find triangle containing x:
566            element_found = False
567
568            # This will be returned if element_found = False
569            sigma2 = -10.0
570            sigma0 = -10.0
571            sigma1 = -10.0
572            k = -10.0
573
574            #For all vertices in same cell as point x
575            for v in candidate_vertices:
576
577                #for each triangle id (k) which has v as a vertex
578                for k, _ in self.mesh.vertexlist[v]:
579
580                    #Get the three vertex_points of candidate triangle
581                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
582                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
583                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
584
585                    #print "PDSG - k", k
586                    #print "PDSG - xi0", xi0
587                    #print "PDSG - xi1", xi1
588                    #print "PDSG - xi2", xi2
589                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
590                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
591
592                    #Get the three normals
593                    n0 = self.mesh.get_normal(k, 0)
594                    n1 = self.mesh.get_normal(k, 1)
595                    n2 = self.mesh.get_normal(k, 2)
596
597
598                    #Compute interpolation
599                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
600                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
601                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
602
603                    #print "PDSG - sigma0", sigma0
604                    #print "PDSG - sigma1", sigma1
605                    #print "PDSG - sigma2", sigma2
606
607                    #FIXME: Maybe move out to test or something
608                    epsilon = 1.0e-6
609                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
610
611                    #Check that this triangle contains the data point
612
613                    #Sigmas can get negative within
614                    #machine precision on some machines (e.g nautilus)
615                    #Hence the small eps
616                    eps = 1.0e-15
617                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
618                        element_found = True
619                        break
620
621                if element_found is True:
622                    #Don't look for any other triangle
623                    break
624            return element_found, sigma0, sigma1, sigma2, k
625
626
627
628    def build_interpolation_matrix_A_brute(self, point_coordinates):
629        """Build n x m interpolation matrix, where
630        n is the number of data points and
631        m is the number of basis functions phi_k (one per vertex)
632
633        This is the brute force which is too slow for large problems,
634        but could be used for testing
635        """
636
637        from util import ensure_numeric
638
639        #Convert input to Numeric arrays
640        point_coordinates = ensure_numeric(point_coordinates, Float)
641
642        #Build n x m interpolation matrix
643        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
644        n = point_coordinates.shape[0]     #Nbr of data points
645
646        self.A = Sparse(n,m)
647        self.AtA = Sparse(m,m)
648
649        #Compute matrix elements
650        for i in range(n):
651            #For each data_coordinate point
652
653            x = point_coordinates[i]
654            element_found = False
655            k = 0
656            while not element_found and k < len(self.mesh):
657                #For each triangle (brute force)
658                #FIXME: Real algorithm should only visit relevant triangles
659
660                #Get the three vertex_points
661                xi0 = self.mesh.get_vertex_coordinate(k, 0)
662                xi1 = self.mesh.get_vertex_coordinate(k, 1)
663                xi2 = self.mesh.get_vertex_coordinate(k, 2)
664
665                #Get the three normals
666                n0 = self.mesh.get_normal(k, 0)
667                n1 = self.mesh.get_normal(k, 1)
668                n2 = self.mesh.get_normal(k, 2)
669
670                #Compute interpolation
671                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
672                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
673                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
674
675                #FIXME: Maybe move out to test or something
676                epsilon = 1.0e-6
677                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
678
679                #Check that this triangle contains data point
680                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
681                    element_found = True
682                    #Assign values to matrix A
683
684                    j0 = self.mesh.triangles[k,0] #Global vertex id
685                    #self.A[i, j0] = sigma0
686
687                    j1 = self.mesh.triangles[k,1] #Global vertex id
688                    #self.A[i, j1] = sigma1
689
690                    j2 = self.mesh.triangles[k,2] #Global vertex id
691                    #self.A[i, j2] = sigma2
692
693                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
694                    js     = [j0,j1,j2]
695
696                    for j in js:
697                        self.A[i,j] = sigmas[j]
698                        for k in js:
699                            self.AtA[j,k] += sigmas[j]*sigmas[k]
700                k = k+1
701
702
703
704    def get_A(self):
705        return self.A.todense()
706
707    def get_B(self):
708        return self.B.todense()
709
710    def get_D(self):
711        return self.D.todense()
712
713        #FIXME: Remember to re-introduce the 1/n factor in the
714        #interpolation term
715
716    def build_smoothing_matrix_D(self):
717        """Build m x m smoothing matrix, where
718        m is the number of basis functions phi_k (one per vertex)
719
720        The smoothing matrix is defined as
721
722        D = D1 + D2
723
724        where
725
726        [D1]_{k,l} = \int_\Omega
727           \frac{\partial \phi_k}{\partial x}
728           \frac{\partial \phi_l}{\partial x}\,
729           dx dy
730
731        [D2]_{k,l} = \int_\Omega
732           \frac{\partial \phi_k}{\partial y}
733           \frac{\partial \phi_l}{\partial y}\,
734           dx dy
735
736
737        The derivatives \frac{\partial \phi_k}{\partial x},
738        \frac{\partial \phi_k}{\partial x} for a particular triangle
739        are obtained by computing the gradient a_k, b_k for basis function k
740        """
741
742        #FIXME: algorithm might be optimised by computing local 9x9
743        #"element stiffness matrices:
744
745        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
746
747        self.D = Sparse(m,m)
748
749        #For each triangle compute contributions to D = D1+D2
750        for i in range(len(self.mesh)):
751
752            #Get area
753            area = self.mesh.areas[i]
754
755            #Get global vertex indices
756            v0 = self.mesh.triangles[i,0]
757            v1 = self.mesh.triangles[i,1]
758            v2 = self.mesh.triangles[i,2]
759
760            #Get the three vertex_points
761            xi0 = self.mesh.get_vertex_coordinate(i, 0)
762            xi1 = self.mesh.get_vertex_coordinate(i, 1)
763            xi2 = self.mesh.get_vertex_coordinate(i, 2)
764
765            #Compute gradients for each vertex
766            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
767                              1, 0, 0)
768
769            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
770                              0, 1, 0)
771
772            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
773                              0, 0, 1)
774
775            #Compute diagonal contributions
776            self.D[v0,v0] += (a0*a0 + b0*b0)*area
777            self.D[v1,v1] += (a1*a1 + b1*b1)*area
778            self.D[v2,v2] += (a2*a2 + b2*b2)*area
779
780            #Compute contributions for basis functions sharing edges
781            e01 = (a0*a1 + b0*b1)*area
782            self.D[v0,v1] += e01
783            self.D[v1,v0] += e01
784
785            e12 = (a1*a2 + b1*b2)*area
786            self.D[v1,v2] += e12
787            self.D[v2,v1] += e12
788
789            e20 = (a2*a0 + b2*b0)*area
790            self.D[v2,v0] += e20
791            self.D[v0,v2] += e20
792
793
794    def fit(self, z):
795        """Fit a smooth surface to given 1d array of data points z.
796
797        The smooth surface is computed at each vertex in the underlying
798        mesh using the formula given in the module doc string.
799
800        Pre Condition:
801          self.A, self.AtA and self.B have been initialised
802
803        Inputs:
804          z: Single 1d vector or array of data at the point_coordinates.
805        """
806
807        #Convert input to Numeric arrays
808        from util import ensure_numeric
809        z = ensure_numeric(z, Float)
810
811        if len(z.shape) > 1 :
812            raise VectorShapeError, 'Can only deal with 1d data vector'
813
814        if self.point_indices is not None:
815            #Remove values for any points that were outside mesh
816            z = take(z, self.point_indices)
817
818        #Compute right hand side based on data
819        Atz = self.A.trans_mult(z)
820
821
822        #Check sanity
823        n, m = self.A.shape
824        if n<m and self.alpha == 0.0:
825            msg = 'ERROR (least_squares): Too few data points\n'
826            msg += 'There are only %d data points and alpha == 0. ' %n
827            msg += 'Need at least %d\n' %m
828            msg += 'Alternatively, set smoothing parameter alpha to a small '
829            msg += 'positive value,\ne.g. 1.0e-3.'
830            raise msg
831
832
833
834        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
835        #FIXME: Should we store the result here for later use? (ON)
836
837
838    def fit_points(self, z, verbose=False):
839        """Like fit, but more robust when each point has two or more attributes
840        FIXME (Ole): The name fit_points doesn't carry any meaning
841        for me. How about something like fit_multiple or fit_columns?
842        """
843
844        try:
845            if verbose: print 'Solving penalised least_squares problem'
846            return self.fit(z)
847        except VectorShapeError, e:
848            # broadcasting is not supported.
849
850            #Convert input to Numeric arrays
851            from util import ensure_numeric
852            z = ensure_numeric(z, Float)
853
854            #Build n x m interpolation matrix
855            m = self.mesh.coordinates.shape[0] #Number of vertices
856            n = z.shape[1]                     #Number of data points
857
858            f = zeros((m,n), Float) #Resulting columns
859
860            for i in range(z.shape[1]):
861                f[:,i] = self.fit(z[:,i])
862
863            return f
864
865
866    def interpolate(self, f):
867        """Evaluate smooth surface f at data points implied in self.A.
868
869        The mesh values representing a smooth surface are
870        assumed to be specified in f. This argument could,
871        for example have been obtained from the method self.fit()
872
873        Pre Condition:
874          self.A has been initialised
875
876        Inputs:
877          f: Vector or array of data at the mesh vertices.
878          If f is an array, interpolation will be done for each column as
879          per underlying matrix-matrix multiplication
880
881        Output:
882          Interpolated values at data points implied in self.A
883
884        """
885
886        return self.A * f
887
888    def cull_outsiders(self, f):
889        pass
890
891
892
893
894class Interpolation_function:
895    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
896    which is interpolated from time series defined at vertices of
897    triangular mesh (such as those stored in sww files)
898
899    Let m be the number of vertices, n the number of triangles
900    and p the number of timesteps.
901
902    Mandatory input
903        time:               px1 array of monotonously increasing times (Float)
904        quantities:         Dictionary of pxm arrays or 1 pxm array (Float)
905       
906    Optional input:
907        quantity_names:     List of keys into the quantities dictionary
908        vertex_coordinates: mx2 array of coordinates (Float)
909        triangles:          nx3 array of indices into vertex_coordinates (Int)
910        interpolation_points: array of coordinates to be interpolated to
911        verbose:            Level of reporting
912   
913   
914    The quantities returned by the callable object are specified by
915    the list quantities which must contain the names of the
916    quantities to be returned and also reflect the order, e.g. for
917    the shallow water wave equation, on would have
918    quantities = ['stage', 'xmomentum', 'ymomentum']
919
920    The parameter interpolation_points decides at which points interpolated
921    quantities are to be computed whenever object is called.
922    If None, return average value
923    """
924
925   
926   
927    def __init__(self,
928                 time,
929                 quantities,
930                 quantity_names = None, 
931                 vertex_coordinates = None,
932                 triangles = None,
933                 interpolation_points = None,
934                 verbose = False):
935        """Initialise object and build spatial interpolation if required
936        """
937
938        from Numeric import array, zeros, Float, alltrue, concatenate,\
939             reshape, ArrayType
940
941        from util import mean, ensure_numeric
942        from config import time_format
943        import types
944
945
946
947        #Check temporal info
948        time = ensure_numeric(time)       
949        msg = 'Time must be a monotonuosly '
950        msg += 'increasing sequence %s' %time
951        assert alltrue(time[1:] - time[:-1] > 0 ), msg
952
953
954        #Check if quantities is a single array only
955        if type(quantities) != types.DictType:
956            quantities = ensure_numeric(quantities)
957            quantity_names = ['Attribute']
958
959            #Make it a dictionary
960            quantities = {quantity_names[0]: quantities}
961
962
963        #Use keys if no names are specified
964        if quantity_names is None:
965            quantity_names = quantities.keys()
966
967
968        #Check spatial info
969        if vertex_coordinates is None:
970            self.spatial = False
971        else:   
972            vertex_coordinates = ensure_numeric(vertex_coordinates)
973
974            assert triangles is not None, 'Triangles array must be specified'
975            triangles = ensure_numeric(triangles)
976            self.spatial = True           
977           
978
979 
980        #Save for use with statistics
981        self.quantity_names = quantity_names       
982        self.quantities = quantities       
983        self.vertex_coordinates = vertex_coordinates
984        self.interpolation_points = interpolation_points
985        self.T = time[:]  #Time assumed to be relative to starttime
986        self.index = 0    #Initial time index
987        self.precomputed_values = {}
988           
989
990
991        #Precomputed spatial interpolation if requested
992        if interpolation_points is not None:
993            if self.spatial is False:
994                raise 'Triangles and vertex_coordinates must be specified'
995           
996
997            try:
998                self.interpolation_points =\
999                                          ensure_numeric(self.interpolation_points)
1000            except:
1001                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1002                      'or a list of points\n'
1003                msg += 'I got: %s.' %( str(self.interpolation_points)[:60] + '...')
1004                raise msg
1005
1006
1007            for name in quantity_names:
1008                self.precomputed_values[name] =\
1009                                              zeros((len(self.T),
1010                                                     len(self.interpolation_points)),
1011                                                    Float)
1012
1013            #Build interpolator
1014            interpol = Interpolation(vertex_coordinates,
1015                                     triangles,
1016                                     point_coordinates = self.interpolation_points,
1017                                     alpha = 0,
1018                                     precrop = False, 
1019                                     verbose = verbose)
1020
1021            if verbose: print 'Interpolate'
1022            n = len(self.T)
1023            for i, t in enumerate(self.T):
1024                #Interpolate quantities at this timestep
1025                if verbose and i%((n+10)/10)==0:
1026                    print ' time step %d of %d' %(i, n)
1027                   
1028                for name in quantity_names:
1029                    self.precomputed_values[name][i, :] =\
1030                    interpol.interpolate(quantities[name][i,:])
1031
1032            #Report
1033            if verbose:
1034                print self.statistics()
1035                #self.print_statistics()
1036           
1037        else:
1038            #Store quantitites as is
1039            for name in quantity_names:
1040                self.precomputed_values[name] = quantities[name]
1041
1042
1043        #else:
1044        #    #Return an average, making this a time series
1045        #    for name in quantity_names:
1046        #        self.values[name] = zeros(len(self.T), Float)
1047        #
1048        #    if verbose: print 'Compute mean values'
1049        #    for i, t in enumerate(self.T):
1050        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1051        #        for name in quantity_names:
1052        #           self.values[name][i] = mean(quantities[name][i,:])
1053
1054
1055
1056
1057    def __repr__(self):
1058        #return 'Interpolation function (spation-temporal)'
1059        return self.statistics()
1060   
1061
1062    def __call__(self, t, point_id = None, x = None, y = None):
1063        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1064
1065        Inputs:
1066          t: time - Model time. Must lie within existing timesteps
1067          point_id: index of one of the preprocessed points.
1068          x, y:     Overrides location, point_id ignored
1069         
1070          If spatial info is present and all of x,y,point_id
1071          are None an exception is raised
1072                   
1073          If no spatial info is present, point_id and x,y arguments are ignored
1074          making f a function of time only.
1075
1076         
1077          FIXME: point_id could also be a slice
1078          FIXME: What if x and y are vectors?
1079          FIXME: What about f(x,y) without t?
1080        """
1081
1082        from math import pi, cos, sin, sqrt
1083        from Numeric import zeros, Float
1084        from util import mean       
1085
1086        if self.spatial is True:
1087            if point_id is None:
1088                if x is None or y is None:
1089                    msg = 'Either point_id or x and y must be specified'
1090                    raise msg
1091            else:
1092                if self.interpolation_points is None:
1093                    msg = 'Interpolation_function must be instantiated ' +\
1094                          'with a list of interpolation points before parameter ' +\
1095                          'point_id can be used'
1096                    raise msg
1097
1098
1099        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1100        msg += ' does not match model time: %s\n' %t
1101        if t < self.T[0]: raise msg
1102        if t > self.T[-1]: raise msg
1103
1104        oldindex = self.index #Time index
1105
1106        #Find current time slot
1107        while t > self.T[self.index]: self.index += 1
1108        while t < self.T[self.index]: self.index -= 1
1109
1110        if t == self.T[self.index]:
1111            #Protect against case where t == T[-1] (last time)
1112            # - also works in general when t == T[i]
1113            ratio = 0
1114        else:
1115            #t is now between index and index+1
1116            ratio = (t - self.T[self.index])/\
1117                    (self.T[self.index+1] - self.T[self.index])
1118
1119        #Compute interpolated values
1120        q = zeros(len(self.quantity_names), Float)
1121
1122        for i, name in enumerate(self.quantity_names):
1123            Q = self.precomputed_values[name]
1124
1125            if self.spatial is False:
1126                #If there is no spatial info               
1127                assert len(Q.shape) == 1
1128
1129                Q0 = Q[self.index]
1130                if ratio > 0: Q1 = Q[self.index+1]
1131
1132            else:
1133                if x is not None and y is not None:
1134                    #Interpolate to x, y
1135                   
1136                    raise 'x,y interpolation not yet implemented'
1137                else:
1138                    #Use precomputed point
1139                    Q0 = Q[self.index, point_id]
1140                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1141
1142            #Linear temporal interpolation   
1143            if ratio > 0:
1144                q[i] = Q0 + ratio*(Q1 - Q0)
1145            else:
1146                q[i] = Q0
1147
1148
1149        #Return vector of interpolated values
1150        #if len(q) == 1:
1151        #    return q[0]
1152        #else:
1153        #    return q
1154
1155
1156        #Return vector of interpolated values
1157        #FIXME:
1158        if self.spatial is True:
1159            return q
1160        else:
1161            #Replicate q according to x and y
1162            #This is e.g used for Wind_stress
1163            if x == None or y == None: 
1164                return q
1165            else:
1166                try:
1167                    N = len(x)
1168                except:
1169                    return q
1170                else:
1171                    from Numeric import ones, Float
1172                    #x is a vector - Create one constant column for each value
1173                    N = len(x)
1174                    assert len(y) == N, 'x and y must have same length'
1175                    res = []
1176                    for col in q:
1177                        res.append(col*ones(N, Float))
1178                       
1179                return res
1180
1181
1182    def statistics(self):
1183        """Output statistics about interpolation_function
1184        """
1185       
1186        vertex_coordinates = self.vertex_coordinates
1187        interpolation_points = self.interpolation_points               
1188        quantity_names = self.quantity_names
1189        quantities = self.quantities
1190        precomputed_values = self.precomputed_values                 
1191               
1192        x = vertex_coordinates[:,0]
1193        y = vertex_coordinates[:,1]               
1194
1195        str =  '------------------------------------------------\n'
1196        str += 'Interpolation_function (spation-temporal) statistics:\n'
1197        str += '  Extent:\n'
1198        str += '    x in [%f, %f], len(x) == %d\n'\
1199               %(min(x), max(x), len(x))
1200        str += '    y in [%f, %f], len(y) == %d\n'\
1201               %(min(y), max(y), len(y))
1202        str += '    t in [%f, %f], len(t) == %d\n'\
1203               %(min(self.T), max(self.T), len(self.T))
1204        str += '  Quantities:\n'
1205        for name in quantity_names:
1206            q = quantities[name][:].flat
1207            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1208
1209        if interpolation_points is not None:   
1210            str += '  Interpolation points (xi, eta):'\
1211                   ' number of points == %d\n' %interpolation_points.shape[0]
1212            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1213                                            max(interpolation_points[:,0]))
1214            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1215                                             max(interpolation_points[:,1]))
1216            str += '  Interpolated quantities (over all timesteps):\n'
1217       
1218            for name in quantity_names:
1219                q = precomputed_values[name][:].flat
1220                str += '    %s at interpolation points in [%f, %f]\n'\
1221                       %(name, min(q), max(q))
1222        str += '------------------------------------------------\n'
1223
1224        return str
1225
1226        #FIXME: Delete
1227        #print '------------------------------------------------'
1228        #print 'Interpolation_function statistics:'
1229        #print '  Extent:'
1230        #print '    x in [%f, %f], len(x) == %d'\
1231        #      %(min(x), max(x), len(x))
1232        #print '    y in [%f, %f], len(y) == %d'\
1233        #      %(min(y), max(y), len(y))
1234        #print '    t in [%f, %f], len(t) == %d'\
1235        #      %(min(self.T), max(self.T), len(self.T))
1236        #print '  Quantities:'
1237        #for name in quantity_names:
1238        #    q = quantities[name][:].flat
1239        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1240        #print '  Interpolation points (xi, eta):'\
1241        #      ' number of points == %d ' %interpolation_points.shape[0]
1242        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1243        #                             max(interpolation_points[:,0]))
1244        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1245        #                              max(interpolation_points[:,1]))
1246        #print '  Interpolated quantities (over all timesteps):'
1247        #
1248        #for name in quantity_names:
1249        #    q = precomputed_values[name][:].flat
1250        #    print '    %s at interpolation points in [%f, %f]'\
1251        #          %(name, min(q), max(q))
1252        #print '------------------------------------------------'
1253
1254
1255#-------------------------------------------------------------
1256if __name__ == "__main__":
1257    """
1258    Load in a mesh and data points with attributes.
1259    Fit the attributes to the mesh.
1260    Save a new mesh file.
1261    """
1262    import os, sys
1263    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha]"\
1264            %os.path.basename(sys.argv[0])
1265
1266    if len(sys.argv) < 4:
1267        print usage
1268    else:
1269        mesh_file = sys.argv[1]
1270        point_file = sys.argv[2]
1271        mesh_output_file = sys.argv[3]
1272
1273        expand_search = False
1274        if len(sys.argv) > 4:
1275            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1276                expand_search = True
1277            else:
1278                expand_search = False
1279
1280        verbose = False
1281        if len(sys.argv) > 5:
1282            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1283                verbose = False
1284            else:
1285                verbose = True
1286
1287        if len(sys.argv) > 6:
1288            alpha = sys.argv[6]
1289        else:
1290            alpha = DEFAULT_ALPHA
1291
1292        t0 = time.time()
1293        fit_to_mesh_file(mesh_file,
1294                         point_file,
1295                         mesh_output_file,
1296                         alpha,
1297                         verbose= verbose,
1298                         expand_search = expand_search)
1299
1300        print 'That took %.2f seconds' %(time.time()-t0)
1301
Note: See TracBrowser for help on using the repository browser.