source: inundation/pyvolution/least_squares.py @ 1942

Last change on this file since 1942 was 1942, checked in by duncan, 18 years ago

bug fix

File size: 47.7 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from pyvolution.mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from pyvolution.sparse import Sparse, Sparse_CSR
29from pyvolution.cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME (DSG-ON) reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203    interp = Interpolation(vertex_coordinates,
204                           triangles,
205                           point_coordinates,
206                           alpha = alpha,
207                           verbose = verbose,
208                           expand_search = expand_search,
209                           data_origin = data_origin,
210                           mesh_origin = mesh_origin,
211                           precrop = precrop)
212
213    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
214    return vertex_attributes
215
216
217
218def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
219                    verbose = False, reduction = 1):
220    """Fits attributes from pts file to MxN rectangular mesh
221
222    Read pts file and create rectangular mesh of resolution MxN such that
223    it covers all points specified in pts file.
224
225    FIXME: This may be a temporary function until we decide on
226    netcdf formats etc
227
228    FIXME: Uses elevation hardwired
229    """
230
231    import  mesh_factory
232    from load_mesh.loadASCII import import_points_file
233   
234    if verbose: print 'Read pts'
235    points_dict = import_points_file(pts_name)
236    #points, attributes = util.read_xya(pts_name)
237
238    #Reduce number of points a bit
239    points = points_dict['pointlist'][::reduction]
240    elevation = points_dict['attributelist']['elevation']  #Must be elevation
241    elevation = elevation[::reduction]
242
243    if verbose: print 'Got %d data points' %len(points)
244
245    if verbose: print 'Create mesh'
246    #Find extent
247    max_x = min_x = points[0][0]
248    max_y = min_y = points[0][1]
249    for point in points[1:]:
250        x = point[0]
251        if x > max_x: max_x = x
252        if x < min_x: min_x = x
253        y = point[1]
254        if y > max_y: max_y = y
255        if y < min_y: min_y = y
256
257    #Create appropriate mesh
258    vertex_coordinates, triangles, boundary =\
259         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
260                                (min_x, min_y))
261
262    #Fit attributes to mesh
263    vertex_attributes = fit_to_mesh(vertex_coordinates,
264                        triangles,
265                        points,
266                        elevation, alpha=alpha, verbose=verbose)
267
268
269
270    return vertex_coordinates, triangles, boundary, vertex_attributes
271
272
273
274class Interpolation:
275
276    def __init__(self,
277                 vertex_coordinates,
278                 triangles,
279                 point_coordinates = None,
280                 alpha = None,
281                 verbose = False,
282                 expand_search = True,
283                 interp_only = False,
284                 max_points_per_cell = 30,
285                 mesh_origin = None,
286                 data_origin = None,
287                 precrop = False):
288
289
290        """ Build interpolation matrix mapping from
291        function values at vertices to function values at data points
292
293        Inputs:
294
295          vertex_coordinates: List of coordinate pairs [xi, eta] of
296          points constituting mesh (or a an m x 2 Numeric array)
297          Points may appear multiple times
298          (e.g. if vertices have discontinuities)
299
300          triangles: List of 3-tuples (or a Numeric array) of
301          integers representing indices of all vertices in the mesh.
302
303          point_coordinates: List of coordinate pairs [x, y] of
304          data points (or an nx2 Numeric array)
305          If point_coordinates is absent, only smoothing matrix will
306          be built
307
308          alpha: Smoothing parameter
309
310          data_origin and mesh_origin are 3-tuples consisting of
311          UTM zone, easting and northing. If specified
312          point coordinates and vertex coordinates are assumed to be
313          relative to their respective origins.
314
315        """
316        from pyvolution.util import ensure_numeric
317
318        #Convert input to Numeric arrays
319        triangles = ensure_numeric(triangles, Int)
320        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
321
322        #Build underlying mesh
323        if verbose: print 'Building mesh'
324        #self.mesh = General_mesh(vertex_coordinates, triangles,
325        #FIXME: Trying the normal mesh while testing precrop,
326        #       The functionality of boundary_polygon is needed for that
327
328        #FIXME - geo ref does not have to go into mesh.
329        # Change the point co-ords to conform to the
330        # mesh co-ords early in the code
331        if mesh_origin == None:
332            geo = None
333        else:
334            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
335        self.mesh = Mesh(vertex_coordinates, triangles,
336                         geo_reference = geo)
337       
338        self.mesh.check_integrity()
339
340        self.data_origin = data_origin
341
342        self.point_indices = None
343
344        #Smoothing parameter
345        if alpha is None:
346            self.alpha = DEFAULT_ALPHA
347        else:   
348            self.alpha = alpha
349
350
351        if point_coordinates is not None:
352            if verbose: print 'Building interpolation matrix'
353            self.build_interpolation_matrix_A(point_coordinates,
354                                              verbose = verbose,
355                                              expand_search = expand_search,
356                                              interp_only = interp_only, 
357                                              max_points_per_cell =\
358                                              max_points_per_cell,
359                                              data_origin = data_origin,
360                                              precrop = precrop)
361        #Build coefficient matrices
362        if interp_only == False:
363            self.build_coefficient_matrix_B(point_coordinates,
364                                        verbose = verbose,
365                                        expand_search = expand_search,
366                                        max_points_per_cell =\
367                                        max_points_per_cell,
368                                        data_origin = data_origin,
369                                        precrop = precrop)
370
371
372    def set_point_coordinates(self, point_coordinates,
373                              data_origin = None,
374                              verbose = False,
375                              precrop = True):
376        """
377        A public interface to setting the point co-ordinates.
378        """
379        if point_coordinates is not None:
380            if verbose: print 'Building interpolation matrix'
381            self.build_interpolation_matrix_A(point_coordinates,
382                                              verbose = verbose,
383                                              data_origin = data_origin,
384                                              precrop = precrop)
385        self.build_coefficient_matrix_B(point_coordinates, data_origin)
386
387    def build_coefficient_matrix_B(self, point_coordinates=None,
388                                   verbose = False, expand_search = True,
389                                   max_points_per_cell=30,
390                                   data_origin = None,
391                                   precrop = False):
392        """Build final coefficient matrix"""
393
394
395        if self.alpha <> 0:
396            if verbose: print 'Building smoothing matrix'
397            self.build_smoothing_matrix_D()
398
399        if point_coordinates is not None:
400            if self.alpha <> 0:
401                self.B = self.AtA + self.alpha*self.D
402            else:
403                self.B = self.AtA
404
405            #Convert self.B matrix to CSR format for faster matrix vector
406            self.B = Sparse_CSR(self.B)
407
408    def build_interpolation_matrix_A(self, point_coordinates,
409                                     verbose = False, expand_search = True,
410                                     max_points_per_cell=30,
411                                     data_origin = None,
412                                     precrop = False,
413                                     interp_only = False):
414        """Build n x m interpolation matrix, where
415        n is the number of data points and
416        m is the number of basis functions phi_k (one per vertex)
417
418        This algorithm uses a quad tree data structure for fast binning of data points
419        origin is a 3-tuple consisting of UTM zone, easting and northing.
420        If specified coordinates are assumed to be relative to this origin.
421
422        This one will override any data_origin that may be specified in
423        interpolation instance
424
425        """
426
427
428        #FIXME (Ole): Check that this function is memeory efficient.
429        #6 million datapoints and 300000 basis functions
430        #causes out-of-memory situation
431        #First thing to check is whether there is room for self.A and self.AtA
432        #
433        #Maybe we need some sort of blocking
434
435        from pyvolution.quad import build_quadtree
436        from utilities.numerical_tools import ensure_numeric
437        from utilities.polygon import inside_polygon
438       
439
440        if data_origin is None:
441            data_origin = self.data_origin #Use the one from
442                                           #interpolation instance
443
444        #Convert input to Numeric arrays just in case.
445        point_coordinates = ensure_numeric(point_coordinates, Float)
446
447        #Keep track of discarded points (if any).
448        #This is only registered if precrop is True
449        self.cropped_points = False
450
451        #Shift data points to same origin as mesh (if specified)
452
453        #FIXME this will shift if there was no geo_ref.
454        #But all this should be removed anyhow.
455        #change coords before this point
456        mesh_origin = self.mesh.geo_reference.get_origin()
457        if point_coordinates is not None:
458            if data_origin is not None:
459                if mesh_origin is not None:
460
461                    #Transformation:
462                    #
463                    #Let x_0 be the reference point of the point coordinates
464                    #and xi_0 the reference point of the mesh.
465                    #
466                    #A point coordinate (x + x_0) is then made relative
467                    #to xi_0 by
468                    #
469                    # x_new = x + x_0 - xi_0
470                    #
471                    #and similarly for eta
472
473                    x_offset = data_origin[1] - mesh_origin[1]
474                    y_offset = data_origin[2] - mesh_origin[2]
475                else: #Shift back to a zero origin
476                    x_offset = data_origin[1]
477                    y_offset = data_origin[2]
478
479                point_coordinates[:,0] += x_offset
480                point_coordinates[:,1] += y_offset
481            else:
482                if mesh_origin is not None:
483                    #Use mesh origin for data points
484                    point_coordinates[:,0] -= mesh_origin[1]
485                    point_coordinates[:,1] -= mesh_origin[2]
486
487
488
489        #Remove points falling outside mesh boundary
490        #This reduced one example from 1356 seconds to 825 seconds
491        if precrop is True:
492            from Numeric import take
493
494            if verbose: print 'Getting boundary polygon'
495            P = self.mesh.get_boundary_polygon()
496
497            if verbose: print 'Getting indices inside mesh boundary'
498            indices = inside_polygon(point_coordinates, P, verbose = verbose)
499
500
501            if len(indices) != point_coordinates.shape[0]:
502                self.cropped_points = True
503                if verbose:
504                    print 'Done - %d points outside mesh have been cropped.'\
505                          %(point_coordinates.shape[0] - len(indices))
506
507            point_coordinates = take(point_coordinates, indices)
508            self.point_indices = indices
509
510
511
512
513        #Build n x m interpolation matrix
514        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
515        n = point_coordinates.shape[0]     #Nbr of data points
516
517        if verbose: print 'Number of datapoints: %d' %n
518        if verbose: print 'Number of basis functions: %d' %m
519
520        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
521        #However, Sparse_CSR does not have the same methods as Sparse yet
522        #The tests will reveal what needs to be done
523
524        #
525        #self.A = Sparse_CSR(Sparse(n,m))
526        #self.AtA = Sparse_CSR(Sparse(m,m))
527        self.A = Sparse(n,m)
528        self.AtA = Sparse(m,m)
529
530        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
531        root = build_quadtree(self.mesh,
532                              max_points_per_cell = max_points_per_cell)
533        #root.show()
534        self.expanded_quad_searches = []
535        #Compute matrix elements
536        for i in range(n):
537            #For each data_coordinate point
538
539            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
540            x = point_coordinates[i]
541
542            #Find vertices near x
543            candidate_vertices = root.search(x[0], x[1])
544            is_more_elements = True
545
546            element_found, sigma0, sigma1, sigma2, k = \
547                self.search_triangles_of_vertices(candidate_vertices, x)
548            first_expansion = True
549            while not element_found and is_more_elements and expand_search:
550                if verbose: print 'Expanding search'
551                if first_expansion == True:
552                    self.expanded_quad_searches.append(1)
553                else:
554                    end = len(expanded_quad_searches) - 1
555                    assert end >= 0
556                    self.expanded_quad_searches[end] += 1
557                    print "expanded_quad_searches",expanded_quad_searches
558                candidate_vertices, branch = root.expand_search()
559                if branch == []:
560                    # Searching all the verts from the root cell that haven't
561                    # been searched.  This is the last try
562                    element_found, sigma0, sigma1, sigma2, k = \
563                      self.search_triangles_of_vertices(candidate_vertices, x)
564                    is_more_elements = False
565                else:
566                    element_found, sigma0, sigma1, sigma2, k = \
567                      self.search_triangles_of_vertices(candidate_vertices, x)
568
569               
570            #Update interpolation matrix A if necessary
571            if element_found is True:
572                #Assign values to matrix A
573
574                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
575                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
576                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
577
578                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
579                js     = [j0,j1,j2]
580
581                for j in js:
582                    self.A[i,j] = sigmas[j]
583                    for k in js:
584                        if interp_only == False:
585                            self.AtA[j,k] += sigmas[j]*sigmas[k]
586            else:
587                pass
588                #Ok if there is no triangle for datapoint
589                #(as in brute force version)
590                #raise 'Could not find triangle for point', x
591
592
593
594    def search_triangles_of_vertices(self, candidate_vertices, x):
595            #Find triangle containing x:
596            element_found = False
597
598            # This will be returned if element_found = False
599            sigma2 = -10.0
600            sigma0 = -10.0
601            sigma1 = -10.0
602            k = -10.0
603            #print "*$* candidate_vertices", candidate_vertices
604            #For all vertices in same cell as point x
605            for v in candidate_vertices:
606
607                #for each triangle id (k) which has v as a vertex
608                for k, _ in self.mesh.vertexlist[v]:
609
610                    #Get the three vertex_points of candidate triangle
611                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
612                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
613                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
614
615                    #print "PDSG - k", k
616                    #print "PDSG - xi0", xi0
617                    #print "PDSG - xi1", xi1
618                    #print "PDSG - xi2", xi2
619                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
620                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
621
622                    #Get the three normals
623                    n0 = self.mesh.get_normal(k, 0)
624                    n1 = self.mesh.get_normal(k, 1)
625                    n2 = self.mesh.get_normal(k, 2)
626
627
628                    #Compute interpolation
629                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
630                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
631                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
632
633                    #print "PDSG - sigma0", sigma0
634                    #print "PDSG - sigma1", sigma1
635                    #print "PDSG - sigma2", sigma2
636
637                    #FIXME: Maybe move out to test or something
638                    epsilon = 1.0e-6
639                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
640
641                    #Check that this triangle contains the data point
642
643                    #Sigmas can get negative within
644                    #machine precision on some machines (e.g nautilus)
645                    #Hence the small eps
646                    eps = 1.0e-15
647                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
648                        element_found = True
649                        break
650
651                if element_found is True:
652                    #Don't look for any other triangle
653                    break
654            return element_found, sigma0, sigma1, sigma2, k
655
656
657
658    def build_interpolation_matrix_A_brute(self, point_coordinates):
659        """Build n x m interpolation matrix, where
660        n is the number of data points and
661        m is the number of basis functions phi_k (one per vertex)
662
663        This is the brute force which is too slow for large problems,
664        but could be used for testing
665        """
666
667        from util import ensure_numeric
668
669        #Convert input to Numeric arrays
670        point_coordinates = ensure_numeric(point_coordinates, Float)
671
672        #Build n x m interpolation matrix
673        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
674        n = point_coordinates.shape[0]     #Nbr of data points
675
676        self.A = Sparse(n,m)
677        self.AtA = Sparse(m,m)
678
679        #Compute matrix elements
680        for i in range(n):
681            #For each data_coordinate point
682
683            x = point_coordinates[i]
684            element_found = False
685            k = 0
686            while not element_found and k < len(self.mesh):
687                #For each triangle (brute force)
688                #FIXME: Real algorithm should only visit relevant triangles
689
690                #Get the three vertex_points
691                xi0 = self.mesh.get_vertex_coordinate(k, 0)
692                xi1 = self.mesh.get_vertex_coordinate(k, 1)
693                xi2 = self.mesh.get_vertex_coordinate(k, 2)
694
695                #Get the three normals
696                n0 = self.mesh.get_normal(k, 0)
697                n1 = self.mesh.get_normal(k, 1)
698                n2 = self.mesh.get_normal(k, 2)
699
700                #Compute interpolation
701                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
702                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
703                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
704
705                #FIXME: Maybe move out to test or something
706                epsilon = 1.0e-6
707                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
708
709                #Check that this triangle contains data point
710                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
711                    element_found = True
712                    #Assign values to matrix A
713
714                    j0 = self.mesh.triangles[k,0] #Global vertex id
715                    #self.A[i, j0] = sigma0
716
717                    j1 = self.mesh.triangles[k,1] #Global vertex id
718                    #self.A[i, j1] = sigma1
719
720                    j2 = self.mesh.triangles[k,2] #Global vertex id
721                    #self.A[i, j2] = sigma2
722
723                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
724                    js     = [j0,j1,j2]
725
726                    for j in js:
727                        self.A[i,j] = sigmas[j]
728                        for k in js:
729                            self.AtA[j,k] += sigmas[j]*sigmas[k]
730                k = k+1
731
732
733
734    def get_A(self):
735        return self.A.todense()
736
737    def get_B(self):
738        return self.B.todense()
739
740    def get_D(self):
741        return self.D.todense()
742
743        #FIXME: Remember to re-introduce the 1/n factor in the
744        #interpolation term
745
746    def build_smoothing_matrix_D(self):
747        """Build m x m smoothing matrix, where
748        m is the number of basis functions phi_k (one per vertex)
749
750        The smoothing matrix is defined as
751
752        D = D1 + D2
753
754        where
755
756        [D1]_{k,l} = \int_\Omega
757           \frac{\partial \phi_k}{\partial x}
758           \frac{\partial \phi_l}{\partial x}\,
759           dx dy
760
761        [D2]_{k,l} = \int_\Omega
762           \frac{\partial \phi_k}{\partial y}
763           \frac{\partial \phi_l}{\partial y}\,
764           dx dy
765
766
767        The derivatives \frac{\partial \phi_k}{\partial x},
768        \frac{\partial \phi_k}{\partial x} for a particular triangle
769        are obtained by computing the gradient a_k, b_k for basis function k
770        """
771
772        #FIXME: algorithm might be optimised by computing local 9x9
773        #"element stiffness matrices:
774
775        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
776
777        self.D = Sparse(m,m)
778
779        #For each triangle compute contributions to D = D1+D2
780        for i in range(len(self.mesh)):
781
782            #Get area
783            area = self.mesh.areas[i]
784
785            #Get global vertex indices
786            v0 = self.mesh.triangles[i,0]
787            v1 = self.mesh.triangles[i,1]
788            v2 = self.mesh.triangles[i,2]
789
790            #Get the three vertex_points
791            xi0 = self.mesh.get_vertex_coordinate(i, 0)
792            xi1 = self.mesh.get_vertex_coordinate(i, 1)
793            xi2 = self.mesh.get_vertex_coordinate(i, 2)
794
795            #Compute gradients for each vertex
796            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
797                              1, 0, 0)
798
799            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
800                              0, 1, 0)
801
802            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
803                              0, 0, 1)
804
805            #Compute diagonal contributions
806            self.D[v0,v0] += (a0*a0 + b0*b0)*area
807            self.D[v1,v1] += (a1*a1 + b1*b1)*area
808            self.D[v2,v2] += (a2*a2 + b2*b2)*area
809
810            #Compute contributions for basis functions sharing edges
811            e01 = (a0*a1 + b0*b1)*area
812            self.D[v0,v1] += e01
813            self.D[v1,v0] += e01
814
815            e12 = (a1*a2 + b1*b2)*area
816            self.D[v1,v2] += e12
817            self.D[v2,v1] += e12
818
819            e20 = (a2*a0 + b2*b0)*area
820            self.D[v2,v0] += e20
821            self.D[v0,v2] += e20
822
823
824    def fit(self, z):
825        """Fit a smooth surface to given 1d array of data points z.
826
827        The smooth surface is computed at each vertex in the underlying
828        mesh using the formula given in the module doc string.
829
830        Pre Condition:
831          self.A, self.AtA and self.B have been initialised
832
833        Inputs:
834          z: Single 1d vector or array of data at the point_coordinates.
835        """
836
837        #Convert input to Numeric arrays
838        from pyvolution.util import ensure_numeric
839        z = ensure_numeric(z, Float)
840
841        if len(z.shape) > 1 :
842            raise VectorShapeError, 'Can only deal with 1d data vector'
843
844        if self.point_indices is not None:
845            #Remove values for any points that were outside mesh
846            z = take(z, self.point_indices)
847
848        #Compute right hand side based on data
849        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
850        # after a matrix is built, before calcs.
851        Atz = self.A.trans_mult(z)
852
853
854        #Check sanity
855        n, m = self.A.shape
856        if n<m and self.alpha == 0.0:
857            msg = 'ERROR (least_squares): Too few data points\n'
858            msg += 'There are only %d data points and alpha == 0. ' %n
859            msg += 'Need at least %d\n' %m
860            msg += 'Alternatively, set smoothing parameter alpha to a small '
861            msg += 'positive value,\ne.g. 1.0e-3.'
862            raise msg
863
864
865
866        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
867        #FIXME: Should we store the result here for later use? (ON)
868
869
870    def fit_points(self, z, verbose=False):
871        """Like fit, but more robust when each point has two or more attributes
872        FIXME (Ole): The name fit_points doesn't carry any meaning
873        for me. How about something like fit_multiple or fit_columns?
874        """
875
876        try:
877            if verbose: print 'Solving penalised least_squares problem'
878            return self.fit(z)
879        except VectorShapeError, e:
880            # broadcasting is not supported.
881
882            #Convert input to Numeric arrays
883            from util import ensure_numeric
884            z = ensure_numeric(z, Float)
885
886            #Build n x m interpolation matrix
887            m = self.mesh.coordinates.shape[0] #Number of vertices
888            n = z.shape[1]                     #Number of data points
889
890            f = zeros((m,n), Float) #Resulting columns
891
892            for i in range(z.shape[1]):
893                f[:,i] = self.fit(z[:,i])
894
895            return f
896
897
898    def interpolate(self, f):
899        """Evaluate smooth surface f at data points implied in self.A.
900
901        The mesh values representing a smooth surface are
902        assumed to be specified in f. This argument could,
903        for example have been obtained from the method self.fit()
904
905        Pre Condition:
906          self.A has been initialised
907
908        Inputs:
909          f: Vector or array of data at the mesh vertices.
910          If f is an array, interpolation will be done for each column as
911          per underlying matrix-matrix multiplication
912
913        Output:
914          Interpolated values at data points implied in self.A
915
916        """
917
918        return self.A * f
919
920    def cull_outsiders(self, f):
921        pass
922
923
924
925
926class Interpolation_function:
927    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
928    which is interpolated from time series defined at vertices of
929    triangular mesh (such as those stored in sww files)
930
931    Let m be the number of vertices, n the number of triangles
932    and p the number of timesteps.
933
934    Mandatory input
935        time:               px1 array of monotonously increasing times (Float)
936        quantities:         Dictionary of arrays or 1 array (Float)
937                            The arrays must either have dimensions pxm or mx1.
938                            The resulting function will be time dependent in
939                            the former case while it will be constan with
940                            respect to time in the latter case.
941       
942    Optional input:
943        quantity_names:     List of keys into the quantities dictionary
944        vertex_coordinates: mx2 array of coordinates (Float)
945        triangles:          nx3 array of indices into vertex_coordinates (Int)
946        interpolation_points: Nx2 array of coordinates to be interpolated to
947        verbose:            Level of reporting
948   
949   
950    The quantities returned by the callable object are specified by
951    the list quantities which must contain the names of the
952    quantities to be returned and also reflect the order, e.g. for
953    the shallow water wave equation, on would have
954    quantities = ['stage', 'xmomentum', 'ymomentum']
955
956    The parameter interpolation_points decides at which points interpolated
957    quantities are to be computed whenever object is called.
958    If None, return average value
959    """
960
961   
962   
963    def __init__(self,
964                 time,
965                 quantities,
966                 quantity_names = None, 
967                 vertex_coordinates = None,
968                 triangles = None,
969                 interpolation_points = None,
970                 verbose = False):
971        """Initialise object and build spatial interpolation if required
972        """
973
974        from Numeric import array, zeros, Float, alltrue, concatenate,\
975             reshape, ArrayType
976
977
978        from util import mean, ensure_numeric
979        from config import time_format
980        import types
981
982
983
984        #Check temporal info
985        time = ensure_numeric(time)       
986        msg = 'Time must be a monotonuosly '
987        msg += 'increasing sequence %s' %time
988        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
989
990
991        #Check if quantities is a single array only
992        if type(quantities) != types.DictType:
993            quantities = ensure_numeric(quantities)
994            quantity_names = ['Attribute']
995
996            #Make it a dictionary
997            quantities = {quantity_names[0]: quantities}
998
999
1000        #Use keys if no names are specified
1001        if quantity_names is None:
1002            quantity_names = quantities.keys()
1003
1004
1005        #Check spatial info
1006        if vertex_coordinates is None:
1007            self.spatial = False
1008        else:   
1009            vertex_coordinates = ensure_numeric(vertex_coordinates)
1010
1011            assert triangles is not None, 'Triangles array must be specified'
1012            triangles = ensure_numeric(triangles)
1013            self.spatial = True           
1014           
1015
1016 
1017        #Save for use with statistics
1018        self.quantity_names = quantity_names       
1019        self.quantities = quantities       
1020        self.vertex_coordinates = vertex_coordinates
1021        self.interpolation_points = interpolation_points
1022        self.T = time[:]  # Time assumed to be relative to starttime
1023        self.index = 0    # Initial time index
1024        self.precomputed_values = {}
1025           
1026
1027
1028        #Precomputed spatial interpolation if requested
1029        if interpolation_points is not None:
1030            if self.spatial is False:
1031                raise 'Triangles and vertex_coordinates must be specified'
1032           
1033            try:
1034                self.interpolation_points = ensure_numeric(interpolation_points)
1035            except:
1036                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1037                      'or a list of points\n'
1038                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1039                                      '...')
1040                raise msg
1041
1042
1043            m = len(self.interpolation_points)
1044            p = len(self.T)
1045           
1046            for name in quantity_names:
1047                self.precomputed_values[name] = zeros((p, m), Float)
1048
1049            #Build interpolator
1050            interpol = Interpolation(vertex_coordinates,
1051                                     triangles,
1052                                     point_coordinates = \
1053                                     self.interpolation_points,
1054                                     alpha = 0,
1055                                     precrop = False, 
1056                                     verbose = verbose)
1057
1058            if verbose: print 'Interpolate'
1059            for i, t in enumerate(self.T):
1060                #Interpolate quantities at this timestep
1061                if verbose and i%((p+10)/10)==0:
1062                    print ' time step %d of %d' %(i, p)
1063                   
1064                for name in quantity_names:
1065                    if len(quantities[name].shape) == 2:
1066                        result = interpol.interpolate(quantities[name][i,:])
1067                    else:
1068                       #Assume no time dependency
1069                       result = interpol.interpolate(quantities[name][:])
1070                       
1071                    self.precomputed_values[name][i, :] = result
1072                   
1073                       
1074
1075            #Report
1076            if verbose:
1077                print self.statistics()
1078                #self.print_statistics()
1079           
1080        else:
1081            #Store quantitites as is
1082            for name in quantity_names:
1083                self.precomputed_values[name] = quantities[name]
1084
1085
1086        #else:
1087        #    #Return an average, making this a time series
1088        #    for name in quantity_names:
1089        #        self.values[name] = zeros(len(self.T), Float)
1090        #
1091        #    if verbose: print 'Compute mean values'
1092        #    for i, t in enumerate(self.T):
1093        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1094        #        for name in quantity_names:
1095        #           self.values[name][i] = mean(quantities[name][i,:])
1096
1097
1098
1099
1100    def __repr__(self):
1101        #return 'Interpolation function (spatio-temporal)'
1102        return self.statistics()
1103   
1104
1105    def __call__(self, t, point_id = None, x = None, y = None):
1106        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1107
1108        Inputs:
1109          t: time - Model time. Must lie within existing timesteps
1110          point_id: index of one of the preprocessed points.
1111          x, y:     Overrides location, point_id ignored
1112         
1113          If spatial info is present and all of x,y,point_id
1114          are None an exception is raised
1115                   
1116          If no spatial info is present, point_id and x,y arguments are ignored
1117          making f a function of time only.
1118
1119         
1120          FIXME: point_id could also be a slice
1121          FIXME: What if x and y are vectors?
1122          FIXME: What about f(x,y) without t?
1123        """
1124
1125        from math import pi, cos, sin, sqrt
1126        from Numeric import zeros, Float
1127        from util import mean       
1128
1129        if self.spatial is True:
1130            if point_id is None:
1131                if x is None or y is None:
1132                    msg = 'Either point_id or x and y must be specified'
1133                    raise msg
1134            else:
1135                if self.interpolation_points is None:
1136                    msg = 'Interpolation_function must be instantiated ' +\
1137                          'with a list of interpolation points before parameter ' +\
1138                          'point_id can be used'
1139                    raise msg
1140
1141
1142        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1143        msg += ' does not match model time: %s\n' %t
1144        if t < self.T[0]: raise msg
1145        if t > self.T[-1]: raise msg
1146
1147        oldindex = self.index #Time index
1148
1149        #Find current time slot
1150        while t > self.T[self.index]: self.index += 1
1151        while t < self.T[self.index]: self.index -= 1
1152
1153        if t == self.T[self.index]:
1154            #Protect against case where t == T[-1] (last time)
1155            # - also works in general when t == T[i]
1156            ratio = 0
1157        else:
1158            #t is now between index and index+1
1159            ratio = (t - self.T[self.index])/\
1160                    (self.T[self.index+1] - self.T[self.index])
1161
1162        #Compute interpolated values
1163        q = zeros(len(self.quantity_names), Float)
1164
1165        for i, name in enumerate(self.quantity_names):
1166            Q = self.precomputed_values[name]
1167
1168            if self.spatial is False:
1169                #If there is no spatial info               
1170                assert len(Q.shape) == 1
1171
1172                Q0 = Q[self.index]
1173                if ratio > 0: Q1 = Q[self.index+1]
1174
1175            else:
1176                if x is not None and y is not None:
1177                    #Interpolate to x, y
1178                   
1179                    raise 'x,y interpolation not yet implemented'
1180                else:
1181                    #Use precomputed point
1182                    Q0 = Q[self.index, point_id]
1183                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1184
1185            #Linear temporal interpolation   
1186            if ratio > 0:
1187                q[i] = Q0 + ratio*(Q1 - Q0)
1188            else:
1189                q[i] = Q0
1190
1191
1192        #Return vector of interpolated values
1193        #if len(q) == 1:
1194        #    return q[0]
1195        #else:
1196        #    return q
1197
1198
1199        #Return vector of interpolated values
1200        #FIXME:
1201        if self.spatial is True:
1202            return q
1203        else:
1204            #Replicate q according to x and y
1205            #This is e.g used for Wind_stress
1206            if x == None or y == None: 
1207                return q
1208            else:
1209                try:
1210                    N = len(x)
1211                except:
1212                    return q
1213                else:
1214                    from Numeric import ones, Float
1215                    #x is a vector - Create one constant column for each value
1216                    N = len(x)
1217                    assert len(y) == N, 'x and y must have same length'
1218                    res = []
1219                    for col in q:
1220                        res.append(col*ones(N, Float))
1221                       
1222                return res
1223
1224
1225    def statistics(self):
1226        """Output statistics about interpolation_function
1227        """
1228       
1229        vertex_coordinates = self.vertex_coordinates
1230        interpolation_points = self.interpolation_points               
1231        quantity_names = self.quantity_names
1232        quantities = self.quantities
1233        precomputed_values = self.precomputed_values                 
1234               
1235        x = vertex_coordinates[:,0]
1236        y = vertex_coordinates[:,1]               
1237
1238        str =  '------------------------------------------------\n'
1239        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1240        str += '  Extent:\n'
1241        str += '    x in [%f, %f], len(x) == %d\n'\
1242               %(min(x), max(x), len(x))
1243        str += '    y in [%f, %f], len(y) == %d\n'\
1244               %(min(y), max(y), len(y))
1245        str += '    t in [%f, %f], len(t) == %d\n'\
1246               %(min(self.T), max(self.T), len(self.T))
1247        str += '  Quantities:\n'
1248        for name in quantity_names:
1249            q = quantities[name][:].flat
1250            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1251
1252        if interpolation_points is not None:   
1253            str += '  Interpolation points (xi, eta):'\
1254                   ' number of points == %d\n' %interpolation_points.shape[0]
1255            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1256                                            max(interpolation_points[:,0]))
1257            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1258                                             max(interpolation_points[:,1]))
1259            str += '  Interpolated quantities (over all timesteps):\n'
1260       
1261            for name in quantity_names:
1262                q = precomputed_values[name][:].flat
1263                str += '    %s at interpolation points in [%f, %f]\n'\
1264                       %(name, min(q), max(q))
1265        str += '------------------------------------------------\n'
1266
1267        return str
1268
1269        #FIXME: Delete
1270        #print '------------------------------------------------'
1271        #print 'Interpolation_function statistics:'
1272        #print '  Extent:'
1273        #print '    x in [%f, %f], len(x) == %d'\
1274        #      %(min(x), max(x), len(x))
1275        #print '    y in [%f, %f], len(y) == %d'\
1276        #      %(min(y), max(y), len(y))
1277        #print '    t in [%f, %f], len(t) == %d'\
1278        #      %(min(self.T), max(self.T), len(self.T))
1279        #print '  Quantities:'
1280        #for name in quantity_names:
1281        #    q = quantities[name][:].flat
1282        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1283        #print '  Interpolation points (xi, eta):'\
1284        #      ' number of points == %d ' %interpolation_points.shape[0]
1285        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1286        #                             max(interpolation_points[:,0]))
1287        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1288        #                              max(interpolation_points[:,1]))
1289        #print '  Interpolated quantities (over all timesteps):'
1290        #
1291        #for name in quantity_names:
1292        #    q = precomputed_values[name][:].flat
1293        #    print '    %s at interpolation points in [%f, %f]'\
1294        #          %(name, min(q), max(q))
1295        #print '------------------------------------------------'
1296
1297
1298#-------------------------------------------------------------
1299if __name__ == "__main__":
1300    """
1301    Load in a mesh and data points with attributes.
1302    Fit the attributes to the mesh.
1303    Save a new mesh file.
1304    """
1305    import os, sys
1306    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1307            %os.path.basename(sys.argv[0])
1308
1309    if len(sys.argv) < 4:
1310        print usage
1311    else:
1312        mesh_file = sys.argv[1]
1313        point_file = sys.argv[2]
1314        mesh_output_file = sys.argv[3]
1315
1316        expand_search = False
1317        if len(sys.argv) > 4:
1318            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1319                expand_search = True
1320            else:
1321                expand_search = False
1322
1323        verbose = False
1324        if len(sys.argv) > 5:
1325            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1326                verbose = False
1327            else:
1328                verbose = True
1329
1330        if len(sys.argv) > 6:
1331            alpha = sys.argv[6]
1332        else:
1333            alpha = DEFAULT_ALPHA
1334
1335        # This is used more for testing
1336        if len(sys.argv) > 7:
1337            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1338                display_errors = False
1339            else:
1340                display_errors = True
1341           
1342        t0 = time.time()
1343        try:
1344            fit_to_mesh_file(mesh_file,
1345                         point_file,
1346                         mesh_output_file,
1347                         alpha,
1348                         verbose= verbose,
1349                         expand_search = expand_search,
1350                         display_errors = display_errors)
1351        except IOError,e:
1352            import sys; sys.exit(1)
1353
1354        print 'That took %.2f seconds' %(time.time()-t0)
1355
Note: See TracBrowser for help on using the repository browser.