source: inundation/pyvolution/least_squares.py @ 1979

Last change on this file since 1979 was 1979, checked in by ole, 19 years ago

Enforced compliance with "comparisons to singletons" as per PEP 8 style guide

File size: 47.9 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from pyvolution.mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from pyvolution.sparse import Sparse, Sparse_CSR
29from pyvolution.cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME (DSG-ON) reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203    interp = Interpolation(vertex_coordinates,
204                           triangles,
205                           point_coordinates,
206                           alpha = alpha,
207                           verbose = verbose,
208                           expand_search = expand_search,
209                           data_origin = data_origin,
210                           mesh_origin = mesh_origin,
211                           precrop = precrop)
212
213    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
214    return vertex_attributes
215
216
217
218def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
219                    verbose = False, reduction = 1):
220    """Fits attributes from pts file to MxN rectangular mesh
221
222    Read pts file and create rectangular mesh of resolution MxN such that
223    it covers all points specified in pts file.
224
225    FIXME: This may be a temporary function until we decide on
226    netcdf formats etc
227
228    FIXME: Uses elevation hardwired
229    """
230
231    import  mesh_factory
232    from load_mesh.loadASCII import import_points_file
233   
234    if verbose: print 'Read pts'
235    points_dict = import_points_file(pts_name)
236    #points, attributes = util.read_xya(pts_name)
237
238    #Reduce number of points a bit
239    points = points_dict['pointlist'][::reduction]
240    elevation = points_dict['attributelist']['elevation']  #Must be elevation
241    elevation = elevation[::reduction]
242
243    if verbose: print 'Got %d data points' %len(points)
244
245    if verbose: print 'Create mesh'
246    #Find extent
247    max_x = min_x = points[0][0]
248    max_y = min_y = points[0][1]
249    for point in points[1:]:
250        x = point[0]
251        if x > max_x: max_x = x
252        if x < min_x: min_x = x
253        y = point[1]
254        if y > max_y: max_y = y
255        if y < min_y: min_y = y
256
257    #Create appropriate mesh
258    vertex_coordinates, triangles, boundary =\
259         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
260                                (min_x, min_y))
261
262    #Fit attributes to mesh
263    vertex_attributes = fit_to_mesh(vertex_coordinates,
264                        triangles,
265                        points,
266                        elevation, alpha=alpha, verbose=verbose)
267
268
269
270    return vertex_coordinates, triangles, boundary, vertex_attributes
271
272
273
274class Interpolation:
275
276    def __init__(self,
277                 vertex_coordinates,
278                 triangles,
279                 point_coordinates = None,
280                 alpha = None,
281                 verbose = False,
282                 expand_search = True,
283                 interp_only = False,
284                 max_points_per_cell = 30,
285                 mesh_origin = None,
286                 data_origin = None,
287                 precrop = False):
288
289
290        """ Build interpolation matrix mapping from
291        function values at vertices to function values at data points
292
293        Inputs:
294
295          vertex_coordinates: List of coordinate pairs [xi, eta] of
296          points constituting mesh (or a an m x 2 Numeric array)
297          Points may appear multiple times
298          (e.g. if vertices have discontinuities)
299
300          triangles: List of 3-tuples (or a Numeric array) of
301          integers representing indices of all vertices in the mesh.
302
303          point_coordinates: List of coordinate pairs [x, y] of
304          data points (or an nx2 Numeric array)
305          If point_coordinates is absent, only smoothing matrix will
306          be built
307
308          alpha: Smoothing parameter
309
310          data_origin and mesh_origin are 3-tuples consisting of
311          UTM zone, easting and northing. If specified
312          point coordinates and vertex coordinates are assumed to be
313          relative to their respective origins.
314
315        """
316        from pyvolution.util import ensure_numeric
317
318        #Convert input to Numeric arrays
319        triangles = ensure_numeric(triangles, Int)
320        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
321
322        #Build underlying mesh
323        if verbose: print 'Building mesh'
324        #self.mesh = General_mesh(vertex_coordinates, triangles,
325        #FIXME: Trying the normal mesh while testing precrop,
326        #       The functionality of boundary_polygon is needed for that
327
328        #FIXME - geo ref does not have to go into mesh.
329        # Change the point co-ords to conform to the
330        # mesh co-ords early in the code
331        if mesh_origin is None:
332            geo = None
333        else:
334            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
335        self.mesh = Mesh(vertex_coordinates, triangles,
336                         geo_reference = geo)
337       
338        self.mesh.check_integrity()
339
340        self.data_origin = data_origin
341
342        self.point_indices = None
343
344        #Smoothing parameter
345        if alpha is None:
346            self.alpha = DEFAULT_ALPHA
347        else:   
348            self.alpha = alpha
349
350
351        if point_coordinates is not None:
352            if verbose: print 'Building interpolation matrix'
353            self.build_interpolation_matrix_A(point_coordinates,
354                                              verbose = verbose,
355                                              expand_search = expand_search,
356                                              interp_only = interp_only, 
357                                              max_points_per_cell =\
358                                              max_points_per_cell,
359                                              data_origin = data_origin,
360                                              precrop = precrop)
361        #Build coefficient matrices
362        if interp_only == False:
363            self.build_coefficient_matrix_B(point_coordinates,
364                                        verbose = verbose,
365                                        expand_search = expand_search,
366                                        max_points_per_cell =\
367                                        max_points_per_cell,
368                                        data_origin = data_origin,
369                                        precrop = precrop)
370
371    def set_point_coordinates(self, point_coordinates,
372                              data_origin = None,
373                              verbose = False,
374                              precrop = True):
375        """
376        A public interface to setting the point co-ordinates.
377        """
378        if point_coordinates is not None:
379            if verbose: print 'Building interpolation matrix'
380            self.build_interpolation_matrix_A(point_coordinates,
381                                              verbose = verbose,
382                                              data_origin = data_origin,
383                                              precrop = precrop)
384        self.build_coefficient_matrix_B(point_coordinates, data_origin)
385
386    def build_coefficient_matrix_B(self, point_coordinates=None,
387                                   verbose = False, expand_search = True,
388                                   max_points_per_cell=30,
389                                   data_origin = None,
390                                   precrop = False):
391        """Build final coefficient matrix"""
392
393
394        if self.alpha <> 0:
395            if verbose: print 'Building smoothing matrix'
396            self.build_smoothing_matrix_D()
397
398        if point_coordinates is not None:
399            if self.alpha <> 0:
400                self.B = self.AtA + self.alpha*self.D
401            else:
402                self.B = self.AtA
403
404            #Convert self.B matrix to CSR format for faster matrix vector
405            self.B = Sparse_CSR(self.B)
406
407    def build_interpolation_matrix_A(self, point_coordinates,
408                                     verbose = False, expand_search = True,
409                                     max_points_per_cell=30,
410                                     data_origin = None,
411                                     precrop = False,
412                                     interp_only = False):
413        """Build n x m interpolation matrix, where
414        n is the number of data points and
415        m is the number of basis functions phi_k (one per vertex)
416
417        This algorithm uses a quad tree data structure for fast binning of data points
418        origin is a 3-tuple consisting of UTM zone, easting and northing.
419        If specified coordinates are assumed to be relative to this origin.
420
421        This one will override any data_origin that may be specified in
422        interpolation instance
423
424        """
425
426
427        #FIXME (Ole): Check that this function is memeory efficient.
428        #6 million datapoints and 300000 basis functions
429        #causes out-of-memory situation
430        #First thing to check is whether there is room for self.A and self.AtA
431        #
432        #Maybe we need some sort of blocking
433
434        from pyvolution.quad import build_quadtree
435        from utilities.numerical_tools import ensure_numeric
436        from utilities.polygon import inside_polygon
437       
438
439        if data_origin is None:
440            data_origin = self.data_origin #Use the one from
441                                           #interpolation instance
442
443        #Convert input to Numeric arrays just in case.
444        point_coordinates = ensure_numeric(point_coordinates, Float)
445
446        #Keep track of discarded points (if any).
447        #This is only registered if precrop is True
448        self.cropped_points = False
449
450        #Shift data points to same origin as mesh (if specified)
451
452        #FIXME this will shift if there was no geo_ref.
453        #But all this should be removed anyhow.
454        #change coords before this point
455        mesh_origin = self.mesh.geo_reference.get_origin()
456        if point_coordinates is not None:
457            if data_origin is not None:
458                if mesh_origin is not None:
459
460                    #Transformation:
461                    #
462                    #Let x_0 be the reference point of the point coordinates
463                    #and xi_0 the reference point of the mesh.
464                    #
465                    #A point coordinate (x + x_0) is then made relative
466                    #to xi_0 by
467                    #
468                    # x_new = x + x_0 - xi_0
469                    #
470                    #and similarly for eta
471
472                    x_offset = data_origin[1] - mesh_origin[1]
473                    y_offset = data_origin[2] - mesh_origin[2]
474                else: #Shift back to a zero origin
475                    x_offset = data_origin[1]
476                    y_offset = data_origin[2]
477
478                point_coordinates[:,0] += x_offset
479                point_coordinates[:,1] += y_offset
480            else:
481                if mesh_origin is not None:
482                    #Use mesh origin for data points
483                    point_coordinates[:,0] -= mesh_origin[1]
484                    point_coordinates[:,1] -= mesh_origin[2]
485
486
487
488        #Remove points falling outside mesh boundary
489        #This reduced one example from 1356 seconds to 825 seconds
490        if precrop is True:
491            from Numeric import take
492
493            if verbose: print 'Getting boundary polygon'
494            P = self.mesh.get_boundary_polygon()
495
496            if verbose: print 'Getting indices inside mesh boundary'
497            indices = inside_polygon(point_coordinates, P, verbose = verbose)
498
499
500            if len(indices) != point_coordinates.shape[0]:
501                self.cropped_points = True
502                if verbose:
503                    print 'Done - %d points outside mesh have been cropped.'\
504                          %(point_coordinates.shape[0] - len(indices))
505
506            point_coordinates = take(point_coordinates, indices)
507            self.point_indices = indices
508
509
510
511
512        #Build n x m interpolation matrix
513        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
514        n = point_coordinates.shape[0]     #Nbr of data points
515
516        if verbose: print 'Number of datapoints: %d' %n
517        if verbose: print 'Number of basis functions: %d' %m
518
519        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
520        #However, Sparse_CSR does not have the same methods as Sparse yet
521        #The tests will reveal what needs to be done
522
523        #
524        #self.A = Sparse_CSR(Sparse(n,m))
525        #self.AtA = Sparse_CSR(Sparse(m,m))
526        self.A = Sparse(n,m)
527        self.AtA = Sparse(m,m)
528
529        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
530        root = build_quadtree(self.mesh,
531                              max_points_per_cell = max_points_per_cell)
532        #root.show()
533        self.expanded_quad_searches = []
534        #Compute matrix elements
535        for i in range(n):
536            #For each data_coordinate point
537
538            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
539            x = point_coordinates[i]
540
541            #Find vertices near x
542            candidate_vertices = root.search(x[0], x[1])
543            is_more_elements = True
544
545            element_found, sigma0, sigma1, sigma2, k = \
546                self.search_triangles_of_vertices(candidate_vertices, x)
547            first_expansion = True
548            while not element_found and is_more_elements and expand_search:
549                #if verbose: print 'Expanding search'
550                if first_expansion == True:
551                    self.expanded_quad_searches.append(1)
552                    first_expansion = False
553                else:
554                    end = len(self.expanded_quad_searches) - 1
555                    assert end >= 0
556                    self.expanded_quad_searches[end] += 1
557                candidate_vertices, branch = root.expand_search()
558                if branch == []:
559                    # Searching all the verts from the root cell that haven't
560                    # been searched.  This is the last try
561                    element_found, sigma0, sigma1, sigma2, k = \
562                      self.search_triangles_of_vertices(candidate_vertices, x)
563                    is_more_elements = False
564                else:
565                    element_found, sigma0, sigma1, sigma2, k = \
566                      self.search_triangles_of_vertices(candidate_vertices, x)
567
568               
569            #Update interpolation matrix A if necessary
570            if element_found is True:
571                #Assign values to matrix A
572
573                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
574                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
575                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
576
577                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
578                js     = [j0,j1,j2]
579
580                for j in js:
581                    self.A[i,j] = sigmas[j]
582                    for k in js:
583                        if interp_only == False:
584                            self.AtA[j,k] += sigmas[j]*sigmas[k]
585            else:
586                pass
587                #Ok if there is no triangle for datapoint
588                #(as in brute force version)
589                #raise 'Could not find triangle for point', x
590
591
592
593    def search_triangles_of_vertices(self, candidate_vertices, x):
594            #Find triangle containing x:
595            element_found = False
596
597            # This will be returned if element_found = False
598            sigma2 = -10.0
599            sigma0 = -10.0
600            sigma1 = -10.0
601            k = -10.0
602            #print "*$* candidate_vertices", candidate_vertices
603            #For all vertices in same cell as point x
604            for v in candidate_vertices:
605                #FIXME (DSG-DSG): this catches verts with no triangle.
606                #Currently pmesh is producing these.
607                #this should be stopped,
608                if self.mesh.vertexlist[v] is None:
609                    continue
610                #for each triangle id (k) which has v as a vertex
611                for k, _ in self.mesh.vertexlist[v]:
612
613                    #Get the three vertex_points of candidate triangle
614                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
615                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
616                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
617
618                    #print "PDSG - k", k
619                    #print "PDSG - xi0", xi0
620                    #print "PDSG - xi1", xi1
621                    #print "PDSG - xi2", xi2
622                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
623                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
624
625                    #Get the three normals
626                    n0 = self.mesh.get_normal(k, 0)
627                    n1 = self.mesh.get_normal(k, 1)
628                    n2 = self.mesh.get_normal(k, 2)
629
630
631                    #Compute interpolation
632                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
633                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
634                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
635
636                    #print "PDSG - sigma0", sigma0
637                    #print "PDSG - sigma1", sigma1
638                    #print "PDSG - sigma2", sigma2
639
640                    #FIXME: Maybe move out to test or something
641                    epsilon = 1.0e-6
642                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
643
644                    #Check that this triangle contains the data point
645
646                    #Sigmas can get negative within
647                    #machine precision on some machines (e.g nautilus)
648                    #Hence the small eps
649                    eps = 1.0e-15
650                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
651                        element_found = True
652                        break
653
654                if element_found is True:
655                    #Don't look for any other triangle
656                    break
657            return element_found, sigma0, sigma1, sigma2, k
658
659
660
661    def build_interpolation_matrix_A_brute(self, point_coordinates):
662        """Build n x m interpolation matrix, where
663        n is the number of data points and
664        m is the number of basis functions phi_k (one per vertex)
665
666        This is the brute force which is too slow for large problems,
667        but could be used for testing
668        """
669
670        from util import ensure_numeric
671
672        #Convert input to Numeric arrays
673        point_coordinates = ensure_numeric(point_coordinates, Float)
674
675        #Build n x m interpolation matrix
676        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
677        n = point_coordinates.shape[0]     #Nbr of data points
678
679        self.A = Sparse(n,m)
680        self.AtA = Sparse(m,m)
681
682        #Compute matrix elements
683        for i in range(n):
684            #For each data_coordinate point
685
686            x = point_coordinates[i]
687            element_found = False
688            k = 0
689            while not element_found and k < len(self.mesh):
690                #For each triangle (brute force)
691                #FIXME: Real algorithm should only visit relevant triangles
692
693                #Get the three vertex_points
694                xi0 = self.mesh.get_vertex_coordinate(k, 0)
695                xi1 = self.mesh.get_vertex_coordinate(k, 1)
696                xi2 = self.mesh.get_vertex_coordinate(k, 2)
697
698                #Get the three normals
699                n0 = self.mesh.get_normal(k, 0)
700                n1 = self.mesh.get_normal(k, 1)
701                n2 = self.mesh.get_normal(k, 2)
702
703                #Compute interpolation
704                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
705                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
706                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
707
708                #FIXME: Maybe move out to test or something
709                epsilon = 1.0e-6
710                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
711
712                #Check that this triangle contains data point
713                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
714                    element_found = True
715                    #Assign values to matrix A
716
717                    j0 = self.mesh.triangles[k,0] #Global vertex id
718                    #self.A[i, j0] = sigma0
719
720                    j1 = self.mesh.triangles[k,1] #Global vertex id
721                    #self.A[i, j1] = sigma1
722
723                    j2 = self.mesh.triangles[k,2] #Global vertex id
724                    #self.A[i, j2] = sigma2
725
726                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
727                    js     = [j0,j1,j2]
728
729                    for j in js:
730                        self.A[i,j] = sigmas[j]
731                        for k in js:
732                            self.AtA[j,k] += sigmas[j]*sigmas[k]
733                k = k+1
734
735
736
737    def get_A(self):
738        return self.A.todense()
739
740    def get_B(self):
741        return self.B.todense()
742
743    def get_D(self):
744        return self.D.todense()
745
746        #FIXME: Remember to re-introduce the 1/n factor in the
747        #interpolation term
748
749    def build_smoothing_matrix_D(self):
750        """Build m x m smoothing matrix, where
751        m is the number of basis functions phi_k (one per vertex)
752
753        The smoothing matrix is defined as
754
755        D = D1 + D2
756
757        where
758
759        [D1]_{k,l} = \int_\Omega
760           \frac{\partial \phi_k}{\partial x}
761           \frac{\partial \phi_l}{\partial x}\,
762           dx dy
763
764        [D2]_{k,l} = \int_\Omega
765           \frac{\partial \phi_k}{\partial y}
766           \frac{\partial \phi_l}{\partial y}\,
767           dx dy
768
769
770        The derivatives \frac{\partial \phi_k}{\partial x},
771        \frac{\partial \phi_k}{\partial x} for a particular triangle
772        are obtained by computing the gradient a_k, b_k for basis function k
773        """
774
775        #FIXME: algorithm might be optimised by computing local 9x9
776        #"element stiffness matrices:
777
778        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
779
780        self.D = Sparse(m,m)
781
782        #For each triangle compute contributions to D = D1+D2
783        for i in range(len(self.mesh)):
784
785            #Get area
786            area = self.mesh.areas[i]
787
788            #Get global vertex indices
789            v0 = self.mesh.triangles[i,0]
790            v1 = self.mesh.triangles[i,1]
791            v2 = self.mesh.triangles[i,2]
792
793            #Get the three vertex_points
794            xi0 = self.mesh.get_vertex_coordinate(i, 0)
795            xi1 = self.mesh.get_vertex_coordinate(i, 1)
796            xi2 = self.mesh.get_vertex_coordinate(i, 2)
797
798            #Compute gradients for each vertex
799            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
800                              1, 0, 0)
801
802            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
803                              0, 1, 0)
804
805            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
806                              0, 0, 1)
807
808            #Compute diagonal contributions
809            self.D[v0,v0] += (a0*a0 + b0*b0)*area
810            self.D[v1,v1] += (a1*a1 + b1*b1)*area
811            self.D[v2,v2] += (a2*a2 + b2*b2)*area
812
813            #Compute contributions for basis functions sharing edges
814            e01 = (a0*a1 + b0*b1)*area
815            self.D[v0,v1] += e01
816            self.D[v1,v0] += e01
817
818            e12 = (a1*a2 + b1*b2)*area
819            self.D[v1,v2] += e12
820            self.D[v2,v1] += e12
821
822            e20 = (a2*a0 + b2*b0)*area
823            self.D[v2,v0] += e20
824            self.D[v0,v2] += e20
825
826
827    def fit(self, z):
828        """Fit a smooth surface to given 1d array of data points z.
829
830        The smooth surface is computed at each vertex in the underlying
831        mesh using the formula given in the module doc string.
832
833        Pre Condition:
834          self.A, self.AtA and self.B have been initialised
835
836        Inputs:
837          z: Single 1d vector or array of data at the point_coordinates.
838        """
839
840        #Convert input to Numeric arrays
841        from pyvolution.util import ensure_numeric
842        z = ensure_numeric(z, Float)
843
844        if len(z.shape) > 1 :
845            raise VectorShapeError, 'Can only deal with 1d data vector'
846
847        if self.point_indices is not None:
848            #Remove values for any points that were outside mesh
849            z = take(z, self.point_indices)
850
851        #Compute right hand side based on data
852        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
853        # after a matrix is built, before calcs.
854        Atz = self.A.trans_mult(z)
855
856
857        #Check sanity
858        n, m = self.A.shape
859        if n<m and self.alpha == 0.0:
860            msg = 'ERROR (least_squares): Too few data points\n'
861            msg += 'There are only %d data points and alpha == 0. ' %n
862            msg += 'Need at least %d\n' %m
863            msg += 'Alternatively, set smoothing parameter alpha to a small '
864            msg += 'positive value,\ne.g. 1.0e-3.'
865            raise msg
866
867
868
869        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
870        #FIXME: Should we store the result here for later use? (ON)
871
872
873    def fit_points(self, z, verbose=False):
874        """Like fit, but more robust when each point has two or more attributes
875        FIXME (Ole): The name fit_points doesn't carry any meaning
876        for me. How about something like fit_multiple or fit_columns?
877        """
878
879        try:
880            if verbose: print 'Solving penalised least_squares problem'
881            return self.fit(z)
882        except VectorShapeError, e:
883            # broadcasting is not supported.
884
885            #Convert input to Numeric arrays
886            from util import ensure_numeric
887            z = ensure_numeric(z, Float)
888
889            #Build n x m interpolation matrix
890            m = self.mesh.coordinates.shape[0] #Number of vertices
891            n = z.shape[1]                     #Number of data points
892
893            f = zeros((m,n), Float) #Resulting columns
894
895            for i in range(z.shape[1]):
896                f[:,i] = self.fit(z[:,i])
897
898            return f
899
900
901    def interpolate(self, f):
902        """Evaluate smooth surface f at data points implied in self.A.
903
904        The mesh values representing a smooth surface are
905        assumed to be specified in f. This argument could,
906        for example have been obtained from the method self.fit()
907
908        Pre Condition:
909          self.A has been initialised
910
911        Inputs:
912          f: Vector or array of data at the mesh vertices.
913          If f is an array, interpolation will be done for each column as
914          per underlying matrix-matrix multiplication
915
916        Output:
917          Interpolated values at data points implied in self.A
918
919        """
920
921        return self.A * f
922
923    def cull_outsiders(self, f):
924        pass
925
926
927
928
929class Interpolation_function:
930    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
931    which is interpolated from time series defined at vertices of
932    triangular mesh (such as those stored in sww files)
933
934    Let m be the number of vertices, n the number of triangles
935    and p the number of timesteps.
936
937    Mandatory input
938        time:               px1 array of monotonously increasing times (Float)
939        quantities:         Dictionary of arrays or 1 array (Float)
940                            The arrays must either have dimensions pxm or mx1.
941                            The resulting function will be time dependent in
942                            the former case while it will be constan with
943                            respect to time in the latter case.
944       
945    Optional input:
946        quantity_names:     List of keys into the quantities dictionary
947        vertex_coordinates: mx2 array of coordinates (Float)
948        triangles:          nx3 array of indices into vertex_coordinates (Int)
949        interpolation_points: Nx2 array of coordinates to be interpolated to
950        verbose:            Level of reporting
951   
952   
953    The quantities returned by the callable object are specified by
954    the list quantities which must contain the names of the
955    quantities to be returned and also reflect the order, e.g. for
956    the shallow water wave equation, on would have
957    quantities = ['stage', 'xmomentum', 'ymomentum']
958
959    The parameter interpolation_points decides at which points interpolated
960    quantities are to be computed whenever object is called.
961    If None, return average value
962    """
963
964   
965   
966    def __init__(self,
967                 time,
968                 quantities,
969                 quantity_names = None, 
970                 vertex_coordinates = None,
971                 triangles = None,
972                 interpolation_points = None,
973                 verbose = False):
974        """Initialise object and build spatial interpolation if required
975        """
976
977        from Numeric import array, zeros, Float, alltrue, concatenate,\
978             reshape, ArrayType
979
980
981        from util import mean, ensure_numeric
982        from config import time_format
983        import types
984
985
986
987        #Check temporal info
988        time = ensure_numeric(time)       
989        msg = 'Time must be a monotonuosly '
990        msg += 'increasing sequence %s' %time
991        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
992
993
994        #Check if quantities is a single array only
995        if type(quantities) != types.DictType:
996            quantities = ensure_numeric(quantities)
997            quantity_names = ['Attribute']
998
999            #Make it a dictionary
1000            quantities = {quantity_names[0]: quantities}
1001
1002
1003        #Use keys if no names are specified
1004        if quantity_names is None:
1005            quantity_names = quantities.keys()
1006
1007
1008        #Check spatial info
1009        if vertex_coordinates is None:
1010            self.spatial = False
1011        else:   
1012            vertex_coordinates = ensure_numeric(vertex_coordinates)
1013
1014            assert triangles is not None, 'Triangles array must be specified'
1015            triangles = ensure_numeric(triangles)
1016            self.spatial = True           
1017           
1018
1019 
1020        #Save for use with statistics
1021        self.quantity_names = quantity_names       
1022        self.quantities = quantities       
1023        self.vertex_coordinates = vertex_coordinates
1024        self.interpolation_points = interpolation_points
1025        self.T = time[:]  # Time assumed to be relative to starttime
1026        self.index = 0    # Initial time index
1027        self.precomputed_values = {}
1028           
1029
1030
1031        #Precomputed spatial interpolation if requested
1032        if interpolation_points is not None:
1033            if self.spatial is False:
1034                raise 'Triangles and vertex_coordinates must be specified'
1035           
1036            try:
1037                self.interpolation_points = ensure_numeric(interpolation_points)
1038            except:
1039                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1040                      'or a list of points\n'
1041                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1042                                      '...')
1043                raise msg
1044
1045
1046            m = len(self.interpolation_points)
1047            p = len(self.T)
1048           
1049            for name in quantity_names:
1050                self.precomputed_values[name] = zeros((p, m), Float)
1051
1052            #Build interpolator
1053            interpol = Interpolation(vertex_coordinates,
1054                                     triangles,
1055                                     point_coordinates = \
1056                                     self.interpolation_points,
1057                                     alpha = 0,
1058                                     precrop = False, 
1059                                     verbose = verbose)
1060
1061            if verbose: print 'Interpolate'
1062            for i, t in enumerate(self.T):
1063                #Interpolate quantities at this timestep
1064                if verbose and i%((p+10)/10)==0:
1065                    print ' time step %d of %d' %(i, p)
1066                   
1067                for name in quantity_names:
1068                    if len(quantities[name].shape) == 2:
1069                        result = interpol.interpolate(quantities[name][i,:])
1070                    else:
1071                       #Assume no time dependency
1072                       result = interpol.interpolate(quantities[name][:])
1073                       
1074                    self.precomputed_values[name][i, :] = result
1075                   
1076                       
1077
1078            #Report
1079            if verbose:
1080                print self.statistics()
1081                #self.print_statistics()
1082           
1083        else:
1084            #Store quantitites as is
1085            for name in quantity_names:
1086                self.precomputed_values[name] = quantities[name]
1087
1088
1089        #else:
1090        #    #Return an average, making this a time series
1091        #    for name in quantity_names:
1092        #        self.values[name] = zeros(len(self.T), Float)
1093        #
1094        #    if verbose: print 'Compute mean values'
1095        #    for i, t in enumerate(self.T):
1096        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1097        #        for name in quantity_names:
1098        #           self.values[name][i] = mean(quantities[name][i,:])
1099
1100
1101
1102
1103    def __repr__(self):
1104        #return 'Interpolation function (spatio-temporal)'
1105        return self.statistics()
1106   
1107
1108    def __call__(self, t, point_id = None, x = None, y = None):
1109        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1110
1111        Inputs:
1112          t: time - Model time. Must lie within existing timesteps
1113          point_id: index of one of the preprocessed points.
1114          x, y:     Overrides location, point_id ignored
1115         
1116          If spatial info is present and all of x,y,point_id
1117          are None an exception is raised
1118                   
1119          If no spatial info is present, point_id and x,y arguments are ignored
1120          making f a function of time only.
1121
1122         
1123          FIXME: point_id could also be a slice
1124          FIXME: What if x and y are vectors?
1125          FIXME: What about f(x,y) without t?
1126        """
1127
1128        from math import pi, cos, sin, sqrt
1129        from Numeric import zeros, Float
1130        from util import mean       
1131
1132        if self.spatial is True:
1133            if point_id is None:
1134                if x is None or y is None:
1135                    msg = 'Either point_id or x and y must be specified'
1136                    raise msg
1137            else:
1138                if self.interpolation_points is None:
1139                    msg = 'Interpolation_function must be instantiated ' +\
1140                          'with a list of interpolation points before parameter ' +\
1141                          'point_id can be used'
1142                    raise msg
1143
1144
1145        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1146        msg += ' does not match model time: %s\n' %t
1147        if t < self.T[0]: raise msg
1148        if t > self.T[-1]: raise msg
1149
1150        oldindex = self.index #Time index
1151
1152        #Find current time slot
1153        while t > self.T[self.index]: self.index += 1
1154        while t < self.T[self.index]: self.index -= 1
1155
1156        if t == self.T[self.index]:
1157            #Protect against case where t == T[-1] (last time)
1158            # - also works in general when t == T[i]
1159            ratio = 0
1160        else:
1161            #t is now between index and index+1
1162            ratio = (t - self.T[self.index])/\
1163                    (self.T[self.index+1] - self.T[self.index])
1164
1165        #Compute interpolated values
1166        q = zeros(len(self.quantity_names), Float)
1167
1168        for i, name in enumerate(self.quantity_names):
1169            Q = self.precomputed_values[name]
1170
1171            if self.spatial is False:
1172                #If there is no spatial info               
1173                assert len(Q.shape) == 1
1174
1175                Q0 = Q[self.index]
1176                if ratio > 0: Q1 = Q[self.index+1]
1177
1178            else:
1179                if x is not None and y is not None:
1180                    #Interpolate to x, y
1181                   
1182                    raise 'x,y interpolation not yet implemented'
1183                else:
1184                    #Use precomputed point
1185                    Q0 = Q[self.index, point_id]
1186                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1187
1188            #Linear temporal interpolation   
1189            if ratio > 0:
1190                q[i] = Q0 + ratio*(Q1 - Q0)
1191            else:
1192                q[i] = Q0
1193
1194
1195        #Return vector of interpolated values
1196        #if len(q) == 1:
1197        #    return q[0]
1198        #else:
1199        #    return q
1200
1201
1202        #Return vector of interpolated values
1203        #FIXME:
1204        if self.spatial is True:
1205            return q
1206        else:
1207            #Replicate q according to x and y
1208            #This is e.g used for Wind_stress
1209            if x is None or y is None: 
1210                return q
1211            else:
1212                try:
1213                    N = len(x)
1214                except:
1215                    return q
1216                else:
1217                    from Numeric import ones, Float
1218                    #x is a vector - Create one constant column for each value
1219                    N = len(x)
1220                    assert len(y) == N, 'x and y must have same length'
1221                    res = []
1222                    for col in q:
1223                        res.append(col*ones(N, Float))
1224                       
1225                return res
1226
1227
1228    def statistics(self):
1229        """Output statistics about interpolation_function
1230        """
1231       
1232        vertex_coordinates = self.vertex_coordinates
1233        interpolation_points = self.interpolation_points               
1234        quantity_names = self.quantity_names
1235        quantities = self.quantities
1236        precomputed_values = self.precomputed_values                 
1237               
1238        x = vertex_coordinates[:,0]
1239        y = vertex_coordinates[:,1]               
1240
1241        str =  '------------------------------------------------\n'
1242        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1243        str += '  Extent:\n'
1244        str += '    x in [%f, %f], len(x) == %d\n'\
1245               %(min(x), max(x), len(x))
1246        str += '    y in [%f, %f], len(y) == %d\n'\
1247               %(min(y), max(y), len(y))
1248        str += '    t in [%f, %f], len(t) == %d\n'\
1249               %(min(self.T), max(self.T), len(self.T))
1250        str += '  Quantities:\n'
1251        for name in quantity_names:
1252            q = quantities[name][:].flat
1253            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1254
1255        if interpolation_points is not None:   
1256            str += '  Interpolation points (xi, eta):'\
1257                   ' number of points == %d\n' %interpolation_points.shape[0]
1258            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1259                                            max(interpolation_points[:,0]))
1260            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1261                                             max(interpolation_points[:,1]))
1262            str += '  Interpolated quantities (over all timesteps):\n'
1263       
1264            for name in quantity_names:
1265                q = precomputed_values[name][:].flat
1266                str += '    %s at interpolation points in [%f, %f]\n'\
1267                       %(name, min(q), max(q))
1268        str += '------------------------------------------------\n'
1269
1270        return str
1271
1272        #FIXME: Delete
1273        #print '------------------------------------------------'
1274        #print 'Interpolation_function statistics:'
1275        #print '  Extent:'
1276        #print '    x in [%f, %f], len(x) == %d'\
1277        #      %(min(x), max(x), len(x))
1278        #print '    y in [%f, %f], len(y) == %d'\
1279        #      %(min(y), max(y), len(y))
1280        #print '    t in [%f, %f], len(t) == %d'\
1281        #      %(min(self.T), max(self.T), len(self.T))
1282        #print '  Quantities:'
1283        #for name in quantity_names:
1284        #    q = quantities[name][:].flat
1285        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1286        #print '  Interpolation points (xi, eta):'\
1287        #      ' number of points == %d ' %interpolation_points.shape[0]
1288        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1289        #                             max(interpolation_points[:,0]))
1290        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1291        #                              max(interpolation_points[:,1]))
1292        #print '  Interpolated quantities (over all timesteps):'
1293        #
1294        #for name in quantity_names:
1295        #    q = precomputed_values[name][:].flat
1296        #    print '    %s at interpolation points in [%f, %f]'\
1297        #          %(name, min(q), max(q))
1298        #print '------------------------------------------------'
1299
1300
1301#-------------------------------------------------------------
1302if __name__ == "__main__":
1303    """
1304    Load in a mesh and data points with attributes.
1305    Fit the attributes to the mesh.
1306    Save a new mesh file.
1307    """
1308    import os, sys
1309    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1310            %os.path.basename(sys.argv[0])
1311
1312    if len(sys.argv) < 4:
1313        print usage
1314    else:
1315        mesh_file = sys.argv[1]
1316        point_file = sys.argv[2]
1317        mesh_output_file = sys.argv[3]
1318
1319        expand_search = False
1320        if len(sys.argv) > 4:
1321            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1322                expand_search = True
1323            else:
1324                expand_search = False
1325
1326        verbose = False
1327        if len(sys.argv) > 5:
1328            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1329                verbose = False
1330            else:
1331                verbose = True
1332
1333        if len(sys.argv) > 6:
1334            alpha = sys.argv[6]
1335        else:
1336            alpha = DEFAULT_ALPHA
1337
1338        # This is used more for testing
1339        if len(sys.argv) > 7:
1340            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1341                display_errors = False
1342            else:
1343                display_errors = True
1344           
1345        t0 = time.time()
1346        try:
1347            fit_to_mesh_file(mesh_file,
1348                         point_file,
1349                         mesh_output_file,
1350                         alpha,
1351                         verbose= verbose,
1352                         expand_search = expand_search,
1353                         display_errors = display_errors)
1354        except IOError,e:
1355            import sys; sys.exit(1)
1356
1357        print 'That took %.2f seconds' %(time.time()-t0)
1358
Note: See TracBrowser for help on using the repository browser.