source: inundation/pyvolution/least_squares.py @ 1907

Last change on this file since 1907 was 1907, checked in by duncan, 19 years ago

removed dead code

File size: 48.4 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME (DSG-ON) reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203    interp = Interpolation(vertex_coordinates,
204                           triangles,
205                           point_coordinates,
206                           alpha = alpha,
207                           verbose = verbose,
208                           expand_search = expand_search,
209                           data_origin = data_origin,
210                           mesh_origin = mesh_origin,
211                           precrop = precrop)
212
213    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
214    return vertex_attributes
215
216
217
218def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
219                    verbose = False, reduction = 1):
220    """Fits attributes from pts file to MxN rectangular mesh
221
222    Read pts file and create rectangular mesh of resolution MxN such that
223    it covers all points specified in pts file.
224
225    FIXME: This may be a temporary function until we decide on
226    netcdf formats etc
227
228    FIXME: Uses elevation hardwired
229    """
230
231    import  mesh_factory
232    from load_mesh.loadASCII import import_points_file
233   
234    if verbose: print 'Read pts'
235    points_dict = import_points_file(pts_name)
236    #points, attributes = util.read_xya(pts_name)
237
238    #Reduce number of points a bit
239    points = points_dict['pointlist'][::reduction]
240    elevation = points_dict['attributelist']['elevation']  #Must be elevation
241    elevation = elevation[::reduction]
242
243    if verbose: print 'Got %d data points' %len(points)
244
245    if verbose: print 'Create mesh'
246    #Find extent
247    max_x = min_x = points[0][0]
248    max_y = min_y = points[0][1]
249    for point in points[1:]:
250        x = point[0]
251        if x > max_x: max_x = x
252        if x < min_x: min_x = x
253        y = point[1]
254        if y > max_y: max_y = y
255        if y < min_y: min_y = y
256
257    #Create appropriate mesh
258    vertex_coordinates, triangles, boundary =\
259         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
260                                (min_x, min_y))
261
262    #Fit attributes to mesh
263    vertex_attributes = fit_to_mesh(vertex_coordinates,
264                        triangles,
265                        points,
266                        elevation, alpha=alpha, verbose=verbose)
267
268
269
270    return vertex_coordinates, triangles, boundary, vertex_attributes
271
272
273
274class Interpolation:
275
276    def __init__(self,
277                 vertex_coordinates,
278                 triangles,
279                 point_coordinates = None,
280                 alpha = None,
281                 verbose = False,
282                 expand_search = True,
283                 interp_only = False,
284                 max_points_per_cell = 30,
285                 mesh_origin = None,
286                 data_origin = None,
287                 precrop = False):
288
289
290        """ Build interpolation matrix mapping from
291        function values at vertices to function values at data points
292
293        Inputs:
294
295          vertex_coordinates: List of coordinate pairs [xi, eta] of
296          points constituting mesh (or a an m x 2 Numeric array)
297          Points may appear multiple times
298          (e.g. if vertices have discontinuities)
299
300          triangles: List of 3-tuples (or a Numeric array) of
301          integers representing indices of all vertices in the mesh.
302
303          point_coordinates: List of coordinate pairs [x, y] of
304          data points (or an nx2 Numeric array)
305          If point_coordinates is absent, only smoothing matrix will
306          be built
307
308          alpha: Smoothing parameter
309
310          data_origin and mesh_origin are 3-tuples consisting of
311          UTM zone, easting and northing. If specified
312          point coordinates and vertex coordinates are assumed to be
313          relative to their respective origins.
314
315        """
316        from util import ensure_numeric
317
318        #Convert input to Numeric arrays
319        triangles = ensure_numeric(triangles, Int)
320        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
321
322        #Build underlying mesh
323        if verbose: print 'Building mesh'
324        #self.mesh = General_mesh(vertex_coordinates, triangles,
325        #FIXME: Trying the normal mesh while testing precrop,
326        #       The functionality of boundary_polygon is needed for that
327
328        #FIXME - geo ref does not have to go into mesh.
329        # Change the point co-ords to conform to the
330        # mesh co-ords early in the code
331        if mesh_origin == None:
332            geo = None
333        else:
334            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
335        self.mesh = Mesh(vertex_coordinates, triangles,
336                         geo_reference = geo)
337       
338        self.mesh.check_integrity()
339
340        self.data_origin = data_origin
341
342        self.point_indices = None
343
344        #Smoothing parameter
345        if alpha is None:
346            self.alpha = DEFAULT_ALPHA
347        else:   
348            self.alpha = alpha
349
350
351        if point_coordinates is not None:
352            if verbose: print 'Building interpolation matrix'
353            self.build_interpolation_matrix_A(point_coordinates,
354                                              verbose = verbose,
355                                              expand_search = expand_search,
356                                              interp_only = interp_only, 
357                                              max_points_per_cell =\
358                                              max_points_per_cell,
359                                              data_origin = data_origin,
360                                              precrop = precrop)
361        #Build coefficient matrices
362        if interp_only == False:
363            self.build_coefficient_matrix_B(point_coordinates,
364                                        verbose = verbose,
365                                        expand_search = expand_search,
366                                        max_points_per_cell =\
367                                        max_points_per_cell,
368                                        data_origin = data_origin,
369                                        precrop = precrop)
370
371
372    def set_point_coordinates(self, point_coordinates,
373                              data_origin = None,
374                              verbose = False,
375                              precrop = True):
376        """
377        A public interface to setting the point co-ordinates.
378        """
379        if point_coordinates is not None:
380            if verbose: print 'Building interpolation matrix'
381            self.build_interpolation_matrix_A(point_coordinates,
382                                              verbose = verbose,
383                                              data_origin = data_origin,
384                                              precrop = precrop)
385        self.build_coefficient_matrix_B(point_coordinates, data_origin)
386
387    def build_coefficient_matrix_B(self, point_coordinates=None,
388                                   verbose = False, expand_search = True,
389                                   max_points_per_cell=30,
390                                   data_origin = None,
391                                   precrop = False):
392        """Build final coefficient matrix"""
393
394
395        if self.alpha <> 0:
396            if verbose: print 'Building smoothing matrix'
397            self.build_smoothing_matrix_D()
398
399        if point_coordinates is not None:
400            if self.alpha <> 0:
401                self.B = self.AtA + self.alpha*self.D
402            else:
403                self.B = self.AtA
404
405            #Convert self.B matrix to CSR format for faster matrix vector
406            self.B = Sparse_CSR(self.B)
407
408    def build_interpolation_matrix_A(self, point_coordinates,
409                                     verbose = False, expand_search = True,
410                                     max_points_per_cell=30,
411                                     data_origin = None,
412                                     precrop = False,
413                                     interp_only = False):
414        """Build n x m interpolation matrix, where
415        n is the number of data points and
416        m is the number of basis functions phi_k (one per vertex)
417
418        This algorithm uses a quad tree data structure for fast binning of data points
419        origin is a 3-tuple consisting of UTM zone, easting and northing.
420        If specified coordinates are assumed to be relative to this origin.
421
422        This one will override any data_origin that may be specified in
423        interpolation instance
424
425        """
426
427
428        #FIXME (Ole): Check that this function is memeory efficient.
429        #6 million datapoints and 300000 basis functions
430        #causes out-of-memory situation
431        #First thing to check is whether there is room for self.A and self.AtA
432        #
433        #Maybe we need some sort of blocking
434
435        from quad import build_quadtree
436        from util import ensure_numeric
437
438        if data_origin is None:
439            data_origin = self.data_origin #Use the one from
440                                           #interpolation instance
441
442        #Convert input to Numeric arrays just in case.
443        point_coordinates = ensure_numeric(point_coordinates, Float)
444
445        #Keep track of discarded points (if any).
446        #This is only registered if precrop is True
447        self.cropped_points = False
448
449        #Shift data points to same origin as mesh (if specified)
450
451        #FIXME this will shift if there was no geo_ref.
452        #But all this should be removed anyhow.
453        #change coords before this point
454        mesh_origin = self.mesh.geo_reference.get_origin()
455        if point_coordinates is not None:
456            if data_origin is not None:
457                if mesh_origin is not None:
458
459                    #Transformation:
460                    #
461                    #Let x_0 be the reference point of the point coordinates
462                    #and xi_0 the reference point of the mesh.
463                    #
464                    #A point coordinate (x + x_0) is then made relative
465                    #to xi_0 by
466                    #
467                    # x_new = x + x_0 - xi_0
468                    #
469                    #and similarly for eta
470
471                    x_offset = data_origin[1] - mesh_origin[1]
472                    y_offset = data_origin[2] - mesh_origin[2]
473                else: #Shift back to a zero origin
474                    x_offset = data_origin[1]
475                    y_offset = data_origin[2]
476
477                point_coordinates[:,0] += x_offset
478                point_coordinates[:,1] += y_offset
479            else:
480                if mesh_origin is not None:
481                    #Use mesh origin for data points
482                    point_coordinates[:,0] -= mesh_origin[1]
483                    point_coordinates[:,1] -= mesh_origin[2]
484
485
486
487        #Remove points falling outside mesh boundary
488        #This reduced one example from 1356 seconds to 825 seconds
489        if precrop is True:
490            from Numeric import take
491            from util import inside_polygon
492
493            if verbose: print 'Getting boundary polygon'
494            P = self.mesh.get_boundary_polygon()
495
496            if verbose: print 'Getting indices inside mesh boundary'
497            indices = inside_polygon(point_coordinates, P, verbose = verbose)
498
499
500            if len(indices) != point_coordinates.shape[0]:
501                self.cropped_points = True
502                if verbose:
503                    print 'Done - %d points outside mesh have been cropped.'\
504                          %(point_coordinates.shape[0] - len(indices))
505
506            point_coordinates = take(point_coordinates, indices)
507            self.point_indices = indices
508
509
510
511
512        #Build n x m interpolation matrix
513        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
514        n = point_coordinates.shape[0]     #Nbr of data points
515
516        if verbose: print 'Number of datapoints: %d' %n
517        if verbose: print 'Number of basis functions: %d' %m
518
519        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
520        #However, Sparse_CSR does not have the same methods as Sparse yet
521        #The tests will reveal what needs to be done
522
523        #
524        #self.A = Sparse_CSR(Sparse(n,m))
525        #self.AtA = Sparse_CSR(Sparse(m,m))
526        self.A = Sparse(n,m)
527        self.AtA = Sparse(m,m)
528
529        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
530        root = build_quadtree(self.mesh,
531                              max_points_per_cell = max_points_per_cell)
532
533        #Compute matrix elements
534        for i in range(n):
535            #For each data_coordinate point
536
537            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
538            x = point_coordinates[i]
539
540            #Find vertices near x
541            candidate_vertices = root.search(x[0], x[1])
542            is_more_elements = True
543
544            element_found, sigma0, sigma1, sigma2, k = \
545                self.search_triangles_of_vertices(candidate_vertices, x)
546            while not element_found and is_more_elements and expand_search:
547                #if verbose: print 'Expanding search'
548                candidate_vertices, branch = root.expand_search()
549                if branch == []:
550                    # Searching all the verts from the root cell that haven't
551                    # been searched.  This is the last try
552                    element_found, sigma0, sigma1, sigma2, k = \
553                      self.search_triangles_of_vertices(candidate_vertices, x)
554                    is_more_elements = False
555                else:
556                    element_found, sigma0, sigma1, sigma2, k = \
557                      self.search_triangles_of_vertices(candidate_vertices, x)
558
559
560            #Update interpolation matrix A if necessary
561            if element_found is True:
562                #Assign values to matrix A
563
564                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
565                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
566                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
567
568                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
569                js     = [j0,j1,j2]
570
571                for j in js:
572                    self.A[i,j] = sigmas[j]
573                    for k in js:
574                        if interp_only == False:
575                            self.AtA[j,k] += sigmas[j]*sigmas[k]
576            else:
577                pass
578                #Ok if there is no triangle for datapoint
579                #(as in brute force version)
580                #raise 'Could not find triangle for point', x
581
582
583
584    def search_triangles_of_vertices(self, candidate_vertices, x):
585            #Find triangle containing x:
586            element_found = False
587
588            # This will be returned if element_found = False
589            sigma2 = -10.0
590            sigma0 = -10.0
591            sigma1 = -10.0
592            k = -10.0
593
594            #For all vertices in same cell as point x
595            for v in candidate_vertices:
596
597                #for each triangle id (k) which has v as a vertex
598                for k, _ in self.mesh.vertexlist[v]:
599
600                    #Get the three vertex_points of candidate triangle
601                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
602                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
603                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
604
605                    #print "PDSG - k", k
606                    #print "PDSG - xi0", xi0
607                    #print "PDSG - xi1", xi1
608                    #print "PDSG - xi2", xi2
609                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
610                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
611
612                    #Get the three normals
613                    n0 = self.mesh.get_normal(k, 0)
614                    n1 = self.mesh.get_normal(k, 1)
615                    n2 = self.mesh.get_normal(k, 2)
616
617
618                    #Compute interpolation
619                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
620                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
621                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
622
623                    #print "PDSG - sigma0", sigma0
624                    #print "PDSG - sigma1", sigma1
625                    #print "PDSG - sigma2", sigma2
626
627                    #FIXME: Maybe move out to test or something
628                    epsilon = 1.0e-6
629                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
630
631                    #Check that this triangle contains the data point
632
633                    #Sigmas can get negative within
634                    #machine precision on some machines (e.g nautilus)
635                    #Hence the small eps
636                    eps = 1.0e-15
637                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
638                        element_found = True
639                        break
640
641                if element_found is True:
642                    #Don't look for any other triangle
643                    break
644            return element_found, sigma0, sigma1, sigma2, k
645
646
647
648    def build_interpolation_matrix_A_brute(self, point_coordinates):
649        """Build n x m interpolation matrix, where
650        n is the number of data points and
651        m is the number of basis functions phi_k (one per vertex)
652
653        This is the brute force which is too slow for large problems,
654        but could be used for testing
655        """
656
657        from util import ensure_numeric
658
659        #Convert input to Numeric arrays
660        point_coordinates = ensure_numeric(point_coordinates, Float)
661
662        #Build n x m interpolation matrix
663        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
664        n = point_coordinates.shape[0]     #Nbr of data points
665
666        self.A = Sparse(n,m)
667        self.AtA = Sparse(m,m)
668
669        #Compute matrix elements
670        for i in range(n):
671            #For each data_coordinate point
672
673            x = point_coordinates[i]
674            element_found = False
675            k = 0
676            while not element_found and k < len(self.mesh):
677                #For each triangle (brute force)
678                #FIXME: Real algorithm should only visit relevant triangles
679
680                #Get the three vertex_points
681                xi0 = self.mesh.get_vertex_coordinate(k, 0)
682                xi1 = self.mesh.get_vertex_coordinate(k, 1)
683                xi2 = self.mesh.get_vertex_coordinate(k, 2)
684
685                #Get the three normals
686                n0 = self.mesh.get_normal(k, 0)
687                n1 = self.mesh.get_normal(k, 1)
688                n2 = self.mesh.get_normal(k, 2)
689
690                #Compute interpolation
691                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
692                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
693                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
694
695                #FIXME: Maybe move out to test or something
696                epsilon = 1.0e-6
697                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
698
699                #Check that this triangle contains data point
700                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
701                    element_found = True
702                    #Assign values to matrix A
703
704                    j0 = self.mesh.triangles[k,0] #Global vertex id
705                    #self.A[i, j0] = sigma0
706
707                    j1 = self.mesh.triangles[k,1] #Global vertex id
708                    #self.A[i, j1] = sigma1
709
710                    j2 = self.mesh.triangles[k,2] #Global vertex id
711                    #self.A[i, j2] = sigma2
712
713                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
714                    js     = [j0,j1,j2]
715
716                    for j in js:
717                        self.A[i,j] = sigmas[j]
718                        for k in js:
719                            self.AtA[j,k] += sigmas[j]*sigmas[k]
720                k = k+1
721
722
723
724    def get_A(self):
725        return self.A.todense()
726
727    def get_B(self):
728        return self.B.todense()
729
730    def get_D(self):
731        return self.D.todense()
732
733        #FIXME: Remember to re-introduce the 1/n factor in the
734        #interpolation term
735
736    def build_smoothing_matrix_D(self):
737        """Build m x m smoothing matrix, where
738        m is the number of basis functions phi_k (one per vertex)
739
740        The smoothing matrix is defined as
741
742        D = D1 + D2
743
744        where
745
746        [D1]_{k,l} = \int_\Omega
747           \frac{\partial \phi_k}{\partial x}
748           \frac{\partial \phi_l}{\partial x}\,
749           dx dy
750
751        [D2]_{k,l} = \int_\Omega
752           \frac{\partial \phi_k}{\partial y}
753           \frac{\partial \phi_l}{\partial y}\,
754           dx dy
755
756
757        The derivatives \frac{\partial \phi_k}{\partial x},
758        \frac{\partial \phi_k}{\partial x} for a particular triangle
759        are obtained by computing the gradient a_k, b_k for basis function k
760        """
761
762        #FIXME: algorithm might be optimised by computing local 9x9
763        #"element stiffness matrices:
764
765        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
766
767        self.D = Sparse(m,m)
768
769        #For each triangle compute contributions to D = D1+D2
770        for i in range(len(self.mesh)):
771
772            #Get area
773            area = self.mesh.areas[i]
774
775            #Get global vertex indices
776            v0 = self.mesh.triangles[i,0]
777            v1 = self.mesh.triangles[i,1]
778            v2 = self.mesh.triangles[i,2]
779
780            #Get the three vertex_points
781            xi0 = self.mesh.get_vertex_coordinate(i, 0)
782            xi1 = self.mesh.get_vertex_coordinate(i, 1)
783            xi2 = self.mesh.get_vertex_coordinate(i, 2)
784
785            #Compute gradients for each vertex
786            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
787                              1, 0, 0)
788
789            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
790                              0, 1, 0)
791
792            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
793                              0, 0, 1)
794
795            #Compute diagonal contributions
796            self.D[v0,v0] += (a0*a0 + b0*b0)*area
797            self.D[v1,v1] += (a1*a1 + b1*b1)*area
798            self.D[v2,v2] += (a2*a2 + b2*b2)*area
799
800            #Compute contributions for basis functions sharing edges
801            e01 = (a0*a1 + b0*b1)*area
802            self.D[v0,v1] += e01
803            self.D[v1,v0] += e01
804
805            e12 = (a1*a2 + b1*b2)*area
806            self.D[v1,v2] += e12
807            self.D[v2,v1] += e12
808
809            e20 = (a2*a0 + b2*b0)*area
810            self.D[v2,v0] += e20
811            self.D[v0,v2] += e20
812
813
814    def fit(self, z):
815        """Fit a smooth surface to given 1d array of data points z.
816
817        The smooth surface is computed at each vertex in the underlying
818        mesh using the formula given in the module doc string.
819
820        Pre Condition:
821          self.A, self.AtA and self.B have been initialised
822
823        Inputs:
824          z: Single 1d vector or array of data at the point_coordinates.
825        """
826
827        #Convert input to Numeric arrays
828        from util import ensure_numeric
829        z = ensure_numeric(z, Float)
830
831        if len(z.shape) > 1 :
832            raise VectorShapeError, 'Can only deal with 1d data vector'
833
834        if self.point_indices is not None:
835            #Remove values for any points that were outside mesh
836            z = take(z, self.point_indices)
837
838        #Compute right hand side based on data
839        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
840        # after a matrix is built, before calcs.
841        Atz = self.A.trans_mult(z)
842
843
844        #Check sanity
845        n, m = self.A.shape
846        if n<m and self.alpha == 0.0:
847            msg = 'ERROR (least_squares): Too few data points\n'
848            msg += 'There are only %d data points and alpha == 0. ' %n
849            msg += 'Need at least %d\n' %m
850            msg += 'Alternatively, set smoothing parameter alpha to a small '
851            msg += 'positive value,\ne.g. 1.0e-3.'
852            raise msg
853
854
855
856        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
857        #FIXME: Should we store the result here for later use? (ON)
858
859
860    def fit_points(self, z, verbose=False):
861        """Like fit, but more robust when each point has two or more attributes
862        FIXME (Ole): The name fit_points doesn't carry any meaning
863        for me. How about something like fit_multiple or fit_columns?
864        """
865
866        try:
867            if verbose: print 'Solving penalised least_squares problem'
868            return self.fit(z)
869        except VectorShapeError, e:
870            # broadcasting is not supported.
871
872            #Convert input to Numeric arrays
873            from util import ensure_numeric
874            z = ensure_numeric(z, Float)
875
876            #Build n x m interpolation matrix
877            m = self.mesh.coordinates.shape[0] #Number of vertices
878            n = z.shape[1]                     #Number of data points
879
880            f = zeros((m,n), Float) #Resulting columns
881
882            for i in range(z.shape[1]):
883                f[:,i] = self.fit(z[:,i])
884
885            return f
886
887
888    def interpolate(self, f):
889        """Evaluate smooth surface f at data points implied in self.A.
890
891        The mesh values representing a smooth surface are
892        assumed to be specified in f. This argument could,
893        for example have been obtained from the method self.fit()
894
895        Pre Condition:
896          self.A has been initialised
897
898        Inputs:
899          f: Vector or array of data at the mesh vertices.
900          If f is an array, interpolation will be done for each column as
901          per underlying matrix-matrix multiplication
902
903        Output:
904          Interpolated values at data points implied in self.A
905
906        """
907
908        return self.A * f
909
910    def cull_outsiders(self, f):
911        pass
912
913
914
915
916class Interpolation_function:
917    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
918    which is interpolated from time series defined at vertices of
919    triangular mesh (such as those stored in sww files)
920
921    Let m be the number of vertices, n the number of triangles
922    and p the number of timesteps.
923
924    Mandatory input
925        time:               px1 array of monotonously increasing times (Float)
926        quantities:         Dictionary of arrays or 1 array (Float)
927                            The arrays must either have dimensions pxm or mx1.
928                            The resulting function will be time dependent in
929                            the former case while it will be constan with
930                            respect to time in the latter case.
931       
932    Optional input:
933        quantity_names:     List of keys into the quantities dictionary
934        vertex_coordinates: mx2 array of coordinates (Float)
935        triangles:          nx3 array of indices into vertex_coordinates (Int)
936        interpolation_points: Nx2 array of coordinates to be interpolated to
937        verbose:            Level of reporting
938   
939   
940    The quantities returned by the callable object are specified by
941    the list quantities which must contain the names of the
942    quantities to be returned and also reflect the order, e.g. for
943    the shallow water wave equation, on would have
944    quantities = ['stage', 'xmomentum', 'ymomentum']
945
946    The parameter interpolation_points decides at which points interpolated
947    quantities are to be computed whenever object is called.
948    If None, return average value
949    """
950
951   
952   
953    def __init__(self,
954                 time,
955                 quantities,
956                 quantity_names = None, 
957                 vertex_coordinates = None,
958                 triangles = None,
959                 interpolation_points = None,
960                 verbose = False):
961        """Initialise object and build spatial interpolation if required
962        """
963
964        from Numeric import array, zeros, Float, alltrue, concatenate,\
965             reshape, ArrayType
966
967
968        from util import mean, ensure_numeric
969        from config import time_format
970        import types
971
972
973
974        #Check temporal info
975        time = ensure_numeric(time)       
976        msg = 'Time must be a monotonuosly '
977        msg += 'increasing sequence %s' %time
978        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
979
980
981        #Check if quantities is a single array only
982        if type(quantities) != types.DictType:
983            quantities = ensure_numeric(quantities)
984            quantity_names = ['Attribute']
985
986            #Make it a dictionary
987            quantities = {quantity_names[0]: quantities}
988
989
990        #Use keys if no names are specified
991        if quantity_names is None:
992            quantity_names = quantities.keys()
993
994
995        #Check spatial info
996        if vertex_coordinates is None:
997            self.spatial = False
998        else:   
999            vertex_coordinates = ensure_numeric(vertex_coordinates)
1000
1001            assert triangles is not None, 'Triangles array must be specified'
1002            triangles = ensure_numeric(triangles)
1003            self.spatial = True           
1004           
1005
1006 
1007        #Save for use with statistics
1008        self.quantity_names = quantity_names       
1009        self.quantities = quantities       
1010        self.vertex_coordinates = vertex_coordinates
1011        self.interpolation_points = interpolation_points
1012        self.T = time[:]  # Time assumed to be relative to starttime
1013        self.index = 0    # Initial time index
1014        self.precomputed_values = {}
1015           
1016
1017
1018        #Precomputed spatial interpolation if requested
1019        if interpolation_points is not None:
1020            if self.spatial is False:
1021                raise 'Triangles and vertex_coordinates must be specified'
1022           
1023            try:
1024                self.interpolation_points = ensure_numeric(interpolation_points)
1025            except:
1026                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1027                      'or a list of points\n'
1028                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1029                                      '...')
1030                raise msg
1031
1032
1033            m = len(self.interpolation_points)
1034            p = len(self.T)
1035           
1036            for name in quantity_names:
1037                self.precomputed_values[name] = zeros((p, m), Float)
1038
1039            #Build interpolator
1040            interpol = Interpolation(vertex_coordinates,
1041                                     triangles,
1042                                     point_coordinates = \
1043                                     self.interpolation_points,
1044                                     alpha = 0,
1045                                     precrop = False, 
1046                                     verbose = verbose)
1047
1048            if verbose: print 'Interpolate'
1049            for i, t in enumerate(self.T):
1050                #Interpolate quantities at this timestep
1051                if verbose and i%((p+10)/10)==0:
1052                    print ' time step %d of %d' %(i, p)
1053                   
1054                for name in quantity_names:
1055                    if len(quantities[name].shape) == 2:
1056                        result = interpol.interpolate(quantities[name][i,:])
1057                    else:
1058                       #Assume no time dependency
1059                       result = interpol.interpolate(quantities[name][:])
1060                       
1061                    self.precomputed_values[name][i, :] = result
1062                   
1063                       
1064
1065            #Report
1066            if verbose:
1067                print self.statistics()
1068                #self.print_statistics()
1069           
1070        else:
1071            #Store quantitites as is
1072            for name in quantity_names:
1073                self.precomputed_values[name] = quantities[name]
1074
1075
1076        #else:
1077        #    #Return an average, making this a time series
1078        #    for name in quantity_names:
1079        #        self.values[name] = zeros(len(self.T), Float)
1080        #
1081        #    if verbose: print 'Compute mean values'
1082        #    for i, t in enumerate(self.T):
1083        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1084        #        for name in quantity_names:
1085        #           self.values[name][i] = mean(quantities[name][i,:])
1086
1087
1088
1089
1090    def __repr__(self):
1091        #return 'Interpolation function (spatio-temporal)'
1092        return self.statistics()
1093   
1094
1095    def __call__(self, t, point_id = None, x = None, y = None):
1096        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1097
1098        Inputs:
1099          t: time - Model time. Must lie within existing timesteps
1100          point_id: index of one of the preprocessed points.
1101          x, y:     Overrides location, point_id ignored
1102         
1103          If spatial info is present and all of x,y,point_id
1104          are None an exception is raised
1105                   
1106          If no spatial info is present, point_id and x,y arguments are ignored
1107          making f a function of time only.
1108
1109         
1110          FIXME: point_id could also be a slice
1111          FIXME: What if x and y are vectors?
1112          FIXME: What about f(x,y) without t?
1113        """
1114
1115        from math import pi, cos, sin, sqrt
1116        from Numeric import zeros, Float
1117        from util import mean       
1118
1119        if self.spatial is True:
1120            if point_id is None:
1121                if x is None or y is None:
1122                    msg = 'Either point_id or x and y must be specified'
1123                    raise msg
1124            else:
1125                if self.interpolation_points is None:
1126                    msg = 'Interpolation_function must be instantiated ' +\
1127                          'with a list of interpolation points before parameter ' +\
1128                          'point_id can be used'
1129                    raise msg
1130
1131
1132        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1133        msg += ' does not match model time: %s\n' %t
1134        if t < self.T[0]: raise msg
1135        if t > self.T[-1]: raise msg
1136
1137        oldindex = self.index #Time index
1138
1139        #Find current time slot
1140        while t > self.T[self.index]: self.index += 1
1141        while t < self.T[self.index]: self.index -= 1
1142
1143        if t == self.T[self.index]:
1144            #Protect against case where t == T[-1] (last time)
1145            # - also works in general when t == T[i]
1146            ratio = 0
1147        else:
1148            #t is now between index and index+1
1149            ratio = (t - self.T[self.index])/\
1150                    (self.T[self.index+1] - self.T[self.index])
1151
1152        #Compute interpolated values
1153        q = zeros(len(self.quantity_names), Float)
1154
1155        for i, name in enumerate(self.quantity_names):
1156            Q = self.precomputed_values[name]
1157
1158            if self.spatial is False:
1159                #If there is no spatial info               
1160                assert len(Q.shape) == 1
1161
1162                Q0 = Q[self.index]
1163                if ratio > 0: Q1 = Q[self.index+1]
1164
1165            else:
1166                if x is not None and y is not None:
1167                    #Interpolate to x, y
1168                   
1169                    raise 'x,y interpolation not yet implemented'
1170                else:
1171                    #Use precomputed point
1172                    Q0 = Q[self.index, point_id]
1173                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1174
1175            #Linear temporal interpolation   
1176            if ratio > 0:
1177                q[i] = Q0 + ratio*(Q1 - Q0)
1178            else:
1179                q[i] = Q0
1180
1181
1182        #Return vector of interpolated values
1183        #if len(q) == 1:
1184        #    return q[0]
1185        #else:
1186        #    return q
1187
1188
1189        #Return vector of interpolated values
1190        #FIXME:
1191        if self.spatial is True:
1192            return q
1193        else:
1194            #Replicate q according to x and y
1195            #This is e.g used for Wind_stress
1196            if x == None or y == None: 
1197                return q
1198            else:
1199                try:
1200                    N = len(x)
1201                except:
1202                    return q
1203                else:
1204                    from Numeric import ones, Float
1205                    #x is a vector - Create one constant column for each value
1206                    N = len(x)
1207                    assert len(y) == N, 'x and y must have same length'
1208                    res = []
1209                    for col in q:
1210                        res.append(col*ones(N, Float))
1211                       
1212                return res
1213
1214
1215    def statistics(self):
1216        """Output statistics about interpolation_function
1217        """
1218       
1219        vertex_coordinates = self.vertex_coordinates
1220        interpolation_points = self.interpolation_points               
1221        quantity_names = self.quantity_names
1222        quantities = self.quantities
1223        precomputed_values = self.precomputed_values                 
1224               
1225        x = vertex_coordinates[:,0]
1226        y = vertex_coordinates[:,1]               
1227
1228        str =  '------------------------------------------------\n'
1229        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1230        str += '  Extent:\n'
1231        str += '    x in [%f, %f], len(x) == %d\n'\
1232               %(min(x), max(x), len(x))
1233        str += '    y in [%f, %f], len(y) == %d\n'\
1234               %(min(y), max(y), len(y))
1235        str += '    t in [%f, %f], len(t) == %d\n'\
1236               %(min(self.T), max(self.T), len(self.T))
1237        str += '  Quantities:\n'
1238        for name in quantity_names:
1239            q = quantities[name][:].flat
1240            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1241
1242        if interpolation_points is not None:   
1243            str += '  Interpolation points (xi, eta):'\
1244                   ' number of points == %d\n' %interpolation_points.shape[0]
1245            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1246                                            max(interpolation_points[:,0]))
1247            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1248                                             max(interpolation_points[:,1]))
1249            str += '  Interpolated quantities (over all timesteps):\n'
1250       
1251            for name in quantity_names:
1252                q = precomputed_values[name][:].flat
1253                str += '    %s at interpolation points in [%f, %f]\n'\
1254                       %(name, min(q), max(q))
1255        str += '------------------------------------------------\n'
1256
1257        return str
1258
1259        #FIXME: Delete
1260        #print '------------------------------------------------'
1261        #print 'Interpolation_function statistics:'
1262        #print '  Extent:'
1263        #print '    x in [%f, %f], len(x) == %d'\
1264        #      %(min(x), max(x), len(x))
1265        #print '    y in [%f, %f], len(y) == %d'\
1266        #      %(min(y), max(y), len(y))
1267        #print '    t in [%f, %f], len(t) == %d'\
1268        #      %(min(self.T), max(self.T), len(self.T))
1269        #print '  Quantities:'
1270        #for name in quantity_names:
1271        #    q = quantities[name][:].flat
1272        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1273        #print '  Interpolation points (xi, eta):'\
1274        #      ' number of points == %d ' %interpolation_points.shape[0]
1275        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1276        #                             max(interpolation_points[:,0]))
1277        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1278        #                              max(interpolation_points[:,1]))
1279        #print '  Interpolated quantities (over all timesteps):'
1280        #
1281        #for name in quantity_names:
1282        #    q = precomputed_values[name][:].flat
1283        #    print '    %s at interpolation points in [%f, %f]'\
1284        #          %(name, min(q), max(q))
1285        #print '------------------------------------------------'
1286
1287
1288#-------------------------------------------------------------
1289if __name__ == "__main__":
1290    """
1291    Load in a mesh and data points with attributes.
1292    Fit the attributes to the mesh.
1293    Save a new mesh file.
1294    """
1295    import os, sys
1296    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1297            %os.path.basename(sys.argv[0])
1298
1299    if len(sys.argv) < 4:
1300        print usage
1301    else:
1302        mesh_file = sys.argv[1]
1303        point_file = sys.argv[2]
1304        mesh_output_file = sys.argv[3]
1305
1306        expand_search = False
1307        if len(sys.argv) > 4:
1308            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1309                expand_search = True
1310            else:
1311                expand_search = False
1312
1313        verbose = False
1314        if len(sys.argv) > 5:
1315            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1316                verbose = False
1317            else:
1318                verbose = True
1319
1320        if len(sys.argv) > 6:
1321            alpha = sys.argv[6]
1322        else:
1323            alpha = DEFAULT_ALPHA
1324
1325        # This is used more for testing
1326        if len(sys.argv) > 7:
1327            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1328                display_errors = False
1329            else:
1330                display_errors = True
1331           
1332        t0 = time.time()
1333        try:
1334            fit_to_mesh_file(mesh_file,
1335                         point_file,
1336                         mesh_output_file,
1337                         alpha,
1338                         verbose= verbose,
1339                         expand_search = expand_search,
1340                         display_errors = display_errors)
1341        except IOError,e:
1342            import sys; sys.exit(1)
1343
1344        print 'That took %.2f seconds' %(time.time()-t0)
1345
Note: See TracBrowser for help on using the repository browser.