source: inundation/pyvolution/least_squares.py @ 1903

Last change on this file since 1903 was 1903, checked in by duncan, 19 years ago

removing / gutting duplication of methods reading xya files

File size: 47.8 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME (DSG-ON) reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203    interp = Interpolation(vertex_coordinates,
204                           triangles,
205                           point_coordinates,
206                           alpha = alpha,
207                           verbose = verbose,
208                           expand_search = expand_search,
209                           data_origin = data_origin,
210                           mesh_origin = mesh_origin,
211                           precrop = precrop)
212
213    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
214    return vertex_attributes
215
216
217
218def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
219                    verbose = False, reduction = 1):
220    """Fits attributes from pts file to MxN rectangular mesh
221
222    Read pts file and create rectangular mesh of resolution MxN such that
223    it covers all points specified in pts file.
224
225    FIXME: This may be a temporary function until we decide on
226    netcdf formats etc
227
228    FIXME: Uses elevation hardwired
229    """
230
231    import  mesh_factory
232    from load_mesh.loadASCII import import_points_file
233   
234    if verbose: print 'Read pts'
235    points_dict = import_points_file(pts_name)
236    #points, attributes = util.read_xya(pts_name)
237
238    #Reduce number of points a bit
239    points = points_dict['pointlist'][::reduction]
240    elevation = points_dict['attributelist']['elevation']  #Must be elevation
241    elevation = elevation[::reduction]
242
243    if verbose: print 'Got %d data points' %len(points)
244
245    if verbose: print 'Create mesh'
246    #Find extent
247    max_x = min_x = points[0][0]
248    max_y = min_y = points[0][1]
249    for point in points[1:]:
250        x = point[0]
251        if x > max_x: max_x = x
252        if x < min_x: min_x = x
253        y = point[1]
254        if y > max_y: max_y = y
255        if y < min_y: min_y = y
256
257    #Create appropriate mesh
258    vertex_coordinates, triangles, boundary =\
259         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
260                                (min_x, min_y))
261
262    #Fit attributes to mesh
263    vertex_attributes = fit_to_mesh(vertex_coordinates,
264                        triangles,
265                        points,
266                        elevation, alpha=alpha, verbose=verbose)
267
268
269
270    return vertex_coordinates, triangles, boundary, vertex_attributes
271
272
273
274class Interpolation:
275
276    def __init__(self,
277                 vertex_coordinates,
278                 triangles,
279                 point_coordinates = None,
280                 alpha = None,
281                 verbose = False,
282                 expand_search = True,
283                 interp_only = False,
284                 max_points_per_cell = 30,
285                 mesh_origin = None,
286                 data_origin = None,
287                 precrop = False):
288
289
290        """ Build interpolation matrix mapping from
291        function values at vertices to function values at data points
292
293        Inputs:
294
295          vertex_coordinates: List of coordinate pairs [xi, eta] of
296          points constituting mesh (or a an m x 2 Numeric array)
297          Points may appear multiple times
298          (e.g. if vertices have discontinuities)
299
300          triangles: List of 3-tuples (or a Numeric array) of
301          integers representing indices of all vertices in the mesh.
302
303          point_coordinates: List of coordinate pairs [x, y] of
304          data points (or an nx2 Numeric array)
305          If point_coordinates is absent, only smoothing matrix will
306          be built
307
308          alpha: Smoothing parameter
309
310          data_origin and mesh_origin are 3-tuples consisting of
311          UTM zone, easting and northing. If specified
312          point coordinates and vertex coordinates are assumed to be
313          relative to their respective origins.
314
315        """
316        from util import ensure_numeric
317
318        #Convert input to Numeric arrays
319        triangles = ensure_numeric(triangles, Int)
320        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
321
322        #Build underlying mesh
323        if verbose: print 'Building mesh'
324        #self.mesh = General_mesh(vertex_coordinates, triangles,
325        #FIXME: Trying the normal mesh while testing precrop,
326        #       The functionality of boundary_polygon is needed for that
327
328        #FIXME - geo ref does not have to go into mesh.
329        # Change the point co-ords to conform to the
330        # mesh co-ords early in the code
331        if mesh_origin == None:
332            geo = None
333        else:
334            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
335        self.mesh = Mesh(vertex_coordinates, triangles,
336                         geo_reference = geo)
337       
338        self.mesh.check_integrity()
339
340        self.data_origin = data_origin
341
342        self.point_indices = None
343
344        #Smoothing parameter
345        if alpha is None:
346            self.alpha = DEFAULT_ALPHA
347        else:   
348            self.alpha = alpha
349
350
351        if point_coordinates is not None:
352            if verbose: print 'Building interpolation matrix'
353            self.build_interpolation_matrix_A(point_coordinates,
354                                              verbose = verbose,
355                                              expand_search = expand_search,
356                                              interp_only = interp_only, 
357                                              max_points_per_cell =\
358                                              max_points_per_cell,
359                                              data_origin = data_origin,
360                                              precrop = precrop)
361        #Build coefficient matrices
362        if interp_only == False:
363            self.build_coefficient_matrix_B(point_coordinates,
364                                        verbose = verbose,
365                                        expand_search = expand_search,
366                                        max_points_per_cell =\
367                                        max_points_per_cell,
368                                        data_origin = data_origin,
369                                        precrop = precrop)
370
371
372    def set_point_coordinates(self, point_coordinates,
373                              data_origin = None,
374                              verbose = False,
375                              precrop = True):
376        """
377        A public interface to setting the point co-ordinates.
378        """
379        if point_coordinates is not None:
380            if verbose: print 'Building interpolation matrix'
381            self.build_interpolation_matrix_A(point_coordinates,
382                                              verbose = verbose,
383                                              data_origin = data_origin,
384                                              precrop = precrop)
385        self.build_coefficient_matrix_B(point_coordinates, data_origin)
386
387    def build_coefficient_matrix_B(self, point_coordinates=None,
388                                   verbose = False, expand_search = True,
389                                   max_points_per_cell=30,
390                                   data_origin = None,
391                                   precrop = False):
392        """Build final coefficient matrix"""
393
394
395        if self.alpha <> 0:
396            if verbose: print 'Building smoothing matrix'
397            self.build_smoothing_matrix_D()
398
399        if point_coordinates is not None:
400            if self.alpha <> 0:
401                self.B = self.AtA + self.alpha*self.D
402            else:
403                self.B = self.AtA
404
405            #Convert self.B matrix to CSR format for faster matrix vector
406            self.B = Sparse_CSR(self.B)
407
408    def build_interpolation_matrix_A(self, point_coordinates,
409                                     verbose = False, expand_search = True,
410                                     max_points_per_cell=30,
411                                     data_origin = None,
412                                     precrop = False,
413                                     interp_only = False):
414        """Build n x m interpolation matrix, where
415        n is the number of data points and
416        m is the number of basis functions phi_k (one per vertex)
417
418        This algorithm uses a quad tree data structure for fast binning of data points
419        origin is a 3-tuple consisting of UTM zone, easting and northing.
420        If specified coordinates are assumed to be relative to this origin.
421
422        This one will override any data_origin that may be specified in
423        interpolation instance
424
425        """
426
427
428        #FIXME (Ole): Check that this function is memeory efficient.
429        #6 million datapoints and 300000 basis functions
430        #causes out-of-memory situation
431        #First thing to check is whether there is room for self.A and self.AtA
432        #
433        #Maybe we need some sort of blocking
434
435        from quad import build_quadtree
436        from util import ensure_numeric
437
438        if data_origin is None:
439            data_origin = self.data_origin #Use the one from
440                                           #interpolation instance
441
442        #Convert input to Numeric arrays just in case.
443        point_coordinates = ensure_numeric(point_coordinates, Float)
444
445        #Keep track of discarded points (if any).
446        #This is only registered if precrop is True
447        self.cropped_points = False
448
449        #Shift data points to same origin as mesh (if specified)
450
451        #FIXME this will shift if there was no geo_ref.
452        #But all this should be removed anyhow.
453        #change coords before this point
454        mesh_origin = self.mesh.geo_reference.get_origin()
455        if point_coordinates is not None:
456            if data_origin is not None:
457                if mesh_origin is not None:
458
459                    #Transformation:
460                    #
461                    #Let x_0 be the reference point of the point coordinates
462                    #and xi_0 the reference point of the mesh.
463                    #
464                    #A point coordinate (x + x_0) is then made relative
465                    #to xi_0 by
466                    #
467                    # x_new = x + x_0 - xi_0
468                    #
469                    #and similarly for eta
470
471                    x_offset = data_origin[1] - mesh_origin[1]
472                    y_offset = data_origin[2] - mesh_origin[2]
473                else: #Shift back to a zero origin
474                    x_offset = data_origin[1]
475                    y_offset = data_origin[2]
476
477                point_coordinates[:,0] += x_offset
478                point_coordinates[:,1] += y_offset
479            else:
480                if mesh_origin is not None:
481                    #Use mesh origin for data points
482                    point_coordinates[:,0] -= mesh_origin[1]
483                    point_coordinates[:,1] -= mesh_origin[2]
484
485
486
487        #Remove points falling outside mesh boundary
488        #This reduced one example from 1356 seconds to 825 seconds
489        if precrop is True:
490            from Numeric import take
491            from util import inside_polygon
492
493            if verbose: print 'Getting boundary polygon'
494            P = self.mesh.get_boundary_polygon()
495
496            if verbose: print 'Getting indices inside mesh boundary'
497            indices = inside_polygon(point_coordinates, P, verbose = verbose)
498
499
500            if len(indices) != point_coordinates.shape[0]:
501                self.cropped_points = True
502                if verbose:
503                    print 'Done - %d points outside mesh have been cropped.'\
504                          %(point_coordinates.shape[0] - len(indices))
505
506            point_coordinates = take(point_coordinates, indices)
507            self.point_indices = indices
508
509
510
511
512        #Build n x m interpolation matrix
513        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
514        n = point_coordinates.shape[0]     #Nbr of data points
515
516        if verbose: print 'Number of datapoints: %d' %n
517        if verbose: print 'Number of basis functions: %d' %m
518
519        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
520        #However, Sparse_CSR does not have the same methods as Sparse yet
521        #The tests will reveal what needs to be done
522        #self.A = Sparse_CSR(Sparse(n,m))
523        #self.AtA = Sparse_CSR(Sparse(m,m))
524        self.A = Sparse(n,m)
525        self.AtA = Sparse(m,m)
526
527        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
528        root = build_quadtree(self.mesh,
529                              max_points_per_cell = max_points_per_cell)
530
531        #Compute matrix elements
532        for i in range(n):
533            #For each data_coordinate point
534
535            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
536            x = point_coordinates[i]
537
538            #Find vertices near x
539            candidate_vertices = root.search(x[0], x[1])
540            is_more_elements = True
541
542            element_found, sigma0, sigma1, sigma2, k = \
543                self.search_triangles_of_vertices(candidate_vertices, x)
544            while not element_found and is_more_elements and expand_search:
545                #if verbose: print 'Expanding search'
546                candidate_vertices, branch = root.expand_search()
547                if branch == []:
548                    # Searching all the verts from the root cell that haven't
549                    # been searched.  This is the last try
550                    element_found, sigma0, sigma1, sigma2, k = \
551                      self.search_triangles_of_vertices(candidate_vertices, x)
552                    is_more_elements = False
553                else:
554                    element_found, sigma0, sigma1, sigma2, k = \
555                      self.search_triangles_of_vertices(candidate_vertices, x)
556
557
558            #Update interpolation matrix A if necessary
559            if element_found is True:
560                #Assign values to matrix A
561
562                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
563                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
564                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
565
566                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
567                js     = [j0,j1,j2]
568
569                for j in js:
570                    self.A[i,j] = sigmas[j]
571                    for k in js:
572                        if interp_only == False:
573                            self.AtA[j,k] += sigmas[j]*sigmas[k]
574            else:
575                pass
576                #Ok if there is no triangle for datapoint
577                #(as in brute force version)
578                #raise 'Could not find triangle for point', x
579
580
581
582    def search_triangles_of_vertices(self, candidate_vertices, x):
583            #Find triangle containing x:
584            element_found = False
585
586            # This will be returned if element_found = False
587            sigma2 = -10.0
588            sigma0 = -10.0
589            sigma1 = -10.0
590            k = -10.0
591
592            #For all vertices in same cell as point x
593            for v in candidate_vertices:
594
595                #for each triangle id (k) which has v as a vertex
596                for k, _ in self.mesh.vertexlist[v]:
597
598                    #Get the three vertex_points of candidate triangle
599                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
600                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
601                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
602
603                    #print "PDSG - k", k
604                    #print "PDSG - xi0", xi0
605                    #print "PDSG - xi1", xi1
606                    #print "PDSG - xi2", xi2
607                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
608                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
609
610                    #Get the three normals
611                    n0 = self.mesh.get_normal(k, 0)
612                    n1 = self.mesh.get_normal(k, 1)
613                    n2 = self.mesh.get_normal(k, 2)
614
615
616                    #Compute interpolation
617                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
618                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
619                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
620
621                    #print "PDSG - sigma0", sigma0
622                    #print "PDSG - sigma1", sigma1
623                    #print "PDSG - sigma2", sigma2
624
625                    #FIXME: Maybe move out to test or something
626                    epsilon = 1.0e-6
627                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
628
629                    #Check that this triangle contains the data point
630
631                    #Sigmas can get negative within
632                    #machine precision on some machines (e.g nautilus)
633                    #Hence the small eps
634                    eps = 1.0e-15
635                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
636                        element_found = True
637                        break
638
639                if element_found is True:
640                    #Don't look for any other triangle
641                    break
642            return element_found, sigma0, sigma1, sigma2, k
643
644
645
646    def build_interpolation_matrix_A_brute(self, point_coordinates):
647        """Build n x m interpolation matrix, where
648        n is the number of data points and
649        m is the number of basis functions phi_k (one per vertex)
650
651        This is the brute force which is too slow for large problems,
652        but could be used for testing
653        """
654
655        from util import ensure_numeric
656
657        #Convert input to Numeric arrays
658        point_coordinates = ensure_numeric(point_coordinates, Float)
659
660        #Build n x m interpolation matrix
661        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
662        n = point_coordinates.shape[0]     #Nbr of data points
663
664        self.A = Sparse(n,m)
665        self.AtA = Sparse(m,m)
666
667        #Compute matrix elements
668        for i in range(n):
669            #For each data_coordinate point
670
671            x = point_coordinates[i]
672            element_found = False
673            k = 0
674            while not element_found and k < len(self.mesh):
675                #For each triangle (brute force)
676                #FIXME: Real algorithm should only visit relevant triangles
677
678                #Get the three vertex_points
679                xi0 = self.mesh.get_vertex_coordinate(k, 0)
680                xi1 = self.mesh.get_vertex_coordinate(k, 1)
681                xi2 = self.mesh.get_vertex_coordinate(k, 2)
682
683                #Get the three normals
684                n0 = self.mesh.get_normal(k, 0)
685                n1 = self.mesh.get_normal(k, 1)
686                n2 = self.mesh.get_normal(k, 2)
687
688                #Compute interpolation
689                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
690                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
691                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
692
693                #FIXME: Maybe move out to test or something
694                epsilon = 1.0e-6
695                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
696
697                #Check that this triangle contains data point
698                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
699                    element_found = True
700                    #Assign values to matrix A
701
702                    j0 = self.mesh.triangles[k,0] #Global vertex id
703                    #self.A[i, j0] = sigma0
704
705                    j1 = self.mesh.triangles[k,1] #Global vertex id
706                    #self.A[i, j1] = sigma1
707
708                    j2 = self.mesh.triangles[k,2] #Global vertex id
709                    #self.A[i, j2] = sigma2
710
711                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
712                    js     = [j0,j1,j2]
713
714                    for j in js:
715                        self.A[i,j] = sigmas[j]
716                        for k in js:
717                            self.AtA[j,k] += sigmas[j]*sigmas[k]
718                k = k+1
719
720
721
722    def get_A(self):
723        return self.A.todense()
724
725    def get_B(self):
726        return self.B.todense()
727
728    def get_D(self):
729        return self.D.todense()
730
731        #FIXME: Remember to re-introduce the 1/n factor in the
732        #interpolation term
733
734    def build_smoothing_matrix_D(self):
735        """Build m x m smoothing matrix, where
736        m is the number of basis functions phi_k (one per vertex)
737
738        The smoothing matrix is defined as
739
740        D = D1 + D2
741
742        where
743
744        [D1]_{k,l} = \int_\Omega
745           \frac{\partial \phi_k}{\partial x}
746           \frac{\partial \phi_l}{\partial x}\,
747           dx dy
748
749        [D2]_{k,l} = \int_\Omega
750           \frac{\partial \phi_k}{\partial y}
751           \frac{\partial \phi_l}{\partial y}\,
752           dx dy
753
754
755        The derivatives \frac{\partial \phi_k}{\partial x},
756        \frac{\partial \phi_k}{\partial x} for a particular triangle
757        are obtained by computing the gradient a_k, b_k for basis function k
758        """
759
760        #FIXME: algorithm might be optimised by computing local 9x9
761        #"element stiffness matrices:
762
763        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
764
765        self.D = Sparse(m,m)
766
767        #For each triangle compute contributions to D = D1+D2
768        for i in range(len(self.mesh)):
769
770            #Get area
771            area = self.mesh.areas[i]
772
773            #Get global vertex indices
774            v0 = self.mesh.triangles[i,0]
775            v1 = self.mesh.triangles[i,1]
776            v2 = self.mesh.triangles[i,2]
777
778            #Get the three vertex_points
779            xi0 = self.mesh.get_vertex_coordinate(i, 0)
780            xi1 = self.mesh.get_vertex_coordinate(i, 1)
781            xi2 = self.mesh.get_vertex_coordinate(i, 2)
782
783            #Compute gradients for each vertex
784            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
785                              1, 0, 0)
786
787            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
788                              0, 1, 0)
789
790            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
791                              0, 0, 1)
792
793            #Compute diagonal contributions
794            self.D[v0,v0] += (a0*a0 + b0*b0)*area
795            self.D[v1,v1] += (a1*a1 + b1*b1)*area
796            self.D[v2,v2] += (a2*a2 + b2*b2)*area
797
798            #Compute contributions for basis functions sharing edges
799            e01 = (a0*a1 + b0*b1)*area
800            self.D[v0,v1] += e01
801            self.D[v1,v0] += e01
802
803            e12 = (a1*a2 + b1*b2)*area
804            self.D[v1,v2] += e12
805            self.D[v2,v1] += e12
806
807            e20 = (a2*a0 + b2*b0)*area
808            self.D[v2,v0] += e20
809            self.D[v0,v2] += e20
810
811
812    def fit(self, z):
813        """Fit a smooth surface to given 1d array of data points z.
814
815        The smooth surface is computed at each vertex in the underlying
816        mesh using the formula given in the module doc string.
817
818        Pre Condition:
819          self.A, self.AtA and self.B have been initialised
820
821        Inputs:
822          z: Single 1d vector or array of data at the point_coordinates.
823        """
824
825        #Convert input to Numeric arrays
826        from util import ensure_numeric
827        z = ensure_numeric(z, Float)
828
829        if len(z.shape) > 1 :
830            raise VectorShapeError, 'Can only deal with 1d data vector'
831
832        if self.point_indices is not None:
833            #Remove values for any points that were outside mesh
834            z = take(z, self.point_indices)
835
836        #Compute right hand side based on data
837        Atz = self.A.trans_mult(z)
838
839
840        #Check sanity
841        n, m = self.A.shape
842        if n<m and self.alpha == 0.0:
843            msg = 'ERROR (least_squares): Too few data points\n'
844            msg += 'There are only %d data points and alpha == 0. ' %n
845            msg += 'Need at least %d\n' %m
846            msg += 'Alternatively, set smoothing parameter alpha to a small '
847            msg += 'positive value,\ne.g. 1.0e-3.'
848            raise msg
849
850
851
852        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
853        #FIXME: Should we store the result here for later use? (ON)
854
855
856    def fit_points(self, z, verbose=False):
857        """Like fit, but more robust when each point has two or more attributes
858        FIXME (Ole): The name fit_points doesn't carry any meaning
859        for me. How about something like fit_multiple or fit_columns?
860        """
861
862        try:
863            if verbose: print 'Solving penalised least_squares problem'
864            return self.fit(z)
865        except VectorShapeError, e:
866            # broadcasting is not supported.
867
868            #Convert input to Numeric arrays
869            from util import ensure_numeric
870            z = ensure_numeric(z, Float)
871
872            #Build n x m interpolation matrix
873            m = self.mesh.coordinates.shape[0] #Number of vertices
874            n = z.shape[1]                     #Number of data points
875
876            f = zeros((m,n), Float) #Resulting columns
877
878            for i in range(z.shape[1]):
879                f[:,i] = self.fit(z[:,i])
880
881            return f
882
883
884    def interpolate(self, f):
885        """Evaluate smooth surface f at data points implied in self.A.
886
887        The mesh values representing a smooth surface are
888        assumed to be specified in f. This argument could,
889        for example have been obtained from the method self.fit()
890
891        Pre Condition:
892          self.A has been initialised
893
894        Inputs:
895          f: Vector or array of data at the mesh vertices.
896          If f is an array, interpolation will be done for each column as
897          per underlying matrix-matrix multiplication
898
899        Output:
900          Interpolated values at data points implied in self.A
901
902        """
903
904        return self.A * f
905
906    def cull_outsiders(self, f):
907        pass
908
909
910
911
912class Interpolation_function:
913    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
914    which is interpolated from time series defined at vertices of
915    triangular mesh (such as those stored in sww files)
916
917    Let m be the number of vertices, n the number of triangles
918    and p the number of timesteps.
919
920    Mandatory input
921        time:               px1 array of monotonously increasing times (Float)
922        quantities:         Dictionary of pxm arrays or 1 pxm array (Float)
923       
924    Optional input:
925        quantity_names:     List of keys into the quantities dictionary
926        vertex_coordinates: mx2 array of coordinates (Float)
927        triangles:          nx3 array of indices into vertex_coordinates (Int)
928        interpolation_points: array of coordinates to be interpolated to
929        verbose:            Level of reporting
930   
931   
932    The quantities returned by the callable object are specified by
933    the list quantities which must contain the names of the
934    quantities to be returned and also reflect the order, e.g. for
935    the shallow water wave equation, on would have
936    quantities = ['stage', 'xmomentum', 'ymomentum']
937
938    The parameter interpolation_points decides at which points interpolated
939    quantities are to be computed whenever object is called.
940    If None, return average value
941    """
942
943   
944   
945    def __init__(self,
946                 time,
947                 quantities,
948                 quantity_names = None, 
949                 vertex_coordinates = None,
950                 triangles = None,
951                 interpolation_points = None,
952                 verbose = False):
953        """Initialise object and build spatial interpolation if required
954        """
955
956        from Numeric import array, zeros, Float, alltrue, concatenate,\
957             reshape, ArrayType
958
959
960        from util import mean, ensure_numeric
961        from config import time_format
962        import types
963
964
965
966        #Check temporal info
967        time = ensure_numeric(time)       
968        msg = 'Time must be a monotonuosly '
969        msg += 'increasing sequence %s' %time
970        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
971
972
973        #Check if quantities is a single array only
974        if type(quantities) != types.DictType:
975            quantities = ensure_numeric(quantities)
976            quantity_names = ['Attribute']
977
978            #Make it a dictionary
979            quantities = {quantity_names[0]: quantities}
980
981
982        #Use keys if no names are specified
983        if quantity_names is None:
984            quantity_names = quantities.keys()
985
986
987        #Check spatial info
988        if vertex_coordinates is None:
989            self.spatial = False
990        else:   
991            vertex_coordinates = ensure_numeric(vertex_coordinates)
992
993            assert triangles is not None, 'Triangles array must be specified'
994            triangles = ensure_numeric(triangles)
995            self.spatial = True           
996           
997
998 
999        #Save for use with statistics
1000        self.quantity_names = quantity_names       
1001        self.quantities = quantities       
1002        self.vertex_coordinates = vertex_coordinates
1003        self.interpolation_points = interpolation_points
1004        self.T = time[:]  #Time assumed to be relative to starttime
1005        self.index = 0    #Initial time index
1006        self.precomputed_values = {}
1007           
1008
1009
1010        #Precomputed spatial interpolation if requested
1011        if interpolation_points is not None:
1012            if self.spatial is False:
1013                raise 'Triangles and vertex_coordinates must be specified'
1014           
1015
1016            try:
1017                self.interpolation_points =\
1018                                          ensure_numeric(self.interpolation_points)
1019            except:
1020                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1021                      'or a list of points\n'
1022                msg += 'I got: %s.' %( str(self.interpolation_points)[:60] + '...')
1023                raise msg
1024
1025
1026            for name in quantity_names:
1027                self.precomputed_values[name] =\
1028                                              zeros((len(self.T),
1029                                                     len(self.interpolation_points)),
1030                                                    Float)
1031
1032            #Build interpolator
1033            interpol = Interpolation(vertex_coordinates,
1034                                     triangles,
1035                                     point_coordinates = self.interpolation_points,
1036                                     alpha = 0,
1037                                     precrop = False, 
1038                                     verbose = verbose)
1039
1040            if verbose: print 'Interpolate'
1041            n = len(self.T)
1042            for i, t in enumerate(self.T):
1043                #Interpolate quantities at this timestep
1044                if verbose and i%((n+10)/10)==0:
1045                    print ' time step %d of %d' %(i, n)
1046                   
1047                for name in quantity_names:
1048                    self.precomputed_values[name][i, :] =\
1049                    interpol.interpolate(quantities[name][i,:])
1050
1051            #Report
1052            if verbose:
1053                print self.statistics()
1054                #self.print_statistics()
1055           
1056        else:
1057            #Store quantitites as is
1058            for name in quantity_names:
1059                self.precomputed_values[name] = quantities[name]
1060
1061
1062        #else:
1063        #    #Return an average, making this a time series
1064        #    for name in quantity_names:
1065        #        self.values[name] = zeros(len(self.T), Float)
1066        #
1067        #    if verbose: print 'Compute mean values'
1068        #    for i, t in enumerate(self.T):
1069        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1070        #        for name in quantity_names:
1071        #           self.values[name][i] = mean(quantities[name][i,:])
1072
1073
1074
1075
1076    def __repr__(self):
1077        #return 'Interpolation function (spatio-temporal)'
1078        return self.statistics()
1079   
1080
1081    def __call__(self, t, point_id = None, x = None, y = None):
1082        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1083
1084        Inputs:
1085          t: time - Model time. Must lie within existing timesteps
1086          point_id: index of one of the preprocessed points.
1087          x, y:     Overrides location, point_id ignored
1088         
1089          If spatial info is present and all of x,y,point_id
1090          are None an exception is raised
1091                   
1092          If no spatial info is present, point_id and x,y arguments are ignored
1093          making f a function of time only.
1094
1095         
1096          FIXME: point_id could also be a slice
1097          FIXME: What if x and y are vectors?
1098          FIXME: What about f(x,y) without t?
1099        """
1100
1101        from math import pi, cos, sin, sqrt
1102        from Numeric import zeros, Float
1103        from util import mean       
1104
1105        if self.spatial is True:
1106            if point_id is None:
1107                if x is None or y is None:
1108                    msg = 'Either point_id or x and y must be specified'
1109                    raise msg
1110            else:
1111                if self.interpolation_points is None:
1112                    msg = 'Interpolation_function must be instantiated ' +\
1113                          'with a list of interpolation points before parameter ' +\
1114                          'point_id can be used'
1115                    raise msg
1116
1117
1118        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1119        msg += ' does not match model time: %s\n' %t
1120        if t < self.T[0]: raise msg
1121        if t > self.T[-1]: raise msg
1122
1123        oldindex = self.index #Time index
1124
1125        #Find current time slot
1126        while t > self.T[self.index]: self.index += 1
1127        while t < self.T[self.index]: self.index -= 1
1128
1129        if t == self.T[self.index]:
1130            #Protect against case where t == T[-1] (last time)
1131            # - also works in general when t == T[i]
1132            ratio = 0
1133        else:
1134            #t is now between index and index+1
1135            ratio = (t - self.T[self.index])/\
1136                    (self.T[self.index+1] - self.T[self.index])
1137
1138        #Compute interpolated values
1139        q = zeros(len(self.quantity_names), Float)
1140
1141        for i, name in enumerate(self.quantity_names):
1142            Q = self.precomputed_values[name]
1143
1144            if self.spatial is False:
1145                #If there is no spatial info               
1146                assert len(Q.shape) == 1
1147
1148                Q0 = Q[self.index]
1149                if ratio > 0: Q1 = Q[self.index+1]
1150
1151            else:
1152                if x is not None and y is not None:
1153                    #Interpolate to x, y
1154                   
1155                    raise 'x,y interpolation not yet implemented'
1156                else:
1157                    #Use precomputed point
1158                    Q0 = Q[self.index, point_id]
1159                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1160
1161            #Linear temporal interpolation   
1162            if ratio > 0:
1163                q[i] = Q0 + ratio*(Q1 - Q0)
1164            else:
1165                q[i] = Q0
1166
1167
1168        #Return vector of interpolated values
1169        #if len(q) == 1:
1170        #    return q[0]
1171        #else:
1172        #    return q
1173
1174
1175        #Return vector of interpolated values
1176        #FIXME:
1177        if self.spatial is True:
1178            return q
1179        else:
1180            #Replicate q according to x and y
1181            #This is e.g used for Wind_stress
1182            if x == None or y == None: 
1183                return q
1184            else:
1185                try:
1186                    N = len(x)
1187                except:
1188                    return q
1189                else:
1190                    from Numeric import ones, Float
1191                    #x is a vector - Create one constant column for each value
1192                    N = len(x)
1193                    assert len(y) == N, 'x and y must have same length'
1194                    res = []
1195                    for col in q:
1196                        res.append(col*ones(N, Float))
1197                       
1198                return res
1199
1200
1201    def statistics(self):
1202        """Output statistics about interpolation_function
1203        """
1204       
1205        vertex_coordinates = self.vertex_coordinates
1206        interpolation_points = self.interpolation_points               
1207        quantity_names = self.quantity_names
1208        quantities = self.quantities
1209        precomputed_values = self.precomputed_values                 
1210               
1211        x = vertex_coordinates[:,0]
1212        y = vertex_coordinates[:,1]               
1213
1214        str =  '------------------------------------------------\n'
1215        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1216        str += '  Extent:\n'
1217        str += '    x in [%f, %f], len(x) == %d\n'\
1218               %(min(x), max(x), len(x))
1219        str += '    y in [%f, %f], len(y) == %d\n'\
1220               %(min(y), max(y), len(y))
1221        str += '    t in [%f, %f], len(t) == %d\n'\
1222               %(min(self.T), max(self.T), len(self.T))
1223        str += '  Quantities:\n'
1224        for name in quantity_names:
1225            q = quantities[name][:].flat
1226            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1227
1228        if interpolation_points is not None:   
1229            str += '  Interpolation points (xi, eta):'\
1230                   ' number of points == %d\n' %interpolation_points.shape[0]
1231            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1232                                            max(interpolation_points[:,0]))
1233            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1234                                             max(interpolation_points[:,1]))
1235            str += '  Interpolated quantities (over all timesteps):\n'
1236       
1237            for name in quantity_names:
1238                q = precomputed_values[name][:].flat
1239                str += '    %s at interpolation points in [%f, %f]\n'\
1240                       %(name, min(q), max(q))
1241        str += '------------------------------------------------\n'
1242
1243        return str
1244
1245        #FIXME: Delete
1246        #print '------------------------------------------------'
1247        #print 'Interpolation_function statistics:'
1248        #print '  Extent:'
1249        #print '    x in [%f, %f], len(x) == %d'\
1250        #      %(min(x), max(x), len(x))
1251        #print '    y in [%f, %f], len(y) == %d'\
1252        #      %(min(y), max(y), len(y))
1253        #print '    t in [%f, %f], len(t) == %d'\
1254        #      %(min(self.T), max(self.T), len(self.T))
1255        #print '  Quantities:'
1256        #for name in quantity_names:
1257        #    q = quantities[name][:].flat
1258        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1259        #print '  Interpolation points (xi, eta):'\
1260        #      ' number of points == %d ' %interpolation_points.shape[0]
1261        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1262        #                             max(interpolation_points[:,0]))
1263        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1264        #                              max(interpolation_points[:,1]))
1265        #print '  Interpolated quantities (over all timesteps):'
1266        #
1267        #for name in quantity_names:
1268        #    q = precomputed_values[name][:].flat
1269        #    print '    %s at interpolation points in [%f, %f]'\
1270        #          %(name, min(q), max(q))
1271        #print '------------------------------------------------'
1272
1273
1274#-------------------------------------------------------------
1275if __name__ == "__main__":
1276    """
1277    Load in a mesh and data points with attributes.
1278    Fit the attributes to the mesh.
1279    Save a new mesh file.
1280    """
1281    import os, sys
1282    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1283            %os.path.basename(sys.argv[0])
1284
1285    if len(sys.argv) < 4:
1286        print usage
1287    else:
1288        mesh_file = sys.argv[1]
1289        point_file = sys.argv[2]
1290        mesh_output_file = sys.argv[3]
1291
1292        expand_search = False
1293        if len(sys.argv) > 4:
1294            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1295                expand_search = True
1296            else:
1297                expand_search = False
1298
1299        verbose = False
1300        if len(sys.argv) > 5:
1301            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1302                verbose = False
1303            else:
1304                verbose = True
1305
1306        if len(sys.argv) > 6:
1307            alpha = sys.argv[6]
1308        else:
1309            alpha = DEFAULT_ALPHA
1310
1311        # This is used more for testing
1312        if len(sys.argv) > 7:
1313            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1314                display_errors = False
1315            else:
1316                display_errors = True
1317           
1318        t0 = time.time()
1319        try:
1320            fit_to_mesh_file(mesh_file,
1321                         point_file,
1322                         mesh_output_file,
1323                         alpha,
1324                         verbose= verbose,
1325                         expand_search = expand_search,
1326                         display_errors = display_errors)
1327        except IOError,e:
1328            import sys; sys.exit(1)
1329
1330        print 'That took %.2f seconds' %(time.time()-t0)
1331
Note: See TracBrowser for help on using the repository browser.