source: inundation/pyvolution/least_squares.py @ 2378

Last change on this file since 2378 was 2347, checked in by ole, 19 years ago

Change default value for verbose from None to False in set_values

File size: 47.9 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from pyvolution.mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from pyvolution.sparse import Sparse, Sparse_CSR
29from pyvolution.cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME (DSG-ON) reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203
204
205    interp = Interpolation(vertex_coordinates,
206                           triangles,
207                           point_coordinates,
208                           alpha = alpha,
209                           verbose = verbose,
210                           expand_search = expand_search,
211                           data_origin = data_origin,
212                           mesh_origin = mesh_origin,
213                           precrop = precrop)
214
215    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
216    return vertex_attributes
217
218
219
220def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
221                    verbose = False, reduction = 1):
222    """Fits attributes from pts file to MxN rectangular mesh
223
224    Read pts file and create rectangular mesh of resolution MxN such that
225    it covers all points specified in pts file.
226
227    FIXME: This may be a temporary function until we decide on
228    netcdf formats etc
229
230    FIXME: Uses elevation hardwired
231    """
232
233    import  mesh_factory
234    from load_mesh.loadASCII import import_points_file
235   
236    if verbose: print 'Read pts'
237    points_dict = import_points_file(pts_name)
238    #points, attributes = util.read_xya(pts_name)
239
240    #Reduce number of points a bit
241    points = points_dict['pointlist'][::reduction]
242    elevation = points_dict['attributelist']['elevation']  #Must be elevation
243    elevation = elevation[::reduction]
244
245    if verbose: print 'Got %d data points' %len(points)
246
247    if verbose: print 'Create mesh'
248    #Find extent
249    max_x = min_x = points[0][0]
250    max_y = min_y = points[0][1]
251    for point in points[1:]:
252        x = point[0]
253        if x > max_x: max_x = x
254        if x < min_x: min_x = x
255        y = point[1]
256        if y > max_y: max_y = y
257        if y < min_y: min_y = y
258
259    #Create appropriate mesh
260    vertex_coordinates, triangles, boundary =\
261         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
262                                (min_x, min_y))
263
264    #Fit attributes to mesh
265    vertex_attributes = fit_to_mesh(vertex_coordinates,
266                        triangles,
267                        points,
268                        elevation, alpha=alpha, verbose=verbose)
269
270
271
272    return vertex_coordinates, triangles, boundary, vertex_attributes
273
274
275
276class Interpolation:
277
278    def __init__(self,
279                 vertex_coordinates,
280                 triangles,
281                 point_coordinates = None,
282                 alpha = None,
283                 verbose = False,
284                 expand_search = True,
285                 interp_only = False,
286                 max_points_per_cell = 30,
287                 mesh_origin = None,
288                 data_origin = None,
289                 precrop = False):
290
291
292        """ Build interpolation matrix mapping from
293        function values at vertices to function values at data points
294
295        Inputs:
296
297          vertex_coordinates: List of coordinate pairs [xi, eta] of
298          points constituting mesh (or a an m x 2 Numeric array)
299          Points may appear multiple times
300          (e.g. if vertices have discontinuities)
301
302          triangles: List of 3-tuples (or a Numeric array) of
303          integers representing indices of all vertices in the mesh.
304
305          point_coordinates: List of coordinate pairs [x, y] of
306          data points (or an nx2 Numeric array)
307          If point_coordinates is absent, only smoothing matrix will
308          be built
309
310          alpha: Smoothing parameter
311
312          data_origin and mesh_origin are 3-tuples consisting of
313          UTM zone, easting and northing. If specified
314          point coordinates and vertex coordinates are assumed to be
315          relative to their respective origins.
316
317        """
318        from pyvolution.util import ensure_numeric
319
320        #Convert input to Numeric arrays
321        triangles = ensure_numeric(triangles, Int)
322        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
323
324        #Build underlying mesh
325        if verbose: print 'Building mesh'
326        #self.mesh = General_mesh(vertex_coordinates, triangles,
327        #FIXME: Trying the normal mesh while testing precrop,
328        #       The functionality of boundary_polygon is needed for that
329
330        #FIXME - geo ref does not have to go into mesh.
331        # Change the point co-ords to conform to the
332        # mesh co-ords early in the code
333        if mesh_origin is None:
334            geo = None
335        else:
336            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
337        self.mesh = Mesh(vertex_coordinates, triangles,
338                         geo_reference = geo)
339       
340        self.mesh.check_integrity()
341
342        self.data_origin = data_origin
343
344        self.point_indices = None
345
346        #Smoothing parameter
347        if alpha is None:
348            self.alpha = DEFAULT_ALPHA
349        else:   
350            self.alpha = alpha
351
352
353        if point_coordinates is not None:
354            if verbose: print 'Building interpolation matrix'
355            self.build_interpolation_matrix_A(point_coordinates,
356                                              verbose = verbose,
357                                              expand_search = expand_search,
358                                              interp_only = interp_only, 
359                                              max_points_per_cell =\
360                                              max_points_per_cell,
361                                              data_origin = data_origin,
362                                              precrop = precrop)
363        #Build coefficient matrices
364        if interp_only == False:
365            self.build_coefficient_matrix_B(point_coordinates,
366                                        verbose = verbose,
367                                        expand_search = expand_search,
368                                        max_points_per_cell =\
369                                        max_points_per_cell,
370                                        data_origin = data_origin,
371                                        precrop = precrop)
372
373    def set_point_coordinates(self, point_coordinates,
374                              data_origin = None,
375                              verbose = False,
376                              precrop = True):
377        """
378        A public interface to setting the point co-ordinates.
379        """
380        if point_coordinates is not None:
381            if verbose: print 'Building interpolation matrix'
382            self.build_interpolation_matrix_A(point_coordinates,
383                                              verbose = verbose,
384                                              data_origin = data_origin,
385                                              precrop = precrop)
386        self.build_coefficient_matrix_B(point_coordinates, data_origin)
387
388    def build_coefficient_matrix_B(self, point_coordinates=None,
389                                   verbose = False, expand_search = True,
390                                   max_points_per_cell=30,
391                                   data_origin = None,
392                                   precrop = False):
393        """Build final coefficient matrix"""
394
395
396        if self.alpha <> 0:
397            if verbose: print 'Building smoothing matrix'
398            self.build_smoothing_matrix_D()
399
400        if point_coordinates is not None:
401            if self.alpha <> 0:
402                self.B = self.AtA + self.alpha*self.D
403            else:
404                self.B = self.AtA
405
406            #Convert self.B matrix to CSR format for faster matrix vector
407            self.B = Sparse_CSR(self.B)
408
409    def build_interpolation_matrix_A(self, point_coordinates,
410                                     verbose = False, expand_search = True,
411                                     max_points_per_cell=30,
412                                     data_origin = None,
413                                     precrop = False,
414                                     interp_only = False):
415        """Build n x m interpolation matrix, where
416        n is the number of data points and
417        m is the number of basis functions phi_k (one per vertex)
418
419        This algorithm uses a quad tree data structure for fast binning of data points
420        origin is a 3-tuple consisting of UTM zone, easting and northing.
421        If specified coordinates are assumed to be relative to this origin.
422
423        This one will override any data_origin that may be specified in
424        interpolation instance
425
426        """
427
428
429
430        #FIXME (Ole): Check that this function is memeory efficient.
431        #6 million datapoints and 300000 basis functions
432        #causes out-of-memory situation
433        #First thing to check is whether there is room for self.A and self.AtA
434        #
435        #Maybe we need some sort of blocking
436
437        from pyvolution.quad import build_quadtree
438        from utilities.numerical_tools import ensure_numeric
439        from utilities.polygon import inside_polygon
440       
441
442        if data_origin is None:
443            data_origin = self.data_origin #Use the one from
444                                           #interpolation instance
445
446        #Convert input to Numeric arrays just in case.
447        point_coordinates = ensure_numeric(point_coordinates, Float)
448
449        #Keep track of discarded points (if any).
450        #This is only registered if precrop is True
451        self.cropped_points = False
452
453        #Shift data points to same origin as mesh (if specified)
454
455        #FIXME this will shift if there was no geo_ref.
456        #But all this should be removed anyhow.
457        #change coords before this point
458        mesh_origin = self.mesh.geo_reference.get_origin()
459        if point_coordinates is not None:
460            if data_origin is not None:
461                if mesh_origin is not None:
462
463                    #Transformation:
464                    #
465                    #Let x_0 be the reference point of the point coordinates
466                    #and xi_0 the reference point of the mesh.
467                    #
468                    #A point coordinate (x + x_0) is then made relative
469                    #to xi_0 by
470                    #
471                    # x_new = x + x_0 - xi_0
472                    #
473                    #and similarly for eta
474
475                    x_offset = data_origin[1] - mesh_origin[1]
476                    y_offset = data_origin[2] - mesh_origin[2]
477                else: #Shift back to a zero origin
478                    x_offset = data_origin[1]
479                    y_offset = data_origin[2]
480
481                point_coordinates[:,0] += x_offset
482                point_coordinates[:,1] += y_offset
483            else:
484                if mesh_origin is not None:
485                    #Use mesh origin for data points
486                    point_coordinates[:,0] -= mesh_origin[1]
487                    point_coordinates[:,1] -= mesh_origin[2]
488
489
490
491        #Remove points falling outside mesh boundary
492        #This reduced one example from 1356 seconds to 825 seconds
493
494       
495        if precrop is True:
496            from Numeric import take
497
498            if verbose: print 'Getting boundary polygon'
499            P = self.mesh.get_boundary_polygon()
500
501            if verbose: print 'Getting indices inside mesh boundary'
502            indices = inside_polygon(point_coordinates, P, verbose = verbose)
503
504
505            if len(indices) != point_coordinates.shape[0]:
506                self.cropped_points = True
507                if verbose:
508                    print 'Done - %d points outside mesh have been cropped.'\
509                          %(point_coordinates.shape[0] - len(indices))
510
511            point_coordinates = take(point_coordinates, indices)
512            self.point_indices = indices
513
514
515
516
517        #Build n x m interpolation matrix
518        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
519        n = point_coordinates.shape[0]     #Nbr of data points
520
521        if verbose: print 'Number of datapoints: %d' %n
522        if verbose: print 'Number of basis functions: %d' %m
523
524        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
525        #However, Sparse_CSR does not have the same methods as Sparse yet
526        #The tests will reveal what needs to be done
527
528        #
529        #self.A = Sparse_CSR(Sparse(n,m))
530        #self.AtA = Sparse_CSR(Sparse(m,m))
531        self.A = Sparse(n,m)
532        self.AtA = Sparse(m,m)
533
534        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
535        root = build_quadtree(self.mesh,
536                              max_points_per_cell = max_points_per_cell)
537        #root.show()
538        self.expanded_quad_searches = []
539        #Compute matrix elements
540        for i in range(n):
541            #For each data_coordinate point
542
543            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
544            x = point_coordinates[i]
545
546            #Find vertices near x
547            candidate_vertices = root.search(x[0], x[1])
548            is_more_elements = True
549
550            element_found, sigma0, sigma1, sigma2, k = \
551                self.search_triangles_of_vertices(candidate_vertices, x)
552            first_expansion = True
553            while not element_found and is_more_elements and expand_search:
554                #if verbose: print 'Expanding search'
555                if first_expansion == True:
556                    self.expanded_quad_searches.append(1)
557                    first_expansion = False
558                else:
559                    end = len(self.expanded_quad_searches) - 1
560                    assert end >= 0
561                    self.expanded_quad_searches[end] += 1
562                candidate_vertices, branch = root.expand_search()
563                if branch == []:
564                    # Searching all the verts from the root cell that haven't
565                    # been searched.  This is the last try
566                    element_found, sigma0, sigma1, sigma2, k = \
567                      self.search_triangles_of_vertices(candidate_vertices, x)
568                    is_more_elements = False
569                else:
570                    element_found, sigma0, sigma1, sigma2, k = \
571                      self.search_triangles_of_vertices(candidate_vertices, x)
572
573               
574            #Update interpolation matrix A if necessary
575            if element_found is True:
576                #Assign values to matrix A
577
578                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
579                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
580                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
581
582                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
583                js     = [j0,j1,j2]
584
585                for j in js:
586                    self.A[i,j] = sigmas[j]
587                    for k in js:
588                        if interp_only == False:
589                            self.AtA[j,k] += sigmas[j]*sigmas[k]
590            else:
591                pass
592                #Ok if there is no triangle for datapoint
593                #(as in brute force version)
594                #raise 'Could not find triangle for point', x
595
596
597
598    def search_triangles_of_vertices(self, candidate_vertices, x):
599            #Find triangle containing x:
600            element_found = False
601
602            # This will be returned if element_found = False
603            sigma2 = -10.0
604            sigma0 = -10.0
605            sigma1 = -10.0
606            k = -10.0
607            #print "*$* candidate_vertices", candidate_vertices
608            #For all vertices in same cell as point x
609            for v in candidate_vertices:
610                #FIXME (DSG-DSG): this catches verts with no triangle.
611                #Currently pmesh is producing these.
612                #this should be stopped,
613                if self.mesh.vertexlist[v] is None:
614                    continue
615                #for each triangle id (k) which has v as a vertex
616                for k, _ in self.mesh.vertexlist[v]:
617
618                    #Get the three vertex_points of candidate triangle
619                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
620                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
621                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
622
623                    #print "PDSG - k", k
624                    #print "PDSG - xi0", xi0
625                    #print "PDSG - xi1", xi1
626                    #print "PDSG - xi2", xi2
627                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
628                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
629
630                    #Get the three normals
631                    n0 = self.mesh.get_normal(k, 0)
632                    n1 = self.mesh.get_normal(k, 1)
633                    n2 = self.mesh.get_normal(k, 2)
634
635
636                    #Compute interpolation
637                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
638                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
639                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
640
641                    #print "PDSG - sigma0", sigma0
642                    #print "PDSG - sigma1", sigma1
643                    #print "PDSG - sigma2", sigma2
644
645                    #FIXME: Maybe move out to test or something
646                    epsilon = 1.0e-6
647                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
648
649                    #Check that this triangle contains the data point
650
651                    #Sigmas can get negative within
652                    #machine precision on some machines (e.g nautilus)
653                    #Hence the small eps
654                    eps = 1.0e-15
655                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
656                        element_found = True
657                        break
658
659                if element_found is True:
660                    #Don't look for any other triangle
661                    break
662            return element_found, sigma0, sigma1, sigma2, k
663
664
665
666    def build_interpolation_matrix_A_brute(self, point_coordinates):
667        """Build n x m interpolation matrix, where
668        n is the number of data points and
669        m is the number of basis functions phi_k (one per vertex)
670
671        This is the brute force which is too slow for large problems,
672        but could be used for testing
673        """
674
675        from util import ensure_numeric
676
677        #Convert input to Numeric arrays
678        point_coordinates = ensure_numeric(point_coordinates, Float)
679
680        #Build n x m interpolation matrix
681        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
682        n = point_coordinates.shape[0]     #Nbr of data points
683
684        self.A = Sparse(n,m)
685        self.AtA = Sparse(m,m)
686
687        #Compute matrix elements
688        for i in range(n):
689            #For each data_coordinate point
690
691            x = point_coordinates[i]
692            element_found = False
693            k = 0
694            while not element_found and k < len(self.mesh):
695                #For each triangle (brute force)
696                #FIXME: Real algorithm should only visit relevant triangles
697
698                #Get the three vertex_points
699                xi0 = self.mesh.get_vertex_coordinate(k, 0)
700                xi1 = self.mesh.get_vertex_coordinate(k, 1)
701                xi2 = self.mesh.get_vertex_coordinate(k, 2)
702
703                #Get the three normals
704                n0 = self.mesh.get_normal(k, 0)
705                n1 = self.mesh.get_normal(k, 1)
706                n2 = self.mesh.get_normal(k, 2)
707
708                #Compute interpolation
709                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
710                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
711                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
712
713                #FIXME: Maybe move out to test or something
714                epsilon = 1.0e-6
715                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
716
717                #Check that this triangle contains data point
718                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
719                    element_found = True
720                    #Assign values to matrix A
721
722                    j0 = self.mesh.triangles[k,0] #Global vertex id
723                    #self.A[i, j0] = sigma0
724
725                    j1 = self.mesh.triangles[k,1] #Global vertex id
726                    #self.A[i, j1] = sigma1
727
728                    j2 = self.mesh.triangles[k,2] #Global vertex id
729                    #self.A[i, j2] = sigma2
730
731                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
732                    js     = [j0,j1,j2]
733
734                    for j in js:
735                        self.A[i,j] = sigmas[j]
736                        for k in js:
737                            self.AtA[j,k] += sigmas[j]*sigmas[k]
738                k = k+1
739
740
741
742    def get_A(self):
743        return self.A.todense()
744
745    def get_B(self):
746        return self.B.todense()
747
748    def get_D(self):
749        return self.D.todense()
750
751        #FIXME: Remember to re-introduce the 1/n factor in the
752        #interpolation term
753
754    def build_smoothing_matrix_D(self):
755        """Build m x m smoothing matrix, where
756        m is the number of basis functions phi_k (one per vertex)
757
758        The smoothing matrix is defined as
759
760        D = D1 + D2
761
762        where
763
764        [D1]_{k,l} = \int_\Omega
765           \frac{\partial \phi_k}{\partial x}
766           \frac{\partial \phi_l}{\partial x}\,
767           dx dy
768
769        [D2]_{k,l} = \int_\Omega
770           \frac{\partial \phi_k}{\partial y}
771           \frac{\partial \phi_l}{\partial y}\,
772           dx dy
773
774
775        The derivatives \frac{\partial \phi_k}{\partial x},
776        \frac{\partial \phi_k}{\partial x} for a particular triangle
777        are obtained by computing the gradient a_k, b_k for basis function k
778        """
779
780        #FIXME: algorithm might be optimised by computing local 9x9
781        #"element stiffness matrices:
782
783        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
784
785        self.D = Sparse(m,m)
786
787        #For each triangle compute contributions to D = D1+D2
788        for i in range(len(self.mesh)):
789
790            #Get area
791            area = self.mesh.areas[i]
792
793            #Get global vertex indices
794            v0 = self.mesh.triangles[i,0]
795            v1 = self.mesh.triangles[i,1]
796            v2 = self.mesh.triangles[i,2]
797
798            #Get the three vertex_points
799            xi0 = self.mesh.get_vertex_coordinate(i, 0)
800            xi1 = self.mesh.get_vertex_coordinate(i, 1)
801            xi2 = self.mesh.get_vertex_coordinate(i, 2)
802
803            #Compute gradients for each vertex
804            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
805                              1, 0, 0)
806
807            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
808                              0, 1, 0)
809
810            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
811                              0, 0, 1)
812
813            #Compute diagonal contributions
814            self.D[v0,v0] += (a0*a0 + b0*b0)*area
815            self.D[v1,v1] += (a1*a1 + b1*b1)*area
816            self.D[v2,v2] += (a2*a2 + b2*b2)*area
817
818            #Compute contributions for basis functions sharing edges
819            e01 = (a0*a1 + b0*b1)*area
820            self.D[v0,v1] += e01
821            self.D[v1,v0] += e01
822
823            e12 = (a1*a2 + b1*b2)*area
824            self.D[v1,v2] += e12
825            self.D[v2,v1] += e12
826
827            e20 = (a2*a0 + b2*b0)*area
828            self.D[v2,v0] += e20
829            self.D[v0,v2] += e20
830
831
832    def fit(self, z):
833        """Fit a smooth surface to given 1d array of data points z.
834
835        The smooth surface is computed at each vertex in the underlying
836        mesh using the formula given in the module doc string.
837
838        Pre Condition:
839          self.A, self.AtA and self.B have been initialised
840
841        Inputs:
842          z: Single 1d vector or array of data at the point_coordinates.
843        """
844
845        #Convert input to Numeric arrays
846        from pyvolution.util import ensure_numeric
847        z = ensure_numeric(z, Float)
848
849        if len(z.shape) > 1 :
850            raise VectorShapeError, 'Can only deal with 1d data vector'
851
852        if self.point_indices is not None:
853            #Remove values for any points that were outside mesh
854            z = take(z, self.point_indices)
855
856        #Compute right hand side based on data
857        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
858        # after a matrix is built, before calcs.
859        Atz = self.A.trans_mult(z)
860
861
862        #Check sanity
863        n, m = self.A.shape
864        if n<m and self.alpha == 0.0:
865            msg = 'ERROR (least_squares): Too few data points\n'
866            msg += 'There are only %d data points and alpha == 0. ' %n
867            msg += 'Need at least %d\n' %m
868            msg += 'Alternatively, set smoothing parameter alpha to a small '
869            msg += 'positive value,\ne.g. 1.0e-3.'
870            raise msg
871
872
873
874        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
875        #FIXME: Should we store the result here for later use? (ON)
876
877
878    def fit_points(self, z, verbose=False):
879        """Like fit, but more robust when each point has two or more attributes
880        FIXME (Ole): The name fit_points doesn't carry any meaning
881        for me. How about something like fit_multiple or fit_columns?
882        """
883
884        try:
885            if verbose: print 'Solving penalised least_squares problem'
886            return self.fit(z)
887        except VectorShapeError, e:
888            # broadcasting is not supported.
889
890            #Convert input to Numeric arrays
891            from util import ensure_numeric
892            z = ensure_numeric(z, Float)
893
894            #Build n x m interpolation matrix
895            m = self.mesh.coordinates.shape[0] #Number of vertices
896            n = z.shape[1]                     #Number of data points
897
898            f = zeros((m,n), Float) #Resulting columns
899
900            for i in range(z.shape[1]):
901                f[:,i] = self.fit(z[:,i])
902
903            return f
904
905
906    def interpolate(self, f):
907        """Evaluate smooth surface f at data points implied in self.A.
908
909        The mesh values representing a smooth surface are
910        assumed to be specified in f. This argument could,
911        for example have been obtained from the method self.fit()
912
913        Pre Condition:
914          self.A has been initialised
915
916        Inputs:
917          f: Vector or array of data at the mesh vertices.
918          If f is an array, interpolation will be done for each column as
919          per underlying matrix-matrix multiplication
920
921        Output:
922          Interpolated values at data points implied in self.A
923
924        """
925
926        return self.A * f
927
928    def cull_outsiders(self, f):
929        pass
930
931
932
933
934class Interpolation_function:
935    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
936    which is interpolated from time series defined at vertices of
937    triangular mesh (such as those stored in sww files)
938
939    Let m be the number of vertices, n the number of triangles
940    and p the number of timesteps.
941
942    Mandatory input
943        time:               px1 array of monotonously increasing times (Float)
944        quantities:         Dictionary of arrays or 1 array (Float)
945                            The arrays must either have dimensions pxm or mx1.
946                            The resulting function will be time dependent in
947                            the former case while it will be constan with
948                            respect to time in the latter case.
949       
950    Optional input:
951        quantity_names:     List of keys into the quantities dictionary
952        vertex_coordinates: mx2 array of coordinates (Float)
953        triangles:          nx3 array of indices into vertex_coordinates (Int)
954        interpolation_points: Nx2 array of coordinates to be interpolated to
955        verbose:            Level of reporting
956   
957   
958    The quantities returned by the callable object are specified by
959    the list quantities which must contain the names of the
960    quantities to be returned and also reflect the order, e.g. for
961    the shallow water wave equation, on would have
962    quantities = ['stage', 'xmomentum', 'ymomentum']
963
964    The parameter interpolation_points decides at which points interpolated
965    quantities are to be computed whenever object is called.
966    If None, return average value
967    """
968
969   
970   
971    def __init__(self,
972                 time,
973                 quantities,
974                 quantity_names = None, 
975                 vertex_coordinates = None,
976                 triangles = None,
977                 interpolation_points = None,
978                 verbose = False):
979        """Initialise object and build spatial interpolation if required
980        """
981
982        from Numeric import array, zeros, Float, alltrue, concatenate,\
983             reshape, ArrayType
984
985
986        from util import mean, ensure_numeric
987        from config import time_format
988        import types
989
990
991
992        #Check temporal info
993        time = ensure_numeric(time)       
994        msg = 'Time must be a monotonuosly '
995        msg += 'increasing sequence %s' %time
996        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
997
998
999        #Check if quantities is a single array only
1000        if type(quantities) != types.DictType:
1001            quantities = ensure_numeric(quantities)
1002            quantity_names = ['Attribute']
1003
1004            #Make it a dictionary
1005            quantities = {quantity_names[0]: quantities}
1006
1007
1008        #Use keys if no names are specified
1009        if quantity_names is None:
1010            quantity_names = quantities.keys()
1011
1012
1013        #Check spatial info
1014        if vertex_coordinates is None:
1015            self.spatial = False
1016        else:   
1017            vertex_coordinates = ensure_numeric(vertex_coordinates)
1018
1019            assert triangles is not None, 'Triangles array must be specified'
1020            triangles = ensure_numeric(triangles)
1021            self.spatial = True           
1022           
1023
1024 
1025        #Save for use with statistics
1026        self.quantity_names = quantity_names       
1027        self.quantities = quantities       
1028        self.vertex_coordinates = vertex_coordinates
1029        self.interpolation_points = interpolation_points
1030        self.T = time[:]  # Time assumed to be relative to starttime
1031        self.index = 0    # Initial time index
1032        self.precomputed_values = {}
1033           
1034
1035
1036        #Precomputed spatial interpolation if requested
1037        if interpolation_points is not None:
1038            if self.spatial is False:
1039                raise 'Triangles and vertex_coordinates must be specified'
1040           
1041            try:
1042                self.interpolation_points = ensure_numeric(interpolation_points)
1043            except:
1044                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1045                      'or a list of points\n'
1046                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1047                                      '...')
1048                raise msg
1049
1050
1051            m = len(self.interpolation_points)
1052            p = len(self.T)
1053           
1054            for name in quantity_names:
1055                self.precomputed_values[name] = zeros((p, m), Float)
1056
1057            #Build interpolator
1058            interpol = Interpolation(vertex_coordinates,
1059                                     triangles,
1060                                     point_coordinates = \
1061                                     self.interpolation_points,
1062                                     alpha = 0,
1063                                     precrop = False, 
1064                                     verbose = verbose)
1065
1066            if verbose: print 'Interpolate'
1067            for i, t in enumerate(self.T):
1068                #Interpolate quantities at this timestep
1069                if verbose and i%((p+10)/10)==0:
1070                    print ' time step %d of %d' %(i, p)
1071                   
1072                for name in quantity_names:
1073                    if len(quantities[name].shape) == 2:
1074                        result = interpol.interpolate(quantities[name][i,:])
1075                    else:
1076                       #Assume no time dependency
1077                       result = interpol.interpolate(quantities[name][:])
1078                       
1079                    self.precomputed_values[name][i, :] = result
1080                   
1081                       
1082
1083            #Report
1084            if verbose:
1085                print self.statistics()
1086                #self.print_statistics()
1087           
1088        else:
1089            #Store quantitites as is
1090            for name in quantity_names:
1091                self.precomputed_values[name] = quantities[name]
1092
1093
1094        #else:
1095        #    #Return an average, making this a time series
1096        #    for name in quantity_names:
1097        #        self.values[name] = zeros(len(self.T), Float)
1098        #
1099        #    if verbose: print 'Compute mean values'
1100        #    for i, t in enumerate(self.T):
1101        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1102        #        for name in quantity_names:
1103        #           self.values[name][i] = mean(quantities[name][i,:])
1104
1105
1106
1107
1108    def __repr__(self):
1109        #return 'Interpolation function (spatio-temporal)'
1110        return self.statistics()
1111   
1112
1113    def __call__(self, t, point_id = None, x = None, y = None):
1114        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1115
1116        Inputs:
1117          t: time - Model time. Must lie within existing timesteps
1118          point_id: index of one of the preprocessed points.
1119          x, y:     Overrides location, point_id ignored
1120         
1121          If spatial info is present and all of x,y,point_id
1122          are None an exception is raised
1123                   
1124          If no spatial info is present, point_id and x,y arguments are ignored
1125          making f a function of time only.
1126
1127         
1128          FIXME: point_id could also be a slice
1129          FIXME: What if x and y are vectors?
1130          FIXME: What about f(x,y) without t?
1131        """
1132
1133        from math import pi, cos, sin, sqrt
1134        from Numeric import zeros, Float
1135        from util import mean       
1136
1137        if self.spatial is True:
1138            if point_id is None:
1139                if x is None or y is None:
1140                    msg = 'Either point_id or x and y must be specified'
1141                    raise msg
1142            else:
1143                if self.interpolation_points is None:
1144                    msg = 'Interpolation_function must be instantiated ' +\
1145                          'with a list of interpolation points before parameter ' +\
1146                          'point_id can be used'
1147                    raise msg
1148
1149
1150        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1151        msg += ' does not match model time: %s\n' %t
1152        if t < self.T[0]: raise msg
1153        if t > self.T[-1]: raise msg
1154
1155        oldindex = self.index #Time index
1156
1157        #Find current time slot
1158        while t > self.T[self.index]: self.index += 1
1159        while t < self.T[self.index]: self.index -= 1
1160
1161        if t == self.T[self.index]:
1162            #Protect against case where t == T[-1] (last time)
1163            # - also works in general when t == T[i]
1164            ratio = 0
1165        else:
1166            #t is now between index and index+1
1167            ratio = (t - self.T[self.index])/\
1168                    (self.T[self.index+1] - self.T[self.index])
1169
1170        #Compute interpolated values
1171        q = zeros(len(self.quantity_names), Float)
1172
1173        for i, name in enumerate(self.quantity_names):
1174            Q = self.precomputed_values[name]
1175
1176            if self.spatial is False:
1177                #If there is no spatial info               
1178                assert len(Q.shape) == 1
1179
1180                Q0 = Q[self.index]
1181                if ratio > 0: Q1 = Q[self.index+1]
1182
1183            else:
1184                if x is not None and y is not None:
1185                    #Interpolate to x, y
1186                   
1187                    raise 'x,y interpolation not yet implemented'
1188                else:
1189                    #Use precomputed point
1190                    Q0 = Q[self.index, point_id]
1191                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1192
1193            #Linear temporal interpolation   
1194            if ratio > 0:
1195                q[i] = Q0 + ratio*(Q1 - Q0)
1196            else:
1197                q[i] = Q0
1198
1199
1200        #Return vector of interpolated values
1201        #if len(q) == 1:
1202        #    return q[0]
1203        #else:
1204        #    return q
1205
1206
1207        #Return vector of interpolated values
1208        #FIXME:
1209        if self.spatial is True:
1210            return q
1211        else:
1212            #Replicate q according to x and y
1213            #This is e.g used for Wind_stress
1214            if x is None or y is None: 
1215                return q
1216            else:
1217                try:
1218                    N = len(x)
1219                except:
1220                    return q
1221                else:
1222                    from Numeric import ones, Float
1223                    #x is a vector - Create one constant column for each value
1224                    N = len(x)
1225                    assert len(y) == N, 'x and y must have same length'
1226                    res = []
1227                    for col in q:
1228                        res.append(col*ones(N, Float))
1229                       
1230                return res
1231
1232
1233    def statistics(self):
1234        """Output statistics about interpolation_function
1235        """
1236       
1237        vertex_coordinates = self.vertex_coordinates
1238        interpolation_points = self.interpolation_points               
1239        quantity_names = self.quantity_names
1240        quantities = self.quantities
1241        precomputed_values = self.precomputed_values                 
1242               
1243        x = vertex_coordinates[:,0]
1244        y = vertex_coordinates[:,1]               
1245
1246        str =  '------------------------------------------------\n'
1247        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1248        str += '  Extent:\n'
1249        str += '    x in [%f, %f], len(x) == %d\n'\
1250               %(min(x), max(x), len(x))
1251        str += '    y in [%f, %f], len(y) == %d\n'\
1252               %(min(y), max(y), len(y))
1253        str += '    t in [%f, %f], len(t) == %d\n'\
1254               %(min(self.T), max(self.T), len(self.T))
1255        str += '  Quantities:\n'
1256        for name in quantity_names:
1257            q = quantities[name][:].flat
1258            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1259
1260        if interpolation_points is not None:   
1261            str += '  Interpolation points (xi, eta):'\
1262                   ' number of points == %d\n' %interpolation_points.shape[0]
1263            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1264                                            max(interpolation_points[:,0]))
1265            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1266                                             max(interpolation_points[:,1]))
1267            str += '  Interpolated quantities (over all timesteps):\n'
1268       
1269            for name in quantity_names:
1270                q = precomputed_values[name][:].flat
1271                str += '    %s at interpolation points in [%f, %f]\n'\
1272                       %(name, min(q), max(q))
1273        str += '------------------------------------------------\n'
1274
1275        return str
1276
1277        #FIXME: Delete
1278        #print '------------------------------------------------'
1279        #print 'Interpolation_function statistics:'
1280        #print '  Extent:'
1281        #print '    x in [%f, %f], len(x) == %d'\
1282        #      %(min(x), max(x), len(x))
1283        #print '    y in [%f, %f], len(y) == %d'\
1284        #      %(min(y), max(y), len(y))
1285        #print '    t in [%f, %f], len(t) == %d'\
1286        #      %(min(self.T), max(self.T), len(self.T))
1287        #print '  Quantities:'
1288        #for name in quantity_names:
1289        #    q = quantities[name][:].flat
1290        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1291        #print '  Interpolation points (xi, eta):'\
1292        #      ' number of points == %d ' %interpolation_points.shape[0]
1293        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1294        #                             max(interpolation_points[:,0]))
1295        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1296        #                              max(interpolation_points[:,1]))
1297        #print '  Interpolated quantities (over all timesteps):'
1298        #
1299        #for name in quantity_names:
1300        #    q = precomputed_values[name][:].flat
1301        #    print '    %s at interpolation points in [%f, %f]'\
1302        #          %(name, min(q), max(q))
1303        #print '------------------------------------------------'
1304
1305
1306#-------------------------------------------------------------
1307if __name__ == "__main__":
1308    """
1309    Load in a mesh and data points with attributes.
1310    Fit the attributes to the mesh.
1311    Save a new mesh file.
1312    """
1313    import os, sys
1314    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1315            %os.path.basename(sys.argv[0])
1316
1317    if len(sys.argv) < 4:
1318        print usage
1319    else:
1320        mesh_file = sys.argv[1]
1321        point_file = sys.argv[2]
1322        mesh_output_file = sys.argv[3]
1323
1324        expand_search = False
1325        if len(sys.argv) > 4:
1326            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1327                expand_search = True
1328            else:
1329                expand_search = False
1330
1331        verbose = False
1332        if len(sys.argv) > 5:
1333            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1334                verbose = False
1335            else:
1336                verbose = True
1337
1338        if len(sys.argv) > 6:
1339            alpha = sys.argv[6]
1340        else:
1341            alpha = DEFAULT_ALPHA
1342
1343        # This is used more for testing
1344        if len(sys.argv) > 7:
1345            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1346                display_errors = False
1347            else:
1348                display_errors = True
1349           
1350        t0 = time.time()
1351        try:
1352            fit_to_mesh_file(mesh_file,
1353                         point_file,
1354                         mesh_output_file,
1355                         alpha,
1356                         verbose= verbose,
1357                         expand_search = expand_search,
1358                         display_errors = display_errors)
1359        except IOError,e:
1360            import sys; sys.exit(1)
1361
1362        print 'That took %.2f seconds' %(time.time()-t0)
1363
Note: See TracBrowser for help on using the repository browser.