source: inundation/pyvolution/least_squares.py @ 1941

Last change on this file since 1941 was 1941, checked in by duncan, 18 years ago

adding diagnostic info

File size: 47.7 KB
RevLine 
[1911]1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
[1941]25from pyvolution.mesh import Mesh
[1911]26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
[1941]28from pyvolution.sparse import Sparse, Sparse_CSR
29from pyvolution.cg_solve import conjugate_gradient, VectorShapeError
[1911]30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME (DSG-ON) reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203    interp = Interpolation(vertex_coordinates,
204                           triangles,
205                           point_coordinates,
206                           alpha = alpha,
207                           verbose = verbose,
208                           expand_search = expand_search,
209                           data_origin = data_origin,
210                           mesh_origin = mesh_origin,
211                           precrop = precrop)
212
213    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
214    return vertex_attributes
215
216
217
218def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
219                    verbose = False, reduction = 1):
220    """Fits attributes from pts file to MxN rectangular mesh
221
222    Read pts file and create rectangular mesh of resolution MxN such that
223    it covers all points specified in pts file.
224
225    FIXME: This may be a temporary function until we decide on
226    netcdf formats etc
227
228    FIXME: Uses elevation hardwired
229    """
230
231    import  mesh_factory
232    from load_mesh.loadASCII import import_points_file
233   
234    if verbose: print 'Read pts'
235    points_dict = import_points_file(pts_name)
236    #points, attributes = util.read_xya(pts_name)
237
238    #Reduce number of points a bit
239    points = points_dict['pointlist'][::reduction]
240    elevation = points_dict['attributelist']['elevation']  #Must be elevation
241    elevation = elevation[::reduction]
242
243    if verbose: print 'Got %d data points' %len(points)
244
245    if verbose: print 'Create mesh'
246    #Find extent
247    max_x = min_x = points[0][0]
248    max_y = min_y = points[0][1]
249    for point in points[1:]:
250        x = point[0]
251        if x > max_x: max_x = x
252        if x < min_x: min_x = x
253        y = point[1]
254        if y > max_y: max_y = y
255        if y < min_y: min_y = y
256
257    #Create appropriate mesh
258    vertex_coordinates, triangles, boundary =\
259         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
260                                (min_x, min_y))
261
262    #Fit attributes to mesh
263    vertex_attributes = fit_to_mesh(vertex_coordinates,
264                        triangles,
265                        points,
266                        elevation, alpha=alpha, verbose=verbose)
267
268
269
270    return vertex_coordinates, triangles, boundary, vertex_attributes
271
272
273
274class Interpolation:
275
276    def __init__(self,
277                 vertex_coordinates,
278                 triangles,
279                 point_coordinates = None,
280                 alpha = None,
281                 verbose = False,
282                 expand_search = True,
283                 interp_only = False,
284                 max_points_per_cell = 30,
285                 mesh_origin = None,
286                 data_origin = None,
287                 precrop = False):
288
289
290        """ Build interpolation matrix mapping from
291        function values at vertices to function values at data points
292
293        Inputs:
294
295          vertex_coordinates: List of coordinate pairs [xi, eta] of
296          points constituting mesh (or a an m x 2 Numeric array)
297          Points may appear multiple times
298          (e.g. if vertices have discontinuities)
299
300          triangles: List of 3-tuples (or a Numeric array) of
301          integers representing indices of all vertices in the mesh.
302
303          point_coordinates: List of coordinate pairs [x, y] of
304          data points (or an nx2 Numeric array)
305          If point_coordinates is absent, only smoothing matrix will
306          be built
307
308          alpha: Smoothing parameter
309
310          data_origin and mesh_origin are 3-tuples consisting of
311          UTM zone, easting and northing. If specified
312          point coordinates and vertex coordinates are assumed to be
313          relative to their respective origins.
314
315        """
[1941]316        from pyvolution.util import ensure_numeric
[1911]317
318        #Convert input to Numeric arrays
319        triangles = ensure_numeric(triangles, Int)
320        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
321
322        #Build underlying mesh
323        if verbose: print 'Building mesh'
324        #self.mesh = General_mesh(vertex_coordinates, triangles,
325        #FIXME: Trying the normal mesh while testing precrop,
326        #       The functionality of boundary_polygon is needed for that
327
328        #FIXME - geo ref does not have to go into mesh.
329        # Change the point co-ords to conform to the
330        # mesh co-ords early in the code
331        if mesh_origin == None:
332            geo = None
333        else:
334            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
335        self.mesh = Mesh(vertex_coordinates, triangles,
336                         geo_reference = geo)
337       
338        self.mesh.check_integrity()
339
340        self.data_origin = data_origin
341
342        self.point_indices = None
343
344        #Smoothing parameter
345        if alpha is None:
346            self.alpha = DEFAULT_ALPHA
347        else:   
348            self.alpha = alpha
349
350
351        if point_coordinates is not None:
352            if verbose: print 'Building interpolation matrix'
353            self.build_interpolation_matrix_A(point_coordinates,
354                                              verbose = verbose,
355                                              expand_search = expand_search,
356                                              interp_only = interp_only, 
357                                              max_points_per_cell =\
358                                              max_points_per_cell,
359                                              data_origin = data_origin,
360                                              precrop = precrop)
361        #Build coefficient matrices
362        if interp_only == False:
363            self.build_coefficient_matrix_B(point_coordinates,
364                                        verbose = verbose,
365                                        expand_search = expand_search,
366                                        max_points_per_cell =\
367                                        max_points_per_cell,
368                                        data_origin = data_origin,
369                                        precrop = precrop)
370
371
372    def set_point_coordinates(self, point_coordinates,
373                              data_origin = None,
374                              verbose = False,
375                              precrop = True):
376        """
377        A public interface to setting the point co-ordinates.
378        """
379        if point_coordinates is not None:
380            if verbose: print 'Building interpolation matrix'
381            self.build_interpolation_matrix_A(point_coordinates,
382                                              verbose = verbose,
383                                              data_origin = data_origin,
384                                              precrop = precrop)
385        self.build_coefficient_matrix_B(point_coordinates, data_origin)
386
387    def build_coefficient_matrix_B(self, point_coordinates=None,
388                                   verbose = False, expand_search = True,
389                                   max_points_per_cell=30,
390                                   data_origin = None,
391                                   precrop = False):
392        """Build final coefficient matrix"""
393
394
395        if self.alpha <> 0:
396            if verbose: print 'Building smoothing matrix'
397            self.build_smoothing_matrix_D()
398
399        if point_coordinates is not None:
400            if self.alpha <> 0:
401                self.B = self.AtA + self.alpha*self.D
402            else:
403                self.B = self.AtA
404
405            #Convert self.B matrix to CSR format for faster matrix vector
406            self.B = Sparse_CSR(self.B)
407
408    def build_interpolation_matrix_A(self, point_coordinates,
409                                     verbose = False, expand_search = True,
410                                     max_points_per_cell=30,
411                                     data_origin = None,
412                                     precrop = False,
413                                     interp_only = False):
414        """Build n x m interpolation matrix, where
415        n is the number of data points and
416        m is the number of basis functions phi_k (one per vertex)
417
418        This algorithm uses a quad tree data structure for fast binning of data points
419        origin is a 3-tuple consisting of UTM zone, easting and northing.
420        If specified coordinates are assumed to be relative to this origin.
421
422        This one will override any data_origin that may be specified in
423        interpolation instance
424
425        """
426
427
428        #FIXME (Ole): Check that this function is memeory efficient.
429        #6 million datapoints and 300000 basis functions
430        #causes out-of-memory situation
431        #First thing to check is whether there is room for self.A and self.AtA
432        #
433        #Maybe we need some sort of blocking
434
[1941]435        from pyvolution.quad import build_quadtree
[1911]436        from utilities.numerical_tools import ensure_numeric
437        from utilities.polygon import inside_polygon
438       
439
440        if data_origin is None:
441            data_origin = self.data_origin #Use the one from
442                                           #interpolation instance
443
444        #Convert input to Numeric arrays just in case.
445        point_coordinates = ensure_numeric(point_coordinates, Float)
446
447        #Keep track of discarded points (if any).
448        #This is only registered if precrop is True
449        self.cropped_points = False
450
451        #Shift data points to same origin as mesh (if specified)
452
453        #FIXME this will shift if there was no geo_ref.
454        #But all this should be removed anyhow.
455        #change coords before this point
456        mesh_origin = self.mesh.geo_reference.get_origin()
457        if point_coordinates is not None:
458            if data_origin is not None:
459                if mesh_origin is not None:
460
461                    #Transformation:
462                    #
463                    #Let x_0 be the reference point of the point coordinates
464                    #and xi_0 the reference point of the mesh.
465                    #
466                    #A point coordinate (x + x_0) is then made relative
467                    #to xi_0 by
468                    #
469                    # x_new = x + x_0 - xi_0
470                    #
471                    #and similarly for eta
472
473                    x_offset = data_origin[1] - mesh_origin[1]
474                    y_offset = data_origin[2] - mesh_origin[2]
475                else: #Shift back to a zero origin
476                    x_offset = data_origin[1]
477                    y_offset = data_origin[2]
478
479                point_coordinates[:,0] += x_offset
480                point_coordinates[:,1] += y_offset
481            else:
482                if mesh_origin is not None:
483                    #Use mesh origin for data points
484                    point_coordinates[:,0] -= mesh_origin[1]
485                    point_coordinates[:,1] -= mesh_origin[2]
486
487
488
489        #Remove points falling outside mesh boundary
490        #This reduced one example from 1356 seconds to 825 seconds
491        if precrop is True:
492            from Numeric import take
493
494            if verbose: print 'Getting boundary polygon'
495            P = self.mesh.get_boundary_polygon()
496
497            if verbose: print 'Getting indices inside mesh boundary'
498            indices = inside_polygon(point_coordinates, P, verbose = verbose)
499
500
501            if len(indices) != point_coordinates.shape[0]:
502                self.cropped_points = True
503                if verbose:
504                    print 'Done - %d points outside mesh have been cropped.'\
505                          %(point_coordinates.shape[0] - len(indices))
506
507            point_coordinates = take(point_coordinates, indices)
508            self.point_indices = indices
509
510
511
512
513        #Build n x m interpolation matrix
514        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
515        n = point_coordinates.shape[0]     #Nbr of data points
516
517        if verbose: print 'Number of datapoints: %d' %n
518        if verbose: print 'Number of basis functions: %d' %m
519
520        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
521        #However, Sparse_CSR does not have the same methods as Sparse yet
522        #The tests will reveal what needs to be done
523
524        #
525        #self.A = Sparse_CSR(Sparse(n,m))
526        #self.AtA = Sparse_CSR(Sparse(m,m))
527        self.A = Sparse(n,m)
528        self.AtA = Sparse(m,m)
529
530        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
531        root = build_quadtree(self.mesh,
532                              max_points_per_cell = max_points_per_cell)
[1941]533        #root.show()
534        self.expanded_quad_searches = []
[1911]535        #Compute matrix elements
536        for i in range(n):
537            #For each data_coordinate point
538
539            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
540            x = point_coordinates[i]
541
542            #Find vertices near x
543            candidate_vertices = root.search(x[0], x[1])
544            is_more_elements = True
545
546            element_found, sigma0, sigma1, sigma2, k = \
547                self.search_triangles_of_vertices(candidate_vertices, x)
[1941]548            first_expansion = True
[1911]549            while not element_found and is_more_elements and expand_search:
[1941]550                if verbose: print 'Expanding search'
551                if first_expansion = True:
552                    expanded_quad_searches.append(1)
553                else:
554                    end = len(expanded_quad_searches) - 1
555                    assert end >= 0
556                    expanded_quad_searches[end] += 1
557                    print "expanded_quad_searches",expanded_quad_searches
558                self.expanded_quad_searches += 1
[1911]559                candidate_vertices, branch = root.expand_search()
560                if branch == []:
561                    # Searching all the verts from the root cell that haven't
562                    # been searched.  This is the last try
563                    element_found, sigma0, sigma1, sigma2, k = \
564                      self.search_triangles_of_vertices(candidate_vertices, x)
565                    is_more_elements = False
566                else:
567                    element_found, sigma0, sigma1, sigma2, k = \
568                      self.search_triangles_of_vertices(candidate_vertices, x)
569
[1941]570               
[1911]571            #Update interpolation matrix A if necessary
572            if element_found is True:
573                #Assign values to matrix A
574
575                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
576                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
577                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
578
579                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
580                js     = [j0,j1,j2]
581
582                for j in js:
583                    self.A[i,j] = sigmas[j]
584                    for k in js:
585                        if interp_only == False:
586                            self.AtA[j,k] += sigmas[j]*sigmas[k]
587            else:
588                pass
589                #Ok if there is no triangle for datapoint
590                #(as in brute force version)
591                #raise 'Could not find triangle for point', x
592
593
594
595    def search_triangles_of_vertices(self, candidate_vertices, x):
596            #Find triangle containing x:
597            element_found = False
598
599            # This will be returned if element_found = False
600            sigma2 = -10.0
601            sigma0 = -10.0
602            sigma1 = -10.0
603            k = -10.0
[1941]604            #print "*$* candidate_vertices", candidate_vertices
[1911]605            #For all vertices in same cell as point x
606            for v in candidate_vertices:
607
608                #for each triangle id (k) which has v as a vertex
609                for k, _ in self.mesh.vertexlist[v]:
610
611                    #Get the three vertex_points of candidate triangle
612                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
613                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
614                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
615
616                    #print "PDSG - k", k
617                    #print "PDSG - xi0", xi0
618                    #print "PDSG - xi1", xi1
619                    #print "PDSG - xi2", xi2
620                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
621                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
622
623                    #Get the three normals
624                    n0 = self.mesh.get_normal(k, 0)
625                    n1 = self.mesh.get_normal(k, 1)
626                    n2 = self.mesh.get_normal(k, 2)
627
628
629                    #Compute interpolation
630                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
631                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
632                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
633
634                    #print "PDSG - sigma0", sigma0
635                    #print "PDSG - sigma1", sigma1
636                    #print "PDSG - sigma2", sigma2
637
638                    #FIXME: Maybe move out to test or something
639                    epsilon = 1.0e-6
640                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
641
642                    #Check that this triangle contains the data point
643
644                    #Sigmas can get negative within
645                    #machine precision on some machines (e.g nautilus)
646                    #Hence the small eps
647                    eps = 1.0e-15
648                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
649                        element_found = True
650                        break
651
652                if element_found is True:
653                    #Don't look for any other triangle
654                    break
655            return element_found, sigma0, sigma1, sigma2, k
656
657
658
659    def build_interpolation_matrix_A_brute(self, point_coordinates):
660        """Build n x m interpolation matrix, where
661        n is the number of data points and
662        m is the number of basis functions phi_k (one per vertex)
663
664        This is the brute force which is too slow for large problems,
665        but could be used for testing
666        """
667
668        from util import ensure_numeric
669
670        #Convert input to Numeric arrays
671        point_coordinates = ensure_numeric(point_coordinates, Float)
672
673        #Build n x m interpolation matrix
674        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
675        n = point_coordinates.shape[0]     #Nbr of data points
676
677        self.A = Sparse(n,m)
678        self.AtA = Sparse(m,m)
679
680        #Compute matrix elements
681        for i in range(n):
682            #For each data_coordinate point
683
684            x = point_coordinates[i]
685            element_found = False
686            k = 0
687            while not element_found and k < len(self.mesh):
688                #For each triangle (brute force)
689                #FIXME: Real algorithm should only visit relevant triangles
690
691                #Get the three vertex_points
692                xi0 = self.mesh.get_vertex_coordinate(k, 0)
693                xi1 = self.mesh.get_vertex_coordinate(k, 1)
694                xi2 = self.mesh.get_vertex_coordinate(k, 2)
695
696                #Get the three normals
697                n0 = self.mesh.get_normal(k, 0)
698                n1 = self.mesh.get_normal(k, 1)
699                n2 = self.mesh.get_normal(k, 2)
700
701                #Compute interpolation
702                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
703                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
704                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
705
706                #FIXME: Maybe move out to test or something
707                epsilon = 1.0e-6
708                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
709
710                #Check that this triangle contains data point
711                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
712                    element_found = True
713                    #Assign values to matrix A
714
715                    j0 = self.mesh.triangles[k,0] #Global vertex id
716                    #self.A[i, j0] = sigma0
717
718                    j1 = self.mesh.triangles[k,1] #Global vertex id
719                    #self.A[i, j1] = sigma1
720
721                    j2 = self.mesh.triangles[k,2] #Global vertex id
722                    #self.A[i, j2] = sigma2
723
724                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
725                    js     = [j0,j1,j2]
726
727                    for j in js:
728                        self.A[i,j] = sigmas[j]
729                        for k in js:
730                            self.AtA[j,k] += sigmas[j]*sigmas[k]
731                k = k+1
732
733
734
735    def get_A(self):
736        return self.A.todense()
737
738    def get_B(self):
739        return self.B.todense()
740
741    def get_D(self):
742        return self.D.todense()
743
744        #FIXME: Remember to re-introduce the 1/n factor in the
745        #interpolation term
746
747    def build_smoothing_matrix_D(self):
748        """Build m x m smoothing matrix, where
749        m is the number of basis functions phi_k (one per vertex)
750
751        The smoothing matrix is defined as
752
753        D = D1 + D2
754
755        where
756
757        [D1]_{k,l} = \int_\Omega
758           \frac{\partial \phi_k}{\partial x}
759           \frac{\partial \phi_l}{\partial x}\,
760           dx dy
761
762        [D2]_{k,l} = \int_\Omega
763           \frac{\partial \phi_k}{\partial y}
764           \frac{\partial \phi_l}{\partial y}\,
765           dx dy
766
767
768        The derivatives \frac{\partial \phi_k}{\partial x},
769        \frac{\partial \phi_k}{\partial x} for a particular triangle
770        are obtained by computing the gradient a_k, b_k for basis function k
771        """
772
773        #FIXME: algorithm might be optimised by computing local 9x9
774        #"element stiffness matrices:
775
776        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
777
778        self.D = Sparse(m,m)
779
780        #For each triangle compute contributions to D = D1+D2
781        for i in range(len(self.mesh)):
782
783            #Get area
784            area = self.mesh.areas[i]
785
786            #Get global vertex indices
787            v0 = self.mesh.triangles[i,0]
788            v1 = self.mesh.triangles[i,1]
789            v2 = self.mesh.triangles[i,2]
790
791            #Get the three vertex_points
792            xi0 = self.mesh.get_vertex_coordinate(i, 0)
793            xi1 = self.mesh.get_vertex_coordinate(i, 1)
794            xi2 = self.mesh.get_vertex_coordinate(i, 2)
795
796            #Compute gradients for each vertex
797            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
798                              1, 0, 0)
799
800            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
801                              0, 1, 0)
802
803            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
804                              0, 0, 1)
805
806            #Compute diagonal contributions
807            self.D[v0,v0] += (a0*a0 + b0*b0)*area
808            self.D[v1,v1] += (a1*a1 + b1*b1)*area
809            self.D[v2,v2] += (a2*a2 + b2*b2)*area
810
811            #Compute contributions for basis functions sharing edges
812            e01 = (a0*a1 + b0*b1)*area
813            self.D[v0,v1] += e01
814            self.D[v1,v0] += e01
815
816            e12 = (a1*a2 + b1*b2)*area
817            self.D[v1,v2] += e12
818            self.D[v2,v1] += e12
819
820            e20 = (a2*a0 + b2*b0)*area
821            self.D[v2,v0] += e20
822            self.D[v0,v2] += e20
823
824
825    def fit(self, z):
826        """Fit a smooth surface to given 1d array of data points z.
827
828        The smooth surface is computed at each vertex in the underlying
829        mesh using the formula given in the module doc string.
830
831        Pre Condition:
832          self.A, self.AtA and self.B have been initialised
833
834        Inputs:
835          z: Single 1d vector or array of data at the point_coordinates.
836        """
837
838        #Convert input to Numeric arrays
[1941]839        from pyvolution.util import ensure_numeric
[1911]840        z = ensure_numeric(z, Float)
841
842        if len(z.shape) > 1 :
843            raise VectorShapeError, 'Can only deal with 1d data vector'
844
845        if self.point_indices is not None:
846            #Remove values for any points that were outside mesh
847            z = take(z, self.point_indices)
848
849        #Compute right hand side based on data
850        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
851        # after a matrix is built, before calcs.
852        Atz = self.A.trans_mult(z)
853
854
855        #Check sanity
856        n, m = self.A.shape
857        if n<m and self.alpha == 0.0:
858            msg = 'ERROR (least_squares): Too few data points\n'
859            msg += 'There are only %d data points and alpha == 0. ' %n
860            msg += 'Need at least %d\n' %m
861            msg += 'Alternatively, set smoothing parameter alpha to a small '
862            msg += 'positive value,\ne.g. 1.0e-3.'
863            raise msg
864
865
866
867        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
868        #FIXME: Should we store the result here for later use? (ON)
869
870
871    def fit_points(self, z, verbose=False):
872        """Like fit, but more robust when each point has two or more attributes
873        FIXME (Ole): The name fit_points doesn't carry any meaning
874        for me. How about something like fit_multiple or fit_columns?
875        """
876
877        try:
878            if verbose: print 'Solving penalised least_squares problem'
879            return self.fit(z)
880        except VectorShapeError, e:
881            # broadcasting is not supported.
882
883            #Convert input to Numeric arrays
884            from util import ensure_numeric
885            z = ensure_numeric(z, Float)
886
887            #Build n x m interpolation matrix
888            m = self.mesh.coordinates.shape[0] #Number of vertices
889            n = z.shape[1]                     #Number of data points
890
891            f = zeros((m,n), Float) #Resulting columns
892
893            for i in range(z.shape[1]):
894                f[:,i] = self.fit(z[:,i])
895
896            return f
897
898
899    def interpolate(self, f):
900        """Evaluate smooth surface f at data points implied in self.A.
901
902        The mesh values representing a smooth surface are
903        assumed to be specified in f. This argument could,
904        for example have been obtained from the method self.fit()
905
906        Pre Condition:
907          self.A has been initialised
908
909        Inputs:
910          f: Vector or array of data at the mesh vertices.
911          If f is an array, interpolation will be done for each column as
912          per underlying matrix-matrix multiplication
913
914        Output:
915          Interpolated values at data points implied in self.A
916
917        """
918
919        return self.A * f
920
921    def cull_outsiders(self, f):
922        pass
923
924
925
926
927class Interpolation_function:
928    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
929    which is interpolated from time series defined at vertices of
930    triangular mesh (such as those stored in sww files)
931
932    Let m be the number of vertices, n the number of triangles
933    and p the number of timesteps.
934
935    Mandatory input
936        time:               px1 array of monotonously increasing times (Float)
937        quantities:         Dictionary of arrays or 1 array (Float)
938                            The arrays must either have dimensions pxm or mx1.
939                            The resulting function will be time dependent in
[1941]940                            the former case while it will be constan with
[1911]941                            respect to time in the latter case.
942       
943    Optional input:
944        quantity_names:     List of keys into the quantities dictionary
945        vertex_coordinates: mx2 array of coordinates (Float)
946        triangles:          nx3 array of indices into vertex_coordinates (Int)
947        interpolation_points: Nx2 array of coordinates to be interpolated to
948        verbose:            Level of reporting
949   
950   
951    The quantities returned by the callable object are specified by
952    the list quantities which must contain the names of the
953    quantities to be returned and also reflect the order, e.g. for
954    the shallow water wave equation, on would have
955    quantities = ['stage', 'xmomentum', 'ymomentum']
956
957    The parameter interpolation_points decides at which points interpolated
958    quantities are to be computed whenever object is called.
959    If None, return average value
960    """
961
962   
[1941]963   
[1911]964    def __init__(self,
965                 time,
966                 quantities,
967                 quantity_names = None, 
968                 vertex_coordinates = None,
969                 triangles = None,
970                 interpolation_points = None,
971                 verbose = False):
972        """Initialise object and build spatial interpolation if required
973        """
974
975        from Numeric import array, zeros, Float, alltrue, concatenate,\
976             reshape, ArrayType
977
978
979        from util import mean, ensure_numeric
980        from config import time_format
981        import types
982
983
984
985        #Check temporal info
986        time = ensure_numeric(time)       
987        msg = 'Time must be a monotonuosly '
988        msg += 'increasing sequence %s' %time
989        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
990
991
992        #Check if quantities is a single array only
993        if type(quantities) != types.DictType:
994            quantities = ensure_numeric(quantities)
995            quantity_names = ['Attribute']
996
997            #Make it a dictionary
998            quantities = {quantity_names[0]: quantities}
999
1000
1001        #Use keys if no names are specified
1002        if quantity_names is None:
1003            quantity_names = quantities.keys()
1004
1005
1006        #Check spatial info
1007        if vertex_coordinates is None:
1008            self.spatial = False
1009        else:   
1010            vertex_coordinates = ensure_numeric(vertex_coordinates)
1011
1012            assert triangles is not None, 'Triangles array must be specified'
1013            triangles = ensure_numeric(triangles)
1014            self.spatial = True           
1015           
1016
1017 
1018        #Save for use with statistics
1019        self.quantity_names = quantity_names       
1020        self.quantities = quantities       
1021        self.vertex_coordinates = vertex_coordinates
1022        self.interpolation_points = interpolation_points
1023        self.T = time[:]  # Time assumed to be relative to starttime
1024        self.index = 0    # Initial time index
1025        self.precomputed_values = {}
1026           
1027
1028
1029        #Precomputed spatial interpolation if requested
1030        if interpolation_points is not None:
1031            if self.spatial is False:
1032                raise 'Triangles and vertex_coordinates must be specified'
1033           
1034            try:
1035                self.interpolation_points = ensure_numeric(interpolation_points)
1036            except:
1037                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1038                      'or a list of points\n'
1039                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1040                                      '...')
1041                raise msg
1042
1043
1044            m = len(self.interpolation_points)
1045            p = len(self.T)
1046           
1047            for name in quantity_names:
1048                self.precomputed_values[name] = zeros((p, m), Float)
1049
1050            #Build interpolator
1051            interpol = Interpolation(vertex_coordinates,
1052                                     triangles,
1053                                     point_coordinates = \
1054                                     self.interpolation_points,
1055                                     alpha = 0,
1056                                     precrop = False, 
1057                                     verbose = verbose)
1058
1059            if verbose: print 'Interpolate'
1060            for i, t in enumerate(self.T):
1061                #Interpolate quantities at this timestep
1062                if verbose and i%((p+10)/10)==0:
1063                    print ' time step %d of %d' %(i, p)
1064                   
1065                for name in quantity_names:
1066                    if len(quantities[name].shape) == 2:
1067                        result = interpol.interpolate(quantities[name][i,:])
1068                    else:
1069                       #Assume no time dependency
1070                       result = interpol.interpolate(quantities[name][:])
1071                       
1072                    self.precomputed_values[name][i, :] = result
1073                   
1074                       
1075
1076            #Report
1077            if verbose:
1078                print self.statistics()
1079                #self.print_statistics()
1080           
1081        else:
1082            #Store quantitites as is
1083            for name in quantity_names:
1084                self.precomputed_values[name] = quantities[name]
1085
1086
1087        #else:
1088        #    #Return an average, making this a time series
1089        #    for name in quantity_names:
1090        #        self.values[name] = zeros(len(self.T), Float)
1091        #
1092        #    if verbose: print 'Compute mean values'
1093        #    for i, t in enumerate(self.T):
1094        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1095        #        for name in quantity_names:
1096        #           self.values[name][i] = mean(quantities[name][i,:])
1097
1098
1099
1100
1101    def __repr__(self):
1102        #return 'Interpolation function (spatio-temporal)'
1103        return self.statistics()
1104   
1105
1106    def __call__(self, t, point_id = None, x = None, y = None):
1107        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1108
1109        Inputs:
1110          t: time - Model time. Must lie within existing timesteps
1111          point_id: index of one of the preprocessed points.
1112          x, y:     Overrides location, point_id ignored
1113         
1114          If spatial info is present and all of x,y,point_id
1115          are None an exception is raised
1116                   
1117          If no spatial info is present, point_id and x,y arguments are ignored
1118          making f a function of time only.
1119
1120         
1121          FIXME: point_id could also be a slice
1122          FIXME: What if x and y are vectors?
1123          FIXME: What about f(x,y) without t?
1124        """
1125
1126        from math import pi, cos, sin, sqrt
1127        from Numeric import zeros, Float
1128        from util import mean       
1129
1130        if self.spatial is True:
1131            if point_id is None:
1132                if x is None or y is None:
1133                    msg = 'Either point_id or x and y must be specified'
1134                    raise msg
1135            else:
1136                if self.interpolation_points is None:
1137                    msg = 'Interpolation_function must be instantiated ' +\
1138                          'with a list of interpolation points before parameter ' +\
1139                          'point_id can be used'
1140                    raise msg
1141
1142
1143        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1144        msg += ' does not match model time: %s\n' %t
1145        if t < self.T[0]: raise msg
1146        if t > self.T[-1]: raise msg
1147
1148        oldindex = self.index #Time index
1149
1150        #Find current time slot
1151        while t > self.T[self.index]: self.index += 1
1152        while t < self.T[self.index]: self.index -= 1
1153
1154        if t == self.T[self.index]:
1155            #Protect against case where t == T[-1] (last time)
1156            # - also works in general when t == T[i]
1157            ratio = 0
1158        else:
1159            #t is now between index and index+1
1160            ratio = (t - self.T[self.index])/\
1161                    (self.T[self.index+1] - self.T[self.index])
1162
1163        #Compute interpolated values
1164        q = zeros(len(self.quantity_names), Float)
1165
1166        for i, name in enumerate(self.quantity_names):
1167            Q = self.precomputed_values[name]
1168
1169            if self.spatial is False:
1170                #If there is no spatial info               
1171                assert len(Q.shape) == 1
1172
1173                Q0 = Q[self.index]
1174                if ratio > 0: Q1 = Q[self.index+1]
1175
1176            else:
1177                if x is not None and y is not None:
1178                    #Interpolate to x, y
1179                   
1180                    raise 'x,y interpolation not yet implemented'
1181                else:
1182                    #Use precomputed point
1183                    Q0 = Q[self.index, point_id]
1184                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1185
1186            #Linear temporal interpolation   
1187            if ratio > 0:
1188                q[i] = Q0 + ratio*(Q1 - Q0)
1189            else:
1190                q[i] = Q0
1191
1192
1193        #Return vector of interpolated values
1194        #if len(q) == 1:
1195        #    return q[0]
1196        #else:
1197        #    return q
1198
1199
1200        #Return vector of interpolated values
1201        #FIXME:
1202        if self.spatial is True:
1203            return q
1204        else:
1205            #Replicate q according to x and y
1206            #This is e.g used for Wind_stress
1207            if x == None or y == None: 
1208                return q
1209            else:
1210                try:
1211                    N = len(x)
1212                except:
1213                    return q
1214                else:
1215                    from Numeric import ones, Float
1216                    #x is a vector - Create one constant column for each value
1217                    N = len(x)
1218                    assert len(y) == N, 'x and y must have same length'
1219                    res = []
1220                    for col in q:
1221                        res.append(col*ones(N, Float))
1222                       
1223                return res
1224
1225
1226    def statistics(self):
1227        """Output statistics about interpolation_function
1228        """
1229       
1230        vertex_coordinates = self.vertex_coordinates
1231        interpolation_points = self.interpolation_points               
1232        quantity_names = self.quantity_names
1233        quantities = self.quantities
1234        precomputed_values = self.precomputed_values                 
1235               
1236        x = vertex_coordinates[:,0]
1237        y = vertex_coordinates[:,1]               
1238
1239        str =  '------------------------------------------------\n'
1240        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1241        str += '  Extent:\n'
1242        str += '    x in [%f, %f], len(x) == %d\n'\
1243               %(min(x), max(x), len(x))
1244        str += '    y in [%f, %f], len(y) == %d\n'\
1245               %(min(y), max(y), len(y))
1246        str += '    t in [%f, %f], len(t) == %d\n'\
1247               %(min(self.T), max(self.T), len(self.T))
1248        str += '  Quantities:\n'
1249        for name in quantity_names:
1250            q = quantities[name][:].flat
1251            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1252
1253        if interpolation_points is not None:   
1254            str += '  Interpolation points (xi, eta):'\
1255                   ' number of points == %d\n' %interpolation_points.shape[0]
1256            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1257                                            max(interpolation_points[:,0]))
1258            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1259                                             max(interpolation_points[:,1]))
1260            str += '  Interpolated quantities (over all timesteps):\n'
1261       
1262            for name in quantity_names:
1263                q = precomputed_values[name][:].flat
1264                str += '    %s at interpolation points in [%f, %f]\n'\
1265                       %(name, min(q), max(q))
1266        str += '------------------------------------------------\n'
1267
1268        return str
1269
1270        #FIXME: Delete
1271        #print '------------------------------------------------'
1272        #print 'Interpolation_function statistics:'
1273        #print '  Extent:'
1274        #print '    x in [%f, %f], len(x) == %d'\
1275        #      %(min(x), max(x), len(x))
1276        #print '    y in [%f, %f], len(y) == %d'\
1277        #      %(min(y), max(y), len(y))
1278        #print '    t in [%f, %f], len(t) == %d'\
1279        #      %(min(self.T), max(self.T), len(self.T))
1280        #print '  Quantities:'
1281        #for name in quantity_names:
1282        #    q = quantities[name][:].flat
1283        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1284        #print '  Interpolation points (xi, eta):'\
1285        #      ' number of points == %d ' %interpolation_points.shape[0]
1286        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1287        #                             max(interpolation_points[:,0]))
1288        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1289        #                              max(interpolation_points[:,1]))
1290        #print '  Interpolated quantities (over all timesteps):'
1291        #
1292        #for name in quantity_names:
1293        #    q = precomputed_values[name][:].flat
1294        #    print '    %s at interpolation points in [%f, %f]'\
1295        #          %(name, min(q), max(q))
1296        #print '------------------------------------------------'
1297
1298
1299#-------------------------------------------------------------
1300if __name__ == "__main__":
1301    """
1302    Load in a mesh and data points with attributes.
1303    Fit the attributes to the mesh.
1304    Save a new mesh file.
1305    """
1306    import os, sys
1307    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1308            %os.path.basename(sys.argv[0])
1309
1310    if len(sys.argv) < 4:
1311        print usage
1312    else:
1313        mesh_file = sys.argv[1]
1314        point_file = sys.argv[2]
1315        mesh_output_file = sys.argv[3]
1316
1317        expand_search = False
1318        if len(sys.argv) > 4:
1319            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1320                expand_search = True
1321            else:
1322                expand_search = False
1323
1324        verbose = False
1325        if len(sys.argv) > 5:
1326            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1327                verbose = False
1328            else:
1329                verbose = True
1330
1331        if len(sys.argv) > 6:
1332            alpha = sys.argv[6]
1333        else:
1334            alpha = DEFAULT_ALPHA
1335
1336        # This is used more for testing
1337        if len(sys.argv) > 7:
1338            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1339                display_errors = False
1340            else:
1341                display_errors = True
1342           
1343        t0 = time.time()
1344        try:
1345            fit_to_mesh_file(mesh_file,
1346                         point_file,
1347                         mesh_output_file,
1348                         alpha,
1349                         verbose= verbose,
1350                         expand_search = expand_search,
1351                         display_errors = display_errors)
1352        except IOError,e:
1353            import sys; sys.exit(1)
1354
1355        print 'That took %.2f seconds' %(time.time()-t0)
1356
Note: See TracBrowser for help on using the repository browser.