source: inundation/pyvolution/least_squares.py @ 1933

Last change on this file since 1933 was 1933, checked in by ole, 19 years ago

Comments

File size: 47.3 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME (DSG-ON) reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203    interp = Interpolation(vertex_coordinates,
204                           triangles,
205                           point_coordinates,
206                           alpha = alpha,
207                           verbose = verbose,
208                           expand_search = expand_search,
209                           data_origin = data_origin,
210                           mesh_origin = mesh_origin,
211                           precrop = precrop)
212
213    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
214    return vertex_attributes
215
216
217
218def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
219                    verbose = False, reduction = 1):
220    """Fits attributes from pts file to MxN rectangular mesh
221
222    Read pts file and create rectangular mesh of resolution MxN such that
223    it covers all points specified in pts file.
224
225    FIXME: This may be a temporary function until we decide on
226    netcdf formats etc
227
228    FIXME: Uses elevation hardwired
229    """
230
231    import  mesh_factory
232    from load_mesh.loadASCII import import_points_file
233   
234    if verbose: print 'Read pts'
235    points_dict = import_points_file(pts_name)
236    #points, attributes = util.read_xya(pts_name)
237
238    #Reduce number of points a bit
239    points = points_dict['pointlist'][::reduction]
240    elevation = points_dict['attributelist']['elevation']  #Must be elevation
241    elevation = elevation[::reduction]
242
243    if verbose: print 'Got %d data points' %len(points)
244
245    if verbose: print 'Create mesh'
246    #Find extent
247    max_x = min_x = points[0][0]
248    max_y = min_y = points[0][1]
249    for point in points[1:]:
250        x = point[0]
251        if x > max_x: max_x = x
252        if x < min_x: min_x = x
253        y = point[1]
254        if y > max_y: max_y = y
255        if y < min_y: min_y = y
256
257    #Create appropriate mesh
258    vertex_coordinates, triangles, boundary =\
259         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
260                                (min_x, min_y))
261
262    #Fit attributes to mesh
263    vertex_attributes = fit_to_mesh(vertex_coordinates,
264                        triangles,
265                        points,
266                        elevation, alpha=alpha, verbose=verbose)
267
268
269
270    return vertex_coordinates, triangles, boundary, vertex_attributes
271
272
273
274class Interpolation:
275
276    def __init__(self,
277                 vertex_coordinates,
278                 triangles,
279                 point_coordinates = None,
280                 alpha = None,
281                 verbose = False,
282                 expand_search = True,
283                 interp_only = False,
284                 max_points_per_cell = 30,
285                 mesh_origin = None,
286                 data_origin = None,
287                 precrop = False):
288
289
290        """ Build interpolation matrix mapping from
291        function values at vertices to function values at data points
292
293        Inputs:
294
295          vertex_coordinates: List of coordinate pairs [xi, eta] of
296          points constituting mesh (or a an m x 2 Numeric array)
297          Points may appear multiple times
298          (e.g. if vertices have discontinuities)
299
300          triangles: List of 3-tuples (or a Numeric array) of
301          integers representing indices of all vertices in the mesh.
302
303          point_coordinates: List of coordinate pairs [x, y] of
304          data points (or an nx2 Numeric array)
305          If point_coordinates is absent, only smoothing matrix will
306          be built
307
308          alpha: Smoothing parameter
309
310          data_origin and mesh_origin are 3-tuples consisting of
311          UTM zone, easting and northing. If specified
312          point coordinates and vertex coordinates are assumed to be
313          relative to their respective origins.
314
315        """
316        from util import ensure_numeric
317
318        #Convert input to Numeric arrays
319        triangles = ensure_numeric(triangles, Int)
320        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
321
322        #Build underlying mesh
323        if verbose: print 'Building mesh'
324        #self.mesh = General_mesh(vertex_coordinates, triangles,
325        #FIXME: Trying the normal mesh while testing precrop,
326        #       The functionality of boundary_polygon is needed for that
327
328        #FIXME - geo ref does not have to go into mesh.
329        # Change the point co-ords to conform to the
330        # mesh co-ords early in the code
331        if mesh_origin == None:
332            geo = None
333        else:
334            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
335        self.mesh = Mesh(vertex_coordinates, triangles,
336                         geo_reference = geo)
337       
338        self.mesh.check_integrity()
339
340        self.data_origin = data_origin
341
342        self.point_indices = None
343
344        #Smoothing parameter
345        if alpha is None:
346            self.alpha = DEFAULT_ALPHA
347        else:   
348            self.alpha = alpha
349
350
351        if point_coordinates is not None:
352            if verbose: print 'Building interpolation matrix'
353            self.build_interpolation_matrix_A(point_coordinates,
354                                              verbose = verbose,
355                                              expand_search = expand_search,
356                                              interp_only = interp_only, 
357                                              max_points_per_cell =\
358                                              max_points_per_cell,
359                                              data_origin = data_origin,
360                                              precrop = precrop)
361        #Build coefficient matrices
362        if interp_only == False:
363            self.build_coefficient_matrix_B(point_coordinates,
364                                        verbose = verbose,
365                                        expand_search = expand_search,
366                                        max_points_per_cell =\
367                                        max_points_per_cell,
368                                        data_origin = data_origin,
369                                        precrop = precrop)
370
371
372    def set_point_coordinates(self, point_coordinates,
373                              data_origin = None,
374                              verbose = False,
375                              precrop = True):
376        """
377        A public interface to setting the point co-ordinates.
378        """
379        if point_coordinates is not None:
380            if verbose: print 'Building interpolation matrix'
381            self.build_interpolation_matrix_A(point_coordinates,
382                                              verbose = verbose,
383                                              data_origin = data_origin,
384                                              precrop = precrop)
385        self.build_coefficient_matrix_B(point_coordinates, data_origin)
386
387    def build_coefficient_matrix_B(self, point_coordinates=None,
388                                   verbose = False, expand_search = True,
389                                   max_points_per_cell=30,
390                                   data_origin = None,
391                                   precrop = False):
392        """Build final coefficient matrix"""
393
394
395        if self.alpha <> 0:
396            if verbose: print 'Building smoothing matrix'
397            self.build_smoothing_matrix_D()
398
399        if point_coordinates is not None:
400            if self.alpha <> 0:
401                self.B = self.AtA + self.alpha*self.D
402            else:
403                self.B = self.AtA
404
405            #Convert self.B matrix to CSR format for faster matrix vector
406            self.B = Sparse_CSR(self.B)
407
408    def build_interpolation_matrix_A(self, point_coordinates,
409                                     verbose = False, expand_search = True,
410                                     max_points_per_cell=30,
411                                     data_origin = None,
412                                     precrop = False,
413                                     interp_only = False):
414        """Build n x m interpolation matrix, where
415        n is the number of data points and
416        m is the number of basis functions phi_k (one per vertex)
417
418        This algorithm uses a quad tree data structure for fast binning of data points
419        origin is a 3-tuple consisting of UTM zone, easting and northing.
420        If specified coordinates are assumed to be relative to this origin.
421
422        This one will override any data_origin that may be specified in
423        interpolation instance
424
425        """
426
427
428        #FIXME (Ole): Check that this function is memeory efficient.
429        #6 million datapoints and 300000 basis functions
430        #causes out-of-memory situation
431        #First thing to check is whether there is room for self.A and self.AtA
432        #
433        #Maybe we need some sort of blocking
434
435        from quad import build_quadtree
436        from utilities.numerical_tools import ensure_numeric
437        from utilities.polygon import inside_polygon
438       
439
440        if data_origin is None:
441            data_origin = self.data_origin #Use the one from
442                                           #interpolation instance
443
444        #Convert input to Numeric arrays just in case.
445        point_coordinates = ensure_numeric(point_coordinates, Float)
446
447        #Keep track of discarded points (if any).
448        #This is only registered if precrop is True
449        self.cropped_points = False
450
451        #Shift data points to same origin as mesh (if specified)
452
453        #FIXME this will shift if there was no geo_ref.
454        #But all this should be removed anyhow.
455        #change coords before this point
456        mesh_origin = self.mesh.geo_reference.get_origin()
457        if point_coordinates is not None:
458            if data_origin is not None:
459                if mesh_origin is not None:
460
461                    #Transformation:
462                    #
463                    #Let x_0 be the reference point of the point coordinates
464                    #and xi_0 the reference point of the mesh.
465                    #
466                    #A point coordinate (x + x_0) is then made relative
467                    #to xi_0 by
468                    #
469                    # x_new = x + x_0 - xi_0
470                    #
471                    #and similarly for eta
472
473                    x_offset = data_origin[1] - mesh_origin[1]
474                    y_offset = data_origin[2] - mesh_origin[2]
475                else: #Shift back to a zero origin
476                    x_offset = data_origin[1]
477                    y_offset = data_origin[2]
478
479                point_coordinates[:,0] += x_offset
480                point_coordinates[:,1] += y_offset
481            else:
482                if mesh_origin is not None:
483                    #Use mesh origin for data points
484                    point_coordinates[:,0] -= mesh_origin[1]
485                    point_coordinates[:,1] -= mesh_origin[2]
486
487
488
489        #Remove points falling outside mesh boundary
490        #This reduced one example from 1356 seconds to 825 seconds
491        if precrop is True:
492            from Numeric import take
493
494            if verbose: print 'Getting boundary polygon'
495            P = self.mesh.get_boundary_polygon()
496
497            if verbose: print 'Getting indices inside mesh boundary'
498            indices = inside_polygon(point_coordinates, P, verbose = verbose)
499
500
501            if len(indices) != point_coordinates.shape[0]:
502                self.cropped_points = True
503                if verbose:
504                    print 'Done - %d points outside mesh have been cropped.'\
505                          %(point_coordinates.shape[0] - len(indices))
506
507            point_coordinates = take(point_coordinates, indices)
508            self.point_indices = indices
509
510
511
512
513        #Build n x m interpolation matrix
514        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
515        n = point_coordinates.shape[0]     #Nbr of data points
516
517        if verbose: print 'Number of datapoints: %d' %n
518        if verbose: print 'Number of basis functions: %d' %m
519
520        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
521        #However, Sparse_CSR does not have the same methods as Sparse yet
522        #The tests will reveal what needs to be done
523
524        #
525        #self.A = Sparse_CSR(Sparse(n,m))
526        #self.AtA = Sparse_CSR(Sparse(m,m))
527        self.A = Sparse(n,m)
528        self.AtA = Sparse(m,m)
529
530        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
531        root = build_quadtree(self.mesh,
532                              max_points_per_cell = max_points_per_cell)
533
534        #Compute matrix elements
535        for i in range(n):
536            #For each data_coordinate point
537
538            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
539            x = point_coordinates[i]
540
541            #Find vertices near x
542            candidate_vertices = root.search(x[0], x[1])
543            is_more_elements = True
544
545            element_found, sigma0, sigma1, sigma2, k = \
546                self.search_triangles_of_vertices(candidate_vertices, x)
547            while not element_found and is_more_elements and expand_search:
548                #if verbose: print 'Expanding search'
549                candidate_vertices, branch = root.expand_search()
550                if branch == []:
551                    # Searching all the verts from the root cell that haven't
552                    # been searched.  This is the last try
553                    element_found, sigma0, sigma1, sigma2, k = \
554                      self.search_triangles_of_vertices(candidate_vertices, x)
555                    is_more_elements = False
556                else:
557                    element_found, sigma0, sigma1, sigma2, k = \
558                      self.search_triangles_of_vertices(candidate_vertices, x)
559
560
561            #Update interpolation matrix A if necessary
562            if element_found is True:
563                #Assign values to matrix A
564
565                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
566                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
567                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
568
569                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
570                js     = [j0,j1,j2]
571
572                for j in js:
573                    self.A[i,j] = sigmas[j]
574                    for k in js:
575                        if interp_only == False:
576                            self.AtA[j,k] += sigmas[j]*sigmas[k]
577            else:
578                pass
579                #Ok if there is no triangle for datapoint
580                #(as in brute force version)
581                #raise 'Could not find triangle for point', x
582
583
584
585    def search_triangles_of_vertices(self, candidate_vertices, x):
586            #Find triangle containing x:
587            element_found = False
588
589            # This will be returned if element_found = False
590            sigma2 = -10.0
591            sigma0 = -10.0
592            sigma1 = -10.0
593            k = -10.0
594
595            #For all vertices in same cell as point x
596            for v in candidate_vertices:
597
598                #for each triangle id (k) which has v as a vertex
599                for k, _ in self.mesh.vertexlist[v]:
600
601                    #Get the three vertex_points of candidate triangle
602                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
603                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
604                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
605
606                    #print "PDSG - k", k
607                    #print "PDSG - xi0", xi0
608                    #print "PDSG - xi1", xi1
609                    #print "PDSG - xi2", xi2
610                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
611                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
612
613                    #Get the three normals
614                    n0 = self.mesh.get_normal(k, 0)
615                    n1 = self.mesh.get_normal(k, 1)
616                    n2 = self.mesh.get_normal(k, 2)
617
618
619                    #Compute interpolation
620                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
621                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
622                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
623
624                    #print "PDSG - sigma0", sigma0
625                    #print "PDSG - sigma1", sigma1
626                    #print "PDSG - sigma2", sigma2
627
628                    #FIXME: Maybe move out to test or something
629                    epsilon = 1.0e-6
630                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
631
632                    #Check that this triangle contains the data point
633
634                    #Sigmas can get negative within
635                    #machine precision on some machines (e.g nautilus)
636                    #Hence the small eps
637                    eps = 1.0e-15
638                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
639                        element_found = True
640                        break
641
642                if element_found is True:
643                    #Don't look for any other triangle
644                    break
645            return element_found, sigma0, sigma1, sigma2, k
646
647
648
649    def build_interpolation_matrix_A_brute(self, point_coordinates):
650        """Build n x m interpolation matrix, where
651        n is the number of data points and
652        m is the number of basis functions phi_k (one per vertex)
653
654        This is the brute force which is too slow for large problems,
655        but could be used for testing
656        """
657
658        from util import ensure_numeric
659
660        #Convert input to Numeric arrays
661        point_coordinates = ensure_numeric(point_coordinates, Float)
662
663        #Build n x m interpolation matrix
664        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
665        n = point_coordinates.shape[0]     #Nbr of data points
666
667        self.A = Sparse(n,m)
668        self.AtA = Sparse(m,m)
669
670        #Compute matrix elements
671        for i in range(n):
672            #For each data_coordinate point
673
674            x = point_coordinates[i]
675            element_found = False
676            k = 0
677            while not element_found and k < len(self.mesh):
678                #For each triangle (brute force)
679                #FIXME: Real algorithm should only visit relevant triangles
680
681                #Get the three vertex_points
682                xi0 = self.mesh.get_vertex_coordinate(k, 0)
683                xi1 = self.mesh.get_vertex_coordinate(k, 1)
684                xi2 = self.mesh.get_vertex_coordinate(k, 2)
685
686                #Get the three normals
687                n0 = self.mesh.get_normal(k, 0)
688                n1 = self.mesh.get_normal(k, 1)
689                n2 = self.mesh.get_normal(k, 2)
690
691                #Compute interpolation
692                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
693                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
694                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
695
696                #FIXME: Maybe move out to test or something
697                epsilon = 1.0e-6
698                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
699
700                #Check that this triangle contains data point
701                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
702                    element_found = True
703                    #Assign values to matrix A
704
705                    j0 = self.mesh.triangles[k,0] #Global vertex id
706                    #self.A[i, j0] = sigma0
707
708                    j1 = self.mesh.triangles[k,1] #Global vertex id
709                    #self.A[i, j1] = sigma1
710
711                    j2 = self.mesh.triangles[k,2] #Global vertex id
712                    #self.A[i, j2] = sigma2
713
714                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
715                    js     = [j0,j1,j2]
716
717                    for j in js:
718                        self.A[i,j] = sigmas[j]
719                        for k in js:
720                            self.AtA[j,k] += sigmas[j]*sigmas[k]
721                k = k+1
722
723
724
725    def get_A(self):
726        return self.A.todense()
727
728    def get_B(self):
729        return self.B.todense()
730
731    def get_D(self):
732        return self.D.todense()
733
734        #FIXME: Remember to re-introduce the 1/n factor in the
735        #interpolation term
736
737    def build_smoothing_matrix_D(self):
738        """Build m x m smoothing matrix, where
739        m is the number of basis functions phi_k (one per vertex)
740
741        The smoothing matrix is defined as
742
743        D = D1 + D2
744
745        where
746
747        [D1]_{k,l} = \int_\Omega
748           \frac{\partial \phi_k}{\partial x}
749           \frac{\partial \phi_l}{\partial x}\,
750           dx dy
751
752        [D2]_{k,l} = \int_\Omega
753           \frac{\partial \phi_k}{\partial y}
754           \frac{\partial \phi_l}{\partial y}\,
755           dx dy
756
757
758        The derivatives \frac{\partial \phi_k}{\partial x},
759        \frac{\partial \phi_k}{\partial x} for a particular triangle
760        are obtained by computing the gradient a_k, b_k for basis function k
761        """
762
763        #FIXME: algorithm might be optimised by computing local 9x9
764        #"element stiffness matrices:
765
766        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
767
768        self.D = Sparse(m,m)
769
770        #For each triangle compute contributions to D = D1+D2
771        for i in range(len(self.mesh)):
772
773            #Get area
774            area = self.mesh.areas[i]
775
776            #Get global vertex indices
777            v0 = self.mesh.triangles[i,0]
778            v1 = self.mesh.triangles[i,1]
779            v2 = self.mesh.triangles[i,2]
780
781            #Get the three vertex_points
782            xi0 = self.mesh.get_vertex_coordinate(i, 0)
783            xi1 = self.mesh.get_vertex_coordinate(i, 1)
784            xi2 = self.mesh.get_vertex_coordinate(i, 2)
785
786            #Compute gradients for each vertex
787            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
788                              1, 0, 0)
789
790            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
791                              0, 1, 0)
792
793            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
794                              0, 0, 1)
795
796            #Compute diagonal contributions
797            self.D[v0,v0] += (a0*a0 + b0*b0)*area
798            self.D[v1,v1] += (a1*a1 + b1*b1)*area
799            self.D[v2,v2] += (a2*a2 + b2*b2)*area
800
801            #Compute contributions for basis functions sharing edges
802            e01 = (a0*a1 + b0*b1)*area
803            self.D[v0,v1] += e01
804            self.D[v1,v0] += e01
805
806            e12 = (a1*a2 + b1*b2)*area
807            self.D[v1,v2] += e12
808            self.D[v2,v1] += e12
809
810            e20 = (a2*a0 + b2*b0)*area
811            self.D[v2,v0] += e20
812            self.D[v0,v2] += e20
813
814
815    def fit(self, z):
816        """Fit a smooth surface to given 1d array of data points z.
817
818        The smooth surface is computed at each vertex in the underlying
819        mesh using the formula given in the module doc string.
820
821        Pre Condition:
822          self.A, self.AtA and self.B have been initialised
823
824        Inputs:
825          z: Single 1d vector or array of data at the point_coordinates.
826        """
827
828        #Convert input to Numeric arrays
829        from util import ensure_numeric
830        z = ensure_numeric(z, Float)
831
832        if len(z.shape) > 1 :
833            raise VectorShapeError, 'Can only deal with 1d data vector'
834
835        if self.point_indices is not None:
836            #Remove values for any points that were outside mesh
837            z = take(z, self.point_indices)
838
839        #Compute right hand side based on data
840        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
841        # after a matrix is built, before calcs.
842        Atz = self.A.trans_mult(z)
843
844
845        #Check sanity
846        n, m = self.A.shape
847        if n<m and self.alpha == 0.0:
848            msg = 'ERROR (least_squares): Too few data points\n'
849            msg += 'There are only %d data points and alpha == 0. ' %n
850            msg += 'Need at least %d\n' %m
851            msg += 'Alternatively, set smoothing parameter alpha to a small '
852            msg += 'positive value,\ne.g. 1.0e-3.'
853            raise msg
854
855
856
857        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
858        #FIXME: Should we store the result here for later use? (ON)
859
860
861    def fit_points(self, z, verbose=False):
862        """Like fit, but more robust when each point has two or more attributes
863        FIXME (Ole): The name fit_points doesn't carry any meaning
864        for me. How about something like fit_multiple or fit_columns?
865        """
866
867        try:
868            if verbose: print 'Solving penalised least_squares problem'
869            return self.fit(z)
870        except VectorShapeError, e:
871            # broadcasting is not supported.
872
873            #Convert input to Numeric arrays
874            from util import ensure_numeric
875            z = ensure_numeric(z, Float)
876
877            #Build n x m interpolation matrix
878            m = self.mesh.coordinates.shape[0] #Number of vertices
879            n = z.shape[1]                     #Number of data points
880
881            f = zeros((m,n), Float) #Resulting columns
882
883            for i in range(z.shape[1]):
884                f[:,i] = self.fit(z[:,i])
885
886            return f
887
888
889    def interpolate(self, f):
890        """Evaluate smooth surface f at data points implied in self.A.
891
892        The mesh values representing a smooth surface are
893        assumed to be specified in f. This argument could,
894        for example have been obtained from the method self.fit()
895
896        Pre Condition:
897          self.A has been initialised
898
899        Inputs:
900          f: Vector or array of data at the mesh vertices.
901          If f is an array, interpolation will be done for each column as
902          per underlying matrix-matrix multiplication
903
904        Output:
905          Interpolated values at data points implied in self.A
906
907        """
908
909        return self.A * f
910
911    def cull_outsiders(self, f):
912        pass
913
914
915
916
917class Interpolation_function:
918    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
919    which is interpolated from time series defined at vertices of
920    triangular mesh (such as those stored in sww files)
921
922    Let m be the number of vertices, n the number of triangles
923    and p the number of timesteps.
924
925    Mandatory input
926        time:               px1 array of monotonously increasing times (Float)
927        quantities:         Dictionary of arrays or 1 array (Float)
928                            The arrays must either have dimensions pxm or mx1.
929                            The resulting function will be time dependent in
930                            the former case while it will be constant with
931                            respect to time in the latter case.
932       
933    Optional input:
934        quantity_names:     List of keys into the quantities dictionary
935        vertex_coordinates: mx2 array of coordinates (Float)
936        triangles:          nx3 array of indices into vertex_coordinates (Int)
937        interpolation_points: Nx2 array of coordinates to be interpolated to
938        verbose:            Level of reporting
939   
940   
941    The quantities returned by the callable object are specified by
942    the list quantities which must contain the names of the
943    quantities to be returned and also reflect the order, e.g. for
944    the shallow water wave equation, on would have
945    quantities = ['stage', 'xmomentum', 'ymomentum']
946
947    The parameter interpolation_points decides at which points interpolated
948    quantities are to be computed whenever object is called.
949    If None, return average value
950    """
951
952
953    #FIXME: Pending use case, Implement arbitrary expressions as with
954    #set_quantity and sww2dem using the underlying functionality for
955    #doing that.
956   
957    def __init__(self,
958                 time,
959                 quantities,
960                 quantity_names = None, 
961                 vertex_coordinates = None,
962                 triangles = None,
963                 interpolation_points = None,
964                 verbose = False):
965        """Initialise object and build spatial interpolation if required
966        """
967
968        from Numeric import array, zeros, Float, alltrue, concatenate,\
969             reshape, ArrayType
970
971
972        from util import mean, ensure_numeric
973        from config import time_format
974        import types
975
976
977
978        #Check temporal info
979        time = ensure_numeric(time)       
980        msg = 'Time must be a monotonuosly '
981        msg += 'increasing sequence %s' %time
982        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
983
984
985        #Check if quantities is a single array only
986        if type(quantities) != types.DictType:
987            quantities = ensure_numeric(quantities)
988            quantity_names = ['Attribute']
989
990            #Make it a dictionary
991            quantities = {quantity_names[0]: quantities}
992
993
994        #Use keys if no names are specified
995        if quantity_names is None:
996            quantity_names = quantities.keys()
997
998
999        #Check spatial info
1000        if vertex_coordinates is None:
1001            self.spatial = False
1002        else:   
1003            vertex_coordinates = ensure_numeric(vertex_coordinates)
1004
1005            assert triangles is not None, 'Triangles array must be specified'
1006            triangles = ensure_numeric(triangles)
1007            self.spatial = True           
1008           
1009
1010 
1011        #Save for use with statistics
1012        self.quantity_names = quantity_names       
1013        self.quantities = quantities       
1014        self.vertex_coordinates = vertex_coordinates
1015        self.interpolation_points = interpolation_points
1016        self.T = time[:]  # Time assumed to be relative to starttime
1017        self.index = 0    # Initial time index
1018        self.precomputed_values = {}
1019           
1020
1021
1022        #Precomputed spatial interpolation if requested
1023        if interpolation_points is not None:
1024            if self.spatial is False:
1025                raise 'Triangles and vertex_coordinates must be specified'
1026           
1027            try:
1028                self.interpolation_points = ensure_numeric(interpolation_points)
1029            except:
1030                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1031                      'or a list of points\n'
1032                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1033                                      '...')
1034                raise msg
1035
1036
1037            m = len(self.interpolation_points)
1038            p = len(self.T)
1039           
1040            for name in quantity_names:
1041                self.precomputed_values[name] = zeros((p, m), Float)
1042
1043            #Build interpolator
1044            interpol = Interpolation(vertex_coordinates,
1045                                     triangles,
1046                                     point_coordinates = \
1047                                     self.interpolation_points,
1048                                     alpha = 0,
1049                                     precrop = False, 
1050                                     verbose = verbose)
1051
1052            if verbose: print 'Interpolate'
1053            for i, t in enumerate(self.T):
1054                #Interpolate quantities at this timestep
1055                if verbose and i%((p+10)/10)==0:
1056                    print ' time step %d of %d' %(i, p)
1057                   
1058                for name in quantity_names:
1059                    if len(quantities[name].shape) == 2:
1060                        result = interpol.interpolate(quantities[name][i,:])
1061                    else:
1062                       #Assume no time dependency
1063                       result = interpol.interpolate(quantities[name][:])
1064                       
1065                    self.precomputed_values[name][i, :] = result
1066                   
1067                       
1068
1069            #Report
1070            if verbose:
1071                print self.statistics()
1072                #self.print_statistics()
1073           
1074        else:
1075            #Store quantitites as is
1076            for name in quantity_names:
1077                self.precomputed_values[name] = quantities[name]
1078
1079
1080        #else:
1081        #    #Return an average, making this a time series
1082        #    for name in quantity_names:
1083        #        self.values[name] = zeros(len(self.T), Float)
1084        #
1085        #    if verbose: print 'Compute mean values'
1086        #    for i, t in enumerate(self.T):
1087        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1088        #        for name in quantity_names:
1089        #           self.values[name][i] = mean(quantities[name][i,:])
1090
1091
1092
1093
1094    def __repr__(self):
1095        #return 'Interpolation function (spatio-temporal)'
1096        return self.statistics()
1097   
1098
1099    def __call__(self, t, point_id = None, x = None, y = None):
1100        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1101
1102        Inputs:
1103          t: time - Model time. Must lie within existing timesteps
1104          point_id: index of one of the preprocessed points.
1105          x, y:     Overrides location, point_id ignored
1106         
1107          If spatial info is present and all of x,y,point_id
1108          are None an exception is raised
1109                   
1110          If no spatial info is present, point_id and x,y arguments are ignored
1111          making f a function of time only.
1112
1113         
1114          FIXME: point_id could also be a slice
1115          FIXME: What if x and y are vectors?
1116          FIXME: What about f(x,y) without t?
1117        """
1118
1119        from math import pi, cos, sin, sqrt
1120        from Numeric import zeros, Float
1121        from util import mean       
1122
1123        if self.spatial is True:
1124            if point_id is None:
1125                if x is None or y is None:
1126                    msg = 'Either point_id or x and y must be specified'
1127                    raise msg
1128            else:
1129                if self.interpolation_points is None:
1130                    msg = 'Interpolation_function must be instantiated ' +\
1131                          'with a list of interpolation points before parameter ' +\
1132                          'point_id can be used'
1133                    raise msg
1134
1135
1136        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1137        msg += ' does not match model time: %s\n' %t
1138        if t < self.T[0]: raise msg
1139        if t > self.T[-1]: raise msg
1140
1141        oldindex = self.index #Time index
1142
1143        #Find current time slot
1144        while t > self.T[self.index]: self.index += 1
1145        while t < self.T[self.index]: self.index -= 1
1146
1147        if t == self.T[self.index]:
1148            #Protect against case where t == T[-1] (last time)
1149            # - also works in general when t == T[i]
1150            ratio = 0
1151        else:
1152            #t is now between index and index+1
1153            ratio = (t - self.T[self.index])/\
1154                    (self.T[self.index+1] - self.T[self.index])
1155
1156        #Compute interpolated values
1157        q = zeros(len(self.quantity_names), Float)
1158
1159        for i, name in enumerate(self.quantity_names):
1160            Q = self.precomputed_values[name]
1161
1162            if self.spatial is False:
1163                #If there is no spatial info               
1164                assert len(Q.shape) == 1
1165
1166                Q0 = Q[self.index]
1167                if ratio > 0: Q1 = Q[self.index+1]
1168
1169            else:
1170                if x is not None and y is not None:
1171                    #Interpolate to x, y
1172                   
1173                    raise 'x,y interpolation not yet implemented'
1174                else:
1175                    #Use precomputed point
1176                    Q0 = Q[self.index, point_id]
1177                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1178
1179            #Linear temporal interpolation   
1180            if ratio > 0:
1181                q[i] = Q0 + ratio*(Q1 - Q0)
1182            else:
1183                q[i] = Q0
1184
1185
1186        #Return vector of interpolated values
1187        #if len(q) == 1:
1188        #    return q[0]
1189        #else:
1190        #    return q
1191
1192
1193        #Return vector of interpolated values
1194        #FIXME:
1195        if self.spatial is True:
1196            return q
1197        else:
1198            #Replicate q according to x and y
1199            #This is e.g used for Wind_stress
1200            if x == None or y == None: 
1201                return q
1202            else:
1203                try:
1204                    N = len(x)
1205                except:
1206                    return q
1207                else:
1208                    from Numeric import ones, Float
1209                    #x is a vector - Create one constant column for each value
1210                    N = len(x)
1211                    assert len(y) == N, 'x and y must have same length'
1212                    res = []
1213                    for col in q:
1214                        res.append(col*ones(N, Float))
1215                       
1216                return res
1217
1218
1219    def statistics(self):
1220        """Output statistics about interpolation_function
1221        """
1222       
1223        vertex_coordinates = self.vertex_coordinates
1224        interpolation_points = self.interpolation_points               
1225        quantity_names = self.quantity_names
1226        quantities = self.quantities
1227        precomputed_values = self.precomputed_values                 
1228               
1229        x = vertex_coordinates[:,0]
1230        y = vertex_coordinates[:,1]               
1231
1232        str =  '------------------------------------------------\n'
1233        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1234        str += '  Extent:\n'
1235        str += '    x in [%f, %f], len(x) == %d\n'\
1236               %(min(x), max(x), len(x))
1237        str += '    y in [%f, %f], len(y) == %d\n'\
1238               %(min(y), max(y), len(y))
1239        str += '    t in [%f, %f], len(t) == %d\n'\
1240               %(min(self.T), max(self.T), len(self.T))
1241        str += '  Quantities:\n'
1242        for name in quantity_names:
1243            q = quantities[name][:].flat
1244            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1245
1246        if interpolation_points is not None:   
1247            str += '  Interpolation points (xi, eta):'\
1248                   ' number of points == %d\n' %interpolation_points.shape[0]
1249            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1250                                            max(interpolation_points[:,0]))
1251            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1252                                             max(interpolation_points[:,1]))
1253            str += '  Interpolated quantities (over all timesteps):\n'
1254       
1255            for name in quantity_names:
1256                q = precomputed_values[name][:].flat
1257                str += '    %s at interpolation points in [%f, %f]\n'\
1258                       %(name, min(q), max(q))
1259        str += '------------------------------------------------\n'
1260
1261        return str
1262
1263        #FIXME: Delete
1264        #print '------------------------------------------------'
1265        #print 'Interpolation_function statistics:'
1266        #print '  Extent:'
1267        #print '    x in [%f, %f], len(x) == %d'\
1268        #      %(min(x), max(x), len(x))
1269        #print '    y in [%f, %f], len(y) == %d'\
1270        #      %(min(y), max(y), len(y))
1271        #print '    t in [%f, %f], len(t) == %d'\
1272        #      %(min(self.T), max(self.T), len(self.T))
1273        #print '  Quantities:'
1274        #for name in quantity_names:
1275        #    q = quantities[name][:].flat
1276        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1277        #print '  Interpolation points (xi, eta):'\
1278        #      ' number of points == %d ' %interpolation_points.shape[0]
1279        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1280        #                             max(interpolation_points[:,0]))
1281        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1282        #                              max(interpolation_points[:,1]))
1283        #print '  Interpolated quantities (over all timesteps):'
1284        #
1285        #for name in quantity_names:
1286        #    q = precomputed_values[name][:].flat
1287        #    print '    %s at interpolation points in [%f, %f]'\
1288        #          %(name, min(q), max(q))
1289        #print '------------------------------------------------'
1290
1291
1292#-------------------------------------------------------------
1293if __name__ == "__main__":
1294    """
1295    Load in a mesh and data points with attributes.
1296    Fit the attributes to the mesh.
1297    Save a new mesh file.
1298    """
1299    import os, sys
1300    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1301            %os.path.basename(sys.argv[0])
1302
1303    if len(sys.argv) < 4:
1304        print usage
1305    else:
1306        mesh_file = sys.argv[1]
1307        point_file = sys.argv[2]
1308        mesh_output_file = sys.argv[3]
1309
1310        expand_search = False
1311        if len(sys.argv) > 4:
1312            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1313                expand_search = True
1314            else:
1315                expand_search = False
1316
1317        verbose = False
1318        if len(sys.argv) > 5:
1319            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1320                verbose = False
1321            else:
1322                verbose = True
1323
1324        if len(sys.argv) > 6:
1325            alpha = sys.argv[6]
1326        else:
1327            alpha = DEFAULT_ALPHA
1328
1329        # This is used more for testing
1330        if len(sys.argv) > 7:
1331            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1332                display_errors = False
1333            else:
1334                display_errors = True
1335           
1336        t0 = time.time()
1337        try:
1338            fit_to_mesh_file(mesh_file,
1339                         point_file,
1340                         mesh_output_file,
1341                         alpha,
1342                         verbose= verbose,
1343                         expand_search = expand_search,
1344                         display_errors = display_errors)
1345        except IOError,e:
1346            import sys; sys.exit(1)
1347
1348        print 'That took %.2f seconds' %(time.time()-t0)
1349
Note: See TracBrowser for help on using the repository browser.