source: inundation/pyvolution/least_squares.py @ 1916

Last change on this file since 1916 was 1911, checked in by ole, 18 years ago

Refactored pyvolution to use polygon functionality from new utilities package

File size: 47.1 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME (DSG-ON) reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203    interp = Interpolation(vertex_coordinates,
204                           triangles,
205                           point_coordinates,
206                           alpha = alpha,
207                           verbose = verbose,
208                           expand_search = expand_search,
209                           data_origin = data_origin,
210                           mesh_origin = mesh_origin,
211                           precrop = precrop)
212
213    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
214    return vertex_attributes
215
216
217
218def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
219                    verbose = False, reduction = 1):
220    """Fits attributes from pts file to MxN rectangular mesh
221
222    Read pts file and create rectangular mesh of resolution MxN such that
223    it covers all points specified in pts file.
224
225    FIXME: This may be a temporary function until we decide on
226    netcdf formats etc
227
228    FIXME: Uses elevation hardwired
229    """
230
231    import  mesh_factory
232    from load_mesh.loadASCII import import_points_file
233   
234    if verbose: print 'Read pts'
235    points_dict = import_points_file(pts_name)
236    #points, attributes = util.read_xya(pts_name)
237
238    #Reduce number of points a bit
239    points = points_dict['pointlist'][::reduction]
240    elevation = points_dict['attributelist']['elevation']  #Must be elevation
241    elevation = elevation[::reduction]
242
243    if verbose: print 'Got %d data points' %len(points)
244
245    if verbose: print 'Create mesh'
246    #Find extent
247    max_x = min_x = points[0][0]
248    max_y = min_y = points[0][1]
249    for point in points[1:]:
250        x = point[0]
251        if x > max_x: max_x = x
252        if x < min_x: min_x = x
253        y = point[1]
254        if y > max_y: max_y = y
255        if y < min_y: min_y = y
256
257    #Create appropriate mesh
258    vertex_coordinates, triangles, boundary =\
259         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
260                                (min_x, min_y))
261
262    #Fit attributes to mesh
263    vertex_attributes = fit_to_mesh(vertex_coordinates,
264                        triangles,
265                        points,
266                        elevation, alpha=alpha, verbose=verbose)
267
268
269
270    return vertex_coordinates, triangles, boundary, vertex_attributes
271
272
273
274class Interpolation:
275
276    def __init__(self,
277                 vertex_coordinates,
278                 triangles,
279                 point_coordinates = None,
280                 alpha = None,
281                 verbose = False,
282                 expand_search = True,
283                 interp_only = False,
284                 max_points_per_cell = 30,
285                 mesh_origin = None,
286                 data_origin = None,
287                 precrop = False):
288
289
290        """ Build interpolation matrix mapping from
291        function values at vertices to function values at data points
292
293        Inputs:
294
295          vertex_coordinates: List of coordinate pairs [xi, eta] of
296          points constituting mesh (or a an m x 2 Numeric array)
297          Points may appear multiple times
298          (e.g. if vertices have discontinuities)
299
300          triangles: List of 3-tuples (or a Numeric array) of
301          integers representing indices of all vertices in the mesh.
302
303          point_coordinates: List of coordinate pairs [x, y] of
304          data points (or an nx2 Numeric array)
305          If point_coordinates is absent, only smoothing matrix will
306          be built
307
308          alpha: Smoothing parameter
309
310          data_origin and mesh_origin are 3-tuples consisting of
311          UTM zone, easting and northing. If specified
312          point coordinates and vertex coordinates are assumed to be
313          relative to their respective origins.
314
315        """
316        from util import ensure_numeric
317
318        #Convert input to Numeric arrays
319        triangles = ensure_numeric(triangles, Int)
320        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
321
322        #Build underlying mesh
323        if verbose: print 'Building mesh'
324        #self.mesh = General_mesh(vertex_coordinates, triangles,
325        #FIXME: Trying the normal mesh while testing precrop,
326        #       The functionality of boundary_polygon is needed for that
327
328        #FIXME - geo ref does not have to go into mesh.
329        # Change the point co-ords to conform to the
330        # mesh co-ords early in the code
331        if mesh_origin == None:
332            geo = None
333        else:
334            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
335        self.mesh = Mesh(vertex_coordinates, triangles,
336                         geo_reference = geo)
337       
338        self.mesh.check_integrity()
339
340        self.data_origin = data_origin
341
342        self.point_indices = None
343
344        #Smoothing parameter
345        if alpha is None:
346            self.alpha = DEFAULT_ALPHA
347        else:   
348            self.alpha = alpha
349
350
351        if point_coordinates is not None:
352            if verbose: print 'Building interpolation matrix'
353            self.build_interpolation_matrix_A(point_coordinates,
354                                              verbose = verbose,
355                                              expand_search = expand_search,
356                                              interp_only = interp_only, 
357                                              max_points_per_cell =\
358                                              max_points_per_cell,
359                                              data_origin = data_origin,
360                                              precrop = precrop)
361        #Build coefficient matrices
362        if interp_only == False:
363            self.build_coefficient_matrix_B(point_coordinates,
364                                        verbose = verbose,
365                                        expand_search = expand_search,
366                                        max_points_per_cell =\
367                                        max_points_per_cell,
368                                        data_origin = data_origin,
369                                        precrop = precrop)
370
371
372    def set_point_coordinates(self, point_coordinates,
373                              data_origin = None,
374                              verbose = False,
375                              precrop = True):
376        """
377        A public interface to setting the point co-ordinates.
378        """
379        if point_coordinates is not None:
380            if verbose: print 'Building interpolation matrix'
381            self.build_interpolation_matrix_A(point_coordinates,
382                                              verbose = verbose,
383                                              data_origin = data_origin,
384                                              precrop = precrop)
385        self.build_coefficient_matrix_B(point_coordinates, data_origin)
386
387    def build_coefficient_matrix_B(self, point_coordinates=None,
388                                   verbose = False, expand_search = True,
389                                   max_points_per_cell=30,
390                                   data_origin = None,
391                                   precrop = False):
392        """Build final coefficient matrix"""
393
394
395        if self.alpha <> 0:
396            if verbose: print 'Building smoothing matrix'
397            self.build_smoothing_matrix_D()
398
399        if point_coordinates is not None:
400            if self.alpha <> 0:
401                self.B = self.AtA + self.alpha*self.D
402            else:
403                self.B = self.AtA
404
405            #Convert self.B matrix to CSR format for faster matrix vector
406            self.B = Sparse_CSR(self.B)
407
408    def build_interpolation_matrix_A(self, point_coordinates,
409                                     verbose = False, expand_search = True,
410                                     max_points_per_cell=30,
411                                     data_origin = None,
412                                     precrop = False,
413                                     interp_only = False):
414        """Build n x m interpolation matrix, where
415        n is the number of data points and
416        m is the number of basis functions phi_k (one per vertex)
417
418        This algorithm uses a quad tree data structure for fast binning of data points
419        origin is a 3-tuple consisting of UTM zone, easting and northing.
420        If specified coordinates are assumed to be relative to this origin.
421
422        This one will override any data_origin that may be specified in
423        interpolation instance
424
425        """
426
427
428        #FIXME (Ole): Check that this function is memeory efficient.
429        #6 million datapoints and 300000 basis functions
430        #causes out-of-memory situation
431        #First thing to check is whether there is room for self.A and self.AtA
432        #
433        #Maybe we need some sort of blocking
434
435        from quad import build_quadtree
436        from utilities.numerical_tools import ensure_numeric
437        from utilities.polygon import inside_polygon
438       
439
440        if data_origin is None:
441            data_origin = self.data_origin #Use the one from
442                                           #interpolation instance
443
444        #Convert input to Numeric arrays just in case.
445        point_coordinates = ensure_numeric(point_coordinates, Float)
446
447        #Keep track of discarded points (if any).
448        #This is only registered if precrop is True
449        self.cropped_points = False
450
451        #Shift data points to same origin as mesh (if specified)
452
453        #FIXME this will shift if there was no geo_ref.
454        #But all this should be removed anyhow.
455        #change coords before this point
456        mesh_origin = self.mesh.geo_reference.get_origin()
457        if point_coordinates is not None:
458            if data_origin is not None:
459                if mesh_origin is not None:
460
461                    #Transformation:
462                    #
463                    #Let x_0 be the reference point of the point coordinates
464                    #and xi_0 the reference point of the mesh.
465                    #
466                    #A point coordinate (x + x_0) is then made relative
467                    #to xi_0 by
468                    #
469                    # x_new = x + x_0 - xi_0
470                    #
471                    #and similarly for eta
472
473                    x_offset = data_origin[1] - mesh_origin[1]
474                    y_offset = data_origin[2] - mesh_origin[2]
475                else: #Shift back to a zero origin
476                    x_offset = data_origin[1]
477                    y_offset = data_origin[2]
478
479                point_coordinates[:,0] += x_offset
480                point_coordinates[:,1] += y_offset
481            else:
482                if mesh_origin is not None:
483                    #Use mesh origin for data points
484                    point_coordinates[:,0] -= mesh_origin[1]
485                    point_coordinates[:,1] -= mesh_origin[2]
486
487
488
489        #Remove points falling outside mesh boundary
490        #This reduced one example from 1356 seconds to 825 seconds
491        if precrop is True:
492            from Numeric import take
493
494            if verbose: print 'Getting boundary polygon'
495            P = self.mesh.get_boundary_polygon()
496
497            if verbose: print 'Getting indices inside mesh boundary'
498            indices = inside_polygon(point_coordinates, P, verbose = verbose)
499
500
501            if len(indices) != point_coordinates.shape[0]:
502                self.cropped_points = True
503                if verbose:
504                    print 'Done - %d points outside mesh have been cropped.'\
505                          %(point_coordinates.shape[0] - len(indices))
506
507            point_coordinates = take(point_coordinates, indices)
508            self.point_indices = indices
509
510
511
512
513        #Build n x m interpolation matrix
514        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
515        n = point_coordinates.shape[0]     #Nbr of data points
516
517        if verbose: print 'Number of datapoints: %d' %n
518        if verbose: print 'Number of basis functions: %d' %m
519
520        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
521        #However, Sparse_CSR does not have the same methods as Sparse yet
522        #The tests will reveal what needs to be done
523
524        #
525        #self.A = Sparse_CSR(Sparse(n,m))
526        #self.AtA = Sparse_CSR(Sparse(m,m))
527        self.A = Sparse(n,m)
528        self.AtA = Sparse(m,m)
529
530        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
531        root = build_quadtree(self.mesh,
532                              max_points_per_cell = max_points_per_cell)
533
534        #Compute matrix elements
535        for i in range(n):
536            #For each data_coordinate point
537
538            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
539            x = point_coordinates[i]
540
541            #Find vertices near x
542            candidate_vertices = root.search(x[0], x[1])
543            is_more_elements = True
544
545            element_found, sigma0, sigma1, sigma2, k = \
546                self.search_triangles_of_vertices(candidate_vertices, x)
547            while not element_found and is_more_elements and expand_search:
548                #if verbose: print 'Expanding search'
549                candidate_vertices, branch = root.expand_search()
550                if branch == []:
551                    # Searching all the verts from the root cell that haven't
552                    # been searched.  This is the last try
553                    element_found, sigma0, sigma1, sigma2, k = \
554                      self.search_triangles_of_vertices(candidate_vertices, x)
555                    is_more_elements = False
556                else:
557                    element_found, sigma0, sigma1, sigma2, k = \
558                      self.search_triangles_of_vertices(candidate_vertices, x)
559
560
561            #Update interpolation matrix A if necessary
562            if element_found is True:
563                #Assign values to matrix A
564
565                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
566                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
567                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
568
569                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
570                js     = [j0,j1,j2]
571
572                for j in js:
573                    self.A[i,j] = sigmas[j]
574                    for k in js:
575                        if interp_only == False:
576                            self.AtA[j,k] += sigmas[j]*sigmas[k]
577            else:
578                pass
579                #Ok if there is no triangle for datapoint
580                #(as in brute force version)
581                #raise 'Could not find triangle for point', x
582
583
584
585    def search_triangles_of_vertices(self, candidate_vertices, x):
586            #Find triangle containing x:
587            element_found = False
588
589            # This will be returned if element_found = False
590            sigma2 = -10.0
591            sigma0 = -10.0
592            sigma1 = -10.0
593            k = -10.0
594
595            #For all vertices in same cell as point x
596            for v in candidate_vertices:
597
598                #for each triangle id (k) which has v as a vertex
599                for k, _ in self.mesh.vertexlist[v]:
600
601                    #Get the three vertex_points of candidate triangle
602                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
603                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
604                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
605
606                    #print "PDSG - k", k
607                    #print "PDSG - xi0", xi0
608                    #print "PDSG - xi1", xi1
609                    #print "PDSG - xi2", xi2
610                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
611                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
612
613                    #Get the three normals
614                    n0 = self.mesh.get_normal(k, 0)
615                    n1 = self.mesh.get_normal(k, 1)
616                    n2 = self.mesh.get_normal(k, 2)
617
618
619                    #Compute interpolation
620                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
621                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
622                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
623
624                    #print "PDSG - sigma0", sigma0
625                    #print "PDSG - sigma1", sigma1
626                    #print "PDSG - sigma2", sigma2
627
628                    #FIXME: Maybe move out to test or something
629                    epsilon = 1.0e-6
630                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
631
632                    #Check that this triangle contains the data point
633
634                    #Sigmas can get negative within
635                    #machine precision on some machines (e.g nautilus)
636                    #Hence the small eps
637                    eps = 1.0e-15
638                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
639                        element_found = True
640                        break
641
642                if element_found is True:
643                    #Don't look for any other triangle
644                    break
645            return element_found, sigma0, sigma1, sigma2, k
646
647
648
649    def build_interpolation_matrix_A_brute(self, point_coordinates):
650        """Build n x m interpolation matrix, where
651        n is the number of data points and
652        m is the number of basis functions phi_k (one per vertex)
653
654        This is the brute force which is too slow for large problems,
655        but could be used for testing
656        """
657
658        from util import ensure_numeric
659
660        #Convert input to Numeric arrays
661        point_coordinates = ensure_numeric(point_coordinates, Float)
662
663        #Build n x m interpolation matrix
664        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
665        n = point_coordinates.shape[0]     #Nbr of data points
666
667        self.A = Sparse(n,m)
668        self.AtA = Sparse(m,m)
669
670        #Compute matrix elements
671        for i in range(n):
672            #For each data_coordinate point
673
674            x = point_coordinates[i]
675            element_found = False
676            k = 0
677            while not element_found and k < len(self.mesh):
678                #For each triangle (brute force)
679                #FIXME: Real algorithm should only visit relevant triangles
680
681                #Get the three vertex_points
682                xi0 = self.mesh.get_vertex_coordinate(k, 0)
683                xi1 = self.mesh.get_vertex_coordinate(k, 1)
684                xi2 = self.mesh.get_vertex_coordinate(k, 2)
685
686                #Get the three normals
687                n0 = self.mesh.get_normal(k, 0)
688                n1 = self.mesh.get_normal(k, 1)
689                n2 = self.mesh.get_normal(k, 2)
690
691                #Compute interpolation
692                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
693                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
694                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
695
696                #FIXME: Maybe move out to test or something
697                epsilon = 1.0e-6
698                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
699
700                #Check that this triangle contains data point
701                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
702                    element_found = True
703                    #Assign values to matrix A
704
705                    j0 = self.mesh.triangles[k,0] #Global vertex id
706                    #self.A[i, j0] = sigma0
707
708                    j1 = self.mesh.triangles[k,1] #Global vertex id
709                    #self.A[i, j1] = sigma1
710
711                    j2 = self.mesh.triangles[k,2] #Global vertex id
712                    #self.A[i, j2] = sigma2
713
714                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
715                    js     = [j0,j1,j2]
716
717                    for j in js:
718                        self.A[i,j] = sigmas[j]
719                        for k in js:
720                            self.AtA[j,k] += sigmas[j]*sigmas[k]
721                k = k+1
722
723
724
725    def get_A(self):
726        return self.A.todense()
727
728    def get_B(self):
729        return self.B.todense()
730
731    def get_D(self):
732        return self.D.todense()
733
734        #FIXME: Remember to re-introduce the 1/n factor in the
735        #interpolation term
736
737    def build_smoothing_matrix_D(self):
738        """Build m x m smoothing matrix, where
739        m is the number of basis functions phi_k (one per vertex)
740
741        The smoothing matrix is defined as
742
743        D = D1 + D2
744
745        where
746
747        [D1]_{k,l} = \int_\Omega
748           \frac{\partial \phi_k}{\partial x}
749           \frac{\partial \phi_l}{\partial x}\,
750           dx dy
751
752        [D2]_{k,l} = \int_\Omega
753           \frac{\partial \phi_k}{\partial y}
754           \frac{\partial \phi_l}{\partial y}\,
755           dx dy
756
757
758        The derivatives \frac{\partial \phi_k}{\partial x},
759        \frac{\partial \phi_k}{\partial x} for a particular triangle
760        are obtained by computing the gradient a_k, b_k for basis function k
761        """
762
763        #FIXME: algorithm might be optimised by computing local 9x9
764        #"element stiffness matrices:
765
766        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
767
768        self.D = Sparse(m,m)
769
770        #For each triangle compute contributions to D = D1+D2
771        for i in range(len(self.mesh)):
772
773            #Get area
774            area = self.mesh.areas[i]
775
776            #Get global vertex indices
777            v0 = self.mesh.triangles[i,0]
778            v1 = self.mesh.triangles[i,1]
779            v2 = self.mesh.triangles[i,2]
780
781            #Get the three vertex_points
782            xi0 = self.mesh.get_vertex_coordinate(i, 0)
783            xi1 = self.mesh.get_vertex_coordinate(i, 1)
784            xi2 = self.mesh.get_vertex_coordinate(i, 2)
785
786            #Compute gradients for each vertex
787            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
788                              1, 0, 0)
789
790            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
791                              0, 1, 0)
792
793            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
794                              0, 0, 1)
795
796            #Compute diagonal contributions
797            self.D[v0,v0] += (a0*a0 + b0*b0)*area
798            self.D[v1,v1] += (a1*a1 + b1*b1)*area
799            self.D[v2,v2] += (a2*a2 + b2*b2)*area
800
801            #Compute contributions for basis functions sharing edges
802            e01 = (a0*a1 + b0*b1)*area
803            self.D[v0,v1] += e01
804            self.D[v1,v0] += e01
805
806            e12 = (a1*a2 + b1*b2)*area
807            self.D[v1,v2] += e12
808            self.D[v2,v1] += e12
809
810            e20 = (a2*a0 + b2*b0)*area
811            self.D[v2,v0] += e20
812            self.D[v0,v2] += e20
813
814
815    def fit(self, z):
816        """Fit a smooth surface to given 1d array of data points z.
817
818        The smooth surface is computed at each vertex in the underlying
819        mesh using the formula given in the module doc string.
820
821        Pre Condition:
822          self.A, self.AtA and self.B have been initialised
823
824        Inputs:
825          z: Single 1d vector or array of data at the point_coordinates.
826        """
827
828        #Convert input to Numeric arrays
829        from util import ensure_numeric
830        z = ensure_numeric(z, Float)
831
832        if len(z.shape) > 1 :
833            raise VectorShapeError, 'Can only deal with 1d data vector'
834
835        if self.point_indices is not None:
836            #Remove values for any points that were outside mesh
837            z = take(z, self.point_indices)
838
839        #Compute right hand side based on data
840        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
841        # after a matrix is built, before calcs.
842        Atz = self.A.trans_mult(z)
843
844
845        #Check sanity
846        n, m = self.A.shape
847        if n<m and self.alpha == 0.0:
848            msg = 'ERROR (least_squares): Too few data points\n'
849            msg += 'There are only %d data points and alpha == 0. ' %n
850            msg += 'Need at least %d\n' %m
851            msg += 'Alternatively, set smoothing parameter alpha to a small '
852            msg += 'positive value,\ne.g. 1.0e-3.'
853            raise msg
854
855
856
857        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
858        #FIXME: Should we store the result here for later use? (ON)
859
860
861    def fit_points(self, z, verbose=False):
862        """Like fit, but more robust when each point has two or more attributes
863        FIXME (Ole): The name fit_points doesn't carry any meaning
864        for me. How about something like fit_multiple or fit_columns?
865        """
866
867        try:
868            if verbose: print 'Solving penalised least_squares problem'
869            return self.fit(z)
870        except VectorShapeError, e:
871            # broadcasting is not supported.
872
873            #Convert input to Numeric arrays
874            from util import ensure_numeric
875            z = ensure_numeric(z, Float)
876
877            #Build n x m interpolation matrix
878            m = self.mesh.coordinates.shape[0] #Number of vertices
879            n = z.shape[1]                     #Number of data points
880
881            f = zeros((m,n), Float) #Resulting columns
882
883            for i in range(z.shape[1]):
884                f[:,i] = self.fit(z[:,i])
885
886            return f
887
888
889    def interpolate(self, f):
890        """Evaluate smooth surface f at data points implied in self.A.
891
892        The mesh values representing a smooth surface are
893        assumed to be specified in f. This argument could,
894        for example have been obtained from the method self.fit()
895
896        Pre Condition:
897          self.A has been initialised
898
899        Inputs:
900          f: Vector or array of data at the mesh vertices.
901          If f is an array, interpolation will be done for each column as
902          per underlying matrix-matrix multiplication
903
904        Output:
905          Interpolated values at data points implied in self.A
906
907        """
908
909        return self.A * f
910
911    def cull_outsiders(self, f):
912        pass
913
914
915
916
917class Interpolation_function:
918    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
919    which is interpolated from time series defined at vertices of
920    triangular mesh (such as those stored in sww files)
921
922    Let m be the number of vertices, n the number of triangles
923    and p the number of timesteps.
924
925    Mandatory input
926        time:               px1 array of monotonously increasing times (Float)
927        quantities:         Dictionary of arrays or 1 array (Float)
928                            The arrays must either have dimensions pxm or mx1.
929                            The resulting function will be time dependent in
930                            the former case while it will be constan with
931                            respect to time in the latter case.
932       
933    Optional input:
934        quantity_names:     List of keys into the quantities dictionary
935        vertex_coordinates: mx2 array of coordinates (Float)
936        triangles:          nx3 array of indices into vertex_coordinates (Int)
937        interpolation_points: Nx2 array of coordinates to be interpolated to
938        verbose:            Level of reporting
939   
940   
941    The quantities returned by the callable object are specified by
942    the list quantities which must contain the names of the
943    quantities to be returned and also reflect the order, e.g. for
944    the shallow water wave equation, on would have
945    quantities = ['stage', 'xmomentum', 'ymomentum']
946
947    The parameter interpolation_points decides at which points interpolated
948    quantities are to be computed whenever object is called.
949    If None, return average value
950    """
951
952   
953   
954    def __init__(self,
955                 time,
956                 quantities,
957                 quantity_names = None, 
958                 vertex_coordinates = None,
959                 triangles = None,
960                 interpolation_points = None,
961                 verbose = False):
962        """Initialise object and build spatial interpolation if required
963        """
964
965        from Numeric import array, zeros, Float, alltrue, concatenate,\
966             reshape, ArrayType
967
968
969        from util import mean, ensure_numeric
970        from config import time_format
971        import types
972
973
974
975        #Check temporal info
976        time = ensure_numeric(time)       
977        msg = 'Time must be a monotonuosly '
978        msg += 'increasing sequence %s' %time
979        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
980
981
982        #Check if quantities is a single array only
983        if type(quantities) != types.DictType:
984            quantities = ensure_numeric(quantities)
985            quantity_names = ['Attribute']
986
987            #Make it a dictionary
988            quantities = {quantity_names[0]: quantities}
989
990
991        #Use keys if no names are specified
992        if quantity_names is None:
993            quantity_names = quantities.keys()
994
995
996        #Check spatial info
997        if vertex_coordinates is None:
998            self.spatial = False
999        else:   
1000            vertex_coordinates = ensure_numeric(vertex_coordinates)
1001
1002            assert triangles is not None, 'Triangles array must be specified'
1003            triangles = ensure_numeric(triangles)
1004            self.spatial = True           
1005           
1006
1007 
1008        #Save for use with statistics
1009        self.quantity_names = quantity_names       
1010        self.quantities = quantities       
1011        self.vertex_coordinates = vertex_coordinates
1012        self.interpolation_points = interpolation_points
1013        self.T = time[:]  # Time assumed to be relative to starttime
1014        self.index = 0    # Initial time index
1015        self.precomputed_values = {}
1016           
1017
1018
1019        #Precomputed spatial interpolation if requested
1020        if interpolation_points is not None:
1021            if self.spatial is False:
1022                raise 'Triangles and vertex_coordinates must be specified'
1023           
1024            try:
1025                self.interpolation_points = ensure_numeric(interpolation_points)
1026            except:
1027                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1028                      'or a list of points\n'
1029                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1030                                      '...')
1031                raise msg
1032
1033
1034            m = len(self.interpolation_points)
1035            p = len(self.T)
1036           
1037            for name in quantity_names:
1038                self.precomputed_values[name] = zeros((p, m), Float)
1039
1040            #Build interpolator
1041            interpol = Interpolation(vertex_coordinates,
1042                                     triangles,
1043                                     point_coordinates = \
1044                                     self.interpolation_points,
1045                                     alpha = 0,
1046                                     precrop = False, 
1047                                     verbose = verbose)
1048
1049            if verbose: print 'Interpolate'
1050            for i, t in enumerate(self.T):
1051                #Interpolate quantities at this timestep
1052                if verbose and i%((p+10)/10)==0:
1053                    print ' time step %d of %d' %(i, p)
1054                   
1055                for name in quantity_names:
1056                    if len(quantities[name].shape) == 2:
1057                        result = interpol.interpolate(quantities[name][i,:])
1058                    else:
1059                       #Assume no time dependency
1060                       result = interpol.interpolate(quantities[name][:])
1061                       
1062                    self.precomputed_values[name][i, :] = result
1063                   
1064                       
1065
1066            #Report
1067            if verbose:
1068                print self.statistics()
1069                #self.print_statistics()
1070           
1071        else:
1072            #Store quantitites as is
1073            for name in quantity_names:
1074                self.precomputed_values[name] = quantities[name]
1075
1076
1077        #else:
1078        #    #Return an average, making this a time series
1079        #    for name in quantity_names:
1080        #        self.values[name] = zeros(len(self.T), Float)
1081        #
1082        #    if verbose: print 'Compute mean values'
1083        #    for i, t in enumerate(self.T):
1084        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1085        #        for name in quantity_names:
1086        #           self.values[name][i] = mean(quantities[name][i,:])
1087
1088
1089
1090
1091    def __repr__(self):
1092        #return 'Interpolation function (spatio-temporal)'
1093        return self.statistics()
1094   
1095
1096    def __call__(self, t, point_id = None, x = None, y = None):
1097        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1098
1099        Inputs:
1100          t: time - Model time. Must lie within existing timesteps
1101          point_id: index of one of the preprocessed points.
1102          x, y:     Overrides location, point_id ignored
1103         
1104          If spatial info is present and all of x,y,point_id
1105          are None an exception is raised
1106                   
1107          If no spatial info is present, point_id and x,y arguments are ignored
1108          making f a function of time only.
1109
1110         
1111          FIXME: point_id could also be a slice
1112          FIXME: What if x and y are vectors?
1113          FIXME: What about f(x,y) without t?
1114        """
1115
1116        from math import pi, cos, sin, sqrt
1117        from Numeric import zeros, Float
1118        from util import mean       
1119
1120        if self.spatial is True:
1121            if point_id is None:
1122                if x is None or y is None:
1123                    msg = 'Either point_id or x and y must be specified'
1124                    raise msg
1125            else:
1126                if self.interpolation_points is None:
1127                    msg = 'Interpolation_function must be instantiated ' +\
1128                          'with a list of interpolation points before parameter ' +\
1129                          'point_id can be used'
1130                    raise msg
1131
1132
1133        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1134        msg += ' does not match model time: %s\n' %t
1135        if t < self.T[0]: raise msg
1136        if t > self.T[-1]: raise msg
1137
1138        oldindex = self.index #Time index
1139
1140        #Find current time slot
1141        while t > self.T[self.index]: self.index += 1
1142        while t < self.T[self.index]: self.index -= 1
1143
1144        if t == self.T[self.index]:
1145            #Protect against case where t == T[-1] (last time)
1146            # - also works in general when t == T[i]
1147            ratio = 0
1148        else:
1149            #t is now between index and index+1
1150            ratio = (t - self.T[self.index])/\
1151                    (self.T[self.index+1] - self.T[self.index])
1152
1153        #Compute interpolated values
1154        q = zeros(len(self.quantity_names), Float)
1155
1156        for i, name in enumerate(self.quantity_names):
1157            Q = self.precomputed_values[name]
1158
1159            if self.spatial is False:
1160                #If there is no spatial info               
1161                assert len(Q.shape) == 1
1162
1163                Q0 = Q[self.index]
1164                if ratio > 0: Q1 = Q[self.index+1]
1165
1166            else:
1167                if x is not None and y is not None:
1168                    #Interpolate to x, y
1169                   
1170                    raise 'x,y interpolation not yet implemented'
1171                else:
1172                    #Use precomputed point
1173                    Q0 = Q[self.index, point_id]
1174                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1175
1176            #Linear temporal interpolation   
1177            if ratio > 0:
1178                q[i] = Q0 + ratio*(Q1 - Q0)
1179            else:
1180                q[i] = Q0
1181
1182
1183        #Return vector of interpolated values
1184        #if len(q) == 1:
1185        #    return q[0]
1186        #else:
1187        #    return q
1188
1189
1190        #Return vector of interpolated values
1191        #FIXME:
1192        if self.spatial is True:
1193            return q
1194        else:
1195            #Replicate q according to x and y
1196            #This is e.g used for Wind_stress
1197            if x == None or y == None: 
1198                return q
1199            else:
1200                try:
1201                    N = len(x)
1202                except:
1203                    return q
1204                else:
1205                    from Numeric import ones, Float
1206                    #x is a vector - Create one constant column for each value
1207                    N = len(x)
1208                    assert len(y) == N, 'x and y must have same length'
1209                    res = []
1210                    for col in q:
1211                        res.append(col*ones(N, Float))
1212                       
1213                return res
1214
1215
1216    def statistics(self):
1217        """Output statistics about interpolation_function
1218        """
1219       
1220        vertex_coordinates = self.vertex_coordinates
1221        interpolation_points = self.interpolation_points               
1222        quantity_names = self.quantity_names
1223        quantities = self.quantities
1224        precomputed_values = self.precomputed_values                 
1225               
1226        x = vertex_coordinates[:,0]
1227        y = vertex_coordinates[:,1]               
1228
1229        str =  '------------------------------------------------\n'
1230        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1231        str += '  Extent:\n'
1232        str += '    x in [%f, %f], len(x) == %d\n'\
1233               %(min(x), max(x), len(x))
1234        str += '    y in [%f, %f], len(y) == %d\n'\
1235               %(min(y), max(y), len(y))
1236        str += '    t in [%f, %f], len(t) == %d\n'\
1237               %(min(self.T), max(self.T), len(self.T))
1238        str += '  Quantities:\n'
1239        for name in quantity_names:
1240            q = quantities[name][:].flat
1241            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1242
1243        if interpolation_points is not None:   
1244            str += '  Interpolation points (xi, eta):'\
1245                   ' number of points == %d\n' %interpolation_points.shape[0]
1246            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1247                                            max(interpolation_points[:,0]))
1248            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1249                                             max(interpolation_points[:,1]))
1250            str += '  Interpolated quantities (over all timesteps):\n'
1251       
1252            for name in quantity_names:
1253                q = precomputed_values[name][:].flat
1254                str += '    %s at interpolation points in [%f, %f]\n'\
1255                       %(name, min(q), max(q))
1256        str += '------------------------------------------------\n'
1257
1258        return str
1259
1260        #FIXME: Delete
1261        #print '------------------------------------------------'
1262        #print 'Interpolation_function statistics:'
1263        #print '  Extent:'
1264        #print '    x in [%f, %f], len(x) == %d'\
1265        #      %(min(x), max(x), len(x))
1266        #print '    y in [%f, %f], len(y) == %d'\
1267        #      %(min(y), max(y), len(y))
1268        #print '    t in [%f, %f], len(t) == %d'\
1269        #      %(min(self.T), max(self.T), len(self.T))
1270        #print '  Quantities:'
1271        #for name in quantity_names:
1272        #    q = quantities[name][:].flat
1273        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1274        #print '  Interpolation points (xi, eta):'\
1275        #      ' number of points == %d ' %interpolation_points.shape[0]
1276        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1277        #                             max(interpolation_points[:,0]))
1278        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1279        #                              max(interpolation_points[:,1]))
1280        #print '  Interpolated quantities (over all timesteps):'
1281        #
1282        #for name in quantity_names:
1283        #    q = precomputed_values[name][:].flat
1284        #    print '    %s at interpolation points in [%f, %f]'\
1285        #          %(name, min(q), max(q))
1286        #print '------------------------------------------------'
1287
1288
1289#-------------------------------------------------------------
1290if __name__ == "__main__":
1291    """
1292    Load in a mesh and data points with attributes.
1293    Fit the attributes to the mesh.
1294    Save a new mesh file.
1295    """
1296    import os, sys
1297    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1298            %os.path.basename(sys.argv[0])
1299
1300    if len(sys.argv) < 4:
1301        print usage
1302    else:
1303        mesh_file = sys.argv[1]
1304        point_file = sys.argv[2]
1305        mesh_output_file = sys.argv[3]
1306
1307        expand_search = False
1308        if len(sys.argv) > 4:
1309            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1310                expand_search = True
1311            else:
1312                expand_search = False
1313
1314        verbose = False
1315        if len(sys.argv) > 5:
1316            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1317                verbose = False
1318            else:
1319                verbose = True
1320
1321        if len(sys.argv) > 6:
1322            alpha = sys.argv[6]
1323        else:
1324            alpha = DEFAULT_ALPHA
1325
1326        # This is used more for testing
1327        if len(sys.argv) > 7:
1328            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1329                display_errors = False
1330            else:
1331                display_errors = True
1332           
1333        t0 = time.time()
1334        try:
1335            fit_to_mesh_file(mesh_file,
1336                         point_file,
1337                         mesh_output_file,
1338                         alpha,
1339                         verbose= verbose,
1340                         expand_search = expand_search,
1341                         display_errors = display_errors)
1342        except IOError,e:
1343            import sys; sys.exit(1)
1344
1345        print 'That took %.2f seconds' %(time.time()-t0)
1346
Note: See TracBrowser for help on using the repository browser.