source: inundation/pyvolution/least_squares.py @ 2583

Last change on this file since 2583 was 2583, checked in by nick, 18 years ago

add print out statistics

File size: 49.2 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from pyvolution.mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from utilities.sparse import Sparse, Sparse_CSR
29from utilities.cg_solve import conjugate_gradient, VectorShapeError
30from utilities.numerical_tools import ensure_numeric, mean, gradient
31
32
33from coordinate_transforms.geo_reference import Geo_reference
34
35import time
36
37
38
39
40DEFAULT_ALPHA = 0.001
41
42def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
43                     alpha=DEFAULT_ALPHA, verbose= False,
44                     expand_search = False,
45                     data_origin = None,
46                     mesh_origin = None,
47                     precrop = False,
48                     display_errors = True):
49    """
50    Given a mesh file (tsh) and a point attribute file (xya), fit
51    point attributes to the mesh and write a mesh file with the
52    results.
53
54
55    If data_origin is not None it is assumed to be
56    a 3-tuple with geo referenced
57    UTM coordinates (zone, easting, northing)
58
59    NOTE: Throws IOErrors, for a variety of file problems.
60   
61    mesh_origin is the same but refers to the input tsh file.
62    FIXME: When the tsh format contains it own origin, these parameters can go.
63    FIXME: And both origins should be obtained from the specified files.
64    """
65
66    from load_mesh.loadASCII import import_mesh_file, \
67                 import_points_file, export_mesh_file, \
68                 concatinate_attributelist
69
70
71    try:
72        mesh_dict = import_mesh_file(mesh_file)
73    except IOError,e:
74        if display_errors:
75            print "Could not load bad file. ", e
76        raise IOError  #Re-raise exception
77       
78    vertex_coordinates = mesh_dict['vertices']
79    triangles = mesh_dict['triangles']
80    if type(mesh_dict['vertex_attributes']) == ArrayType:
81        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
82    else:
83        old_point_attributes = mesh_dict['vertex_attributes']
84
85    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
86        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
87    else:
88        old_title_list = mesh_dict['vertex_attribute_titles']
89
90    if verbose: print 'tsh file %s loaded' %mesh_file
91
92    # load in the .pts file
93    try:
94        point_dict = import_points_file(point_file, verbose=verbose)
95    except IOError,e:
96        if display_errors:
97            print "Could not load bad file. ", e
98        raise IOError  #Re-raise exception 
99
100    point_coordinates = point_dict['pointlist']
101    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
102
103    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
104        data_origin = point_dict['geo_reference'].get_origin()
105    else:
106        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
107
108    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
109        mesh_origin = mesh_dict['geo_reference'].get_origin()
110    else:
111        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
112
113    if verbose: print "points file loaded"
114    if verbose: print "fitting to mesh"
115    f = fit_to_mesh(vertex_coordinates,
116                    triangles,
117                    point_coordinates,
118                    point_attributes,
119                    alpha = alpha,
120                    verbose = verbose,
121                    expand_search = expand_search,
122                    data_origin = data_origin,
123                    mesh_origin = mesh_origin,
124                    precrop = precrop)
125    if verbose: print "finished fitting to mesh"
126
127    # convert array to list of lists
128    new_point_attributes = f.tolist()
129    #FIXME have this overwrite attributes with the same title - DSG
130    #Put the newer attributes last
131    if old_title_list <> []:
132        old_title_list.extend(title_list)
133        #FIXME can this be done a faster way? - DSG
134        for i in range(len(old_point_attributes)):
135            old_point_attributes[i].extend(new_point_attributes[i])
136        mesh_dict['vertex_attributes'] = old_point_attributes
137        mesh_dict['vertex_attribute_titles'] = old_title_list
138    else:
139        mesh_dict['vertex_attributes'] = new_point_attributes
140        mesh_dict['vertex_attribute_titles'] = title_list
141
142    #FIXME (Ole): Remember to output mesh_origin as well
143    if verbose: print "exporting to file ", mesh_output_file
144
145    try:
146        export_mesh_file(mesh_output_file, mesh_dict)
147    except IOError,e:
148        if display_errors:
149            print "Could not write file. ", e
150        raise IOError
151
152def fit_to_mesh(vertex_coordinates,
153                triangles,
154                point_coordinates,
155                point_attributes,
156                alpha = DEFAULT_ALPHA,
157                verbose = False,
158                expand_search = False,
159                data_origin = None,
160                mesh_origin = None,
161                precrop = False,
162                use_cache = False):
163    """
164    Fit a smooth surface to a triangulation,
165    given data points with attributes.
166
167
168        Inputs:
169
170          vertex_coordinates: List of coordinate pairs [xi, eta] of points
171          constituting mesh (or a an m x 2 Numeric array)
172
173          triangles: List of 3-tuples (or a Numeric array) of
174          integers representing indices of all vertices in the mesh.
175
176          point_coordinates: List of coordinate pairs [x, y] of data points
177          (or an nx2 Numeric array)
178
179          alpha: Smoothing parameter.
180
181          point_attributes: Vector or array of data at the point_coordinates.
182
183          data_origin and mesh_origin are 3-tuples consisting of
184          UTM zone, easting and northing. If specified
185          point coordinates and vertex coordinates are assumed to be
186          relative to their respective origins.
187
188    """
189
190    if use_cache is True:
191        from caching.caching import cache
192        interp = cache(_interpolation,
193                       (vertex_coordinates,
194                        triangles,
195                        point_coordinates),
196                       {'alpha': alpha,
197                        'verbose': verbose,
198                        'expand_search': expand_search,
199                        'data_origin': data_origin,
200                        'mesh_origin': mesh_origin,
201                        'precrop': precrop},
202                       verbose = verbose)       
203       
204    else:
205        interp = Interpolation(vertex_coordinates,
206                               triangles,
207                               point_coordinates,
208                               alpha = alpha,
209                               verbose = verbose,
210                               expand_search = expand_search,
211                               data_origin = data_origin,
212                               mesh_origin = mesh_origin,
213                               precrop = precrop)
214
215    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
216   
217    if verbose:
218        print '+------------------------------------------------'
219        print 'least squares finished'
220        print 'points: %d points' %(len(point_attributes))
221        print '    x in [%f, %f]'%(min(point_coordinates[:,0]),
222                                   max(point_coordinates[:,0]))
223        print '    y in [%f, %f]'%(min(point_coordinates[:,1]),
224                                   max(point_coordinates[:,1]))
225        print '    z in [%f, %f]'%(min(point_attributes),
226                                   max(point_attributes))
227        print
228        print 'mesh: %d vertices' %(len(vertex_attributes))
229        print '    xi in [%f, %f]'%(min(vertex_coordinates[:,0]),
230                                   max(vertex_coordinates[:,0]))
231        print '    eta in [%f, %f]'%(min(vertex_coordinates[:,1]),
232                                   max(vertex_coordinates[:,1]))
233        print '    zeta in [%f, %f]'%(min(vertex_attributes),
234                                   max(vertex_attributes))
235        print '+------------------------------------------------'
236
237    return vertex_attributes
238
239
240
241def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
242                    verbose = False, reduction = 1):
243    """Fits attributes from pts file to MxN rectangular mesh
244
245    Read pts file and create rectangular mesh of resolution MxN such that
246    it covers all points specified in pts file.
247
248    FIXME: This may be a temporary function until we decide on
249    netcdf formats etc
250
251    FIXME: Uses elevation hardwired
252    """
253
254    import  mesh_factory
255    from load_mesh.loadASCII import import_points_file
256   
257    if verbose: print 'Read pts'
258    points_dict = import_points_file(pts_name)
259    #points, attributes = util.read_xya(pts_name)
260
261    #Reduce number of points a bit
262    points = points_dict['pointlist'][::reduction]
263    elevation = points_dict['attributelist']['elevation']  #Must be elevation
264    elevation = elevation[::reduction]
265
266    if verbose: print 'Got %d data points' %len(points)
267
268    if verbose: print 'Create mesh'
269    #Find extent
270    max_x = min_x = points[0][0]
271    max_y = min_y = points[0][1]
272    for point in points[1:]:
273        x = point[0]
274        if x > max_x: max_x = x
275        if x < min_x: min_x = x
276        y = point[1]
277        if y > max_y: max_y = y
278        if y < min_y: min_y = y
279
280    #Create appropriate mesh
281    vertex_coordinates, triangles, boundary =\
282         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
283                                (min_x, min_y))
284
285    #Fit attributes to mesh
286    vertex_attributes = fit_to_mesh(vertex_coordinates,
287                        triangles,
288                        points,
289                        elevation, alpha=alpha, verbose=verbose)
290
291
292
293    return vertex_coordinates, triangles, boundary, vertex_attributes
294
295
296def _interpolation(*args, **kwargs):
297    """Private function for use with caching. Reason is that classes
298    may change their byte code between runs which is annoying.
299    """
300   
301    return Interpolation(*args, **kwargs)
302
303
304class Interpolation:
305
306    def __init__(self,
307                 vertex_coordinates,
308                 triangles,
309                 point_coordinates = None,
310                 alpha = None,
311                 verbose = False,
312                 expand_search = True,
313                 interp_only = False,
314                 max_points_per_cell = 30,
315                 mesh_origin = None,
316                 data_origin = None,
317                 precrop = False):
318
319
320        """ Build interpolation matrix mapping from
321        function values at vertices to function values at data points
322
323        Inputs:
324
325          vertex_coordinates: List of coordinate pairs [xi, eta] of
326          points constituting mesh (or a an m x 2 Numeric array)
327          Points may appear multiple times
328          (e.g. if vertices have discontinuities)
329
330          triangles: List of 3-tuples (or a Numeric array) of
331          integers representing indices of all vertices in the mesh.
332
333          point_coordinates: List of coordinate pairs [x, y] of
334          data points (or an nx2 Numeric array)
335          If point_coordinates is absent, only smoothing matrix will
336          be built
337
338          alpha: Smoothing parameter
339
340          data_origin and mesh_origin are 3-tuples consisting of
341          UTM zone, easting and northing. If specified
342          point coordinates and vertex coordinates are assumed to be
343          relative to their respective origins.
344
345        """
346
347        #Convert input to Numeric arrays
348        triangles = ensure_numeric(triangles, Int)
349        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
350
351        #Build underlying mesh
352        if verbose: print 'Building mesh'
353        #self.mesh = General_mesh(vertex_coordinates, triangles,
354        #FIXME: Trying the normal mesh while testing precrop,
355        #       The functionality of boundary_polygon is needed for that
356
357        #FIXME - geo ref does not have to go into mesh.
358        # Change the point co-ords to conform to the
359        # mesh co-ords early in the code
360        if mesh_origin is None:
361            geo = None
362        else:
363            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
364        self.mesh = Mesh(vertex_coordinates, triangles,
365                         geo_reference = geo)
366       
367        self.mesh.check_integrity()
368
369        self.data_origin = data_origin
370
371        self.point_indices = None
372
373        #Smoothing parameter
374        if alpha is None:
375            self.alpha = DEFAULT_ALPHA
376        else:   
377            self.alpha = alpha
378
379
380        if point_coordinates is not None:
381            if verbose: print 'Building interpolation matrix'
382            self.build_interpolation_matrix_A(point_coordinates,
383                                              verbose = verbose,
384                                              expand_search = expand_search,
385                                              interp_only = interp_only, 
386                                              max_points_per_cell =\
387                                              max_points_per_cell,
388                                              data_origin = data_origin,
389                                              precrop = precrop)
390        #Build coefficient matrices
391        if interp_only == False:
392            self.build_coefficient_matrix_B(point_coordinates,
393                                        verbose = verbose,
394                                        expand_search = expand_search,
395                                        max_points_per_cell =\
396                                        max_points_per_cell,
397                                        data_origin = data_origin,
398                                        precrop = precrop)
399
400    def set_point_coordinates(self, point_coordinates,
401                              data_origin = None,
402                              verbose = False,
403                              precrop = True):
404        """
405        A public interface to setting the point co-ordinates.
406        """
407        if point_coordinates is not None:
408            if verbose: print 'Building interpolation matrix'
409            self.build_interpolation_matrix_A(point_coordinates,
410                                              verbose = verbose,
411                                              data_origin = data_origin,
412                                              precrop = precrop)
413        self.build_coefficient_matrix_B(point_coordinates, data_origin)
414
415    def build_coefficient_matrix_B(self, point_coordinates=None,
416                                   verbose = False, expand_search = True,
417                                   max_points_per_cell=30,
418                                   data_origin = None,
419                                   precrop = False):
420        """Build final coefficient matrix"""
421
422
423        if self.alpha <> 0:
424            if verbose: print 'Building smoothing matrix'
425            self.build_smoothing_matrix_D()
426
427        if point_coordinates is not None:
428            if self.alpha <> 0:
429                self.B = self.AtA + self.alpha*self.D
430            else:
431                self.B = self.AtA
432
433            #Convert self.B matrix to CSR format for faster matrix vector
434            self.B = Sparse_CSR(self.B)
435
436    def build_interpolation_matrix_A(self, point_coordinates,
437                                     verbose = False, expand_search = True,
438                                     max_points_per_cell=30,
439                                     data_origin = None,
440                                     precrop = False,
441                                     interp_only = False):
442        """Build n x m interpolation matrix, where
443        n is the number of data points and
444        m is the number of basis functions phi_k (one per vertex)
445
446        This algorithm uses a quad tree data structure for fast binning of data points
447        origin is a 3-tuple consisting of UTM zone, easting and northing.
448        If specified coordinates are assumed to be relative to this origin.
449
450        This one will override any data_origin that may be specified in
451        interpolation instance
452
453        """
454
455
456
457        #FIXME (Ole): Check that this function is memeory efficient.
458        #6 million datapoints and 300000 basis functions
459        #causes out-of-memory situation
460        #First thing to check is whether there is room for self.A and self.AtA
461        #
462        #Maybe we need some sort of blocking
463
464        from pyvolution.quad import build_quadtree
465        from utilities.polygon import inside_polygon
466       
467
468        if data_origin is None:
469            data_origin = self.data_origin #Use the one from
470                                           #interpolation instance
471
472        #Convert input to Numeric arrays just in case.
473        point_coordinates = ensure_numeric(point_coordinates, Float)
474
475        #Keep track of discarded points (if any).
476        #This is only registered if precrop is True
477        self.cropped_points = False
478
479        #Shift data points to same origin as mesh (if specified)
480
481        #FIXME this will shift if there was no geo_ref.
482        #But all this should be removed anyhow.
483        #change coords before this point
484        mesh_origin = self.mesh.geo_reference.get_origin()
485        if point_coordinates is not None:
486            if data_origin is not None:
487                if mesh_origin is not None:
488
489                    #Transformation:
490                    #
491                    #Let x_0 be the reference point of the point coordinates
492                    #and xi_0 the reference point of the mesh.
493                    #
494                    #A point coordinate (x + x_0) is then made relative
495                    #to xi_0 by
496                    #
497                    # x_new = x + x_0 - xi_0
498                    #
499                    #and similarly for eta
500
501                    x_offset = data_origin[1] - mesh_origin[1]
502                    y_offset = data_origin[2] - mesh_origin[2]
503                else: #Shift back to a zero origin
504                    x_offset = data_origin[1]
505                    y_offset = data_origin[2]
506
507                point_coordinates[:,0] += x_offset
508                point_coordinates[:,1] += y_offset
509            else:
510                if mesh_origin is not None:
511                    #Use mesh origin for data points
512                    point_coordinates[:,0] -= mesh_origin[1]
513                    point_coordinates[:,1] -= mesh_origin[2]
514
515
516
517        #Remove points falling outside mesh boundary
518        #This reduced one example from 1356 seconds to 825 seconds
519
520       
521        if precrop is True:
522            from Numeric import take
523
524            if verbose: print 'Getting boundary polygon'
525            P = self.mesh.get_boundary_polygon()
526
527            if verbose: print 'Getting indices inside mesh boundary'
528            indices = inside_polygon(point_coordinates, P, verbose = verbose)
529
530
531            if len(indices) != point_coordinates.shape[0]:
532                self.cropped_points = True
533                if verbose:
534                    print 'Done - %d points outside mesh have been cropped.'\
535                          %(point_coordinates.shape[0] - len(indices))
536
537            point_coordinates = take(point_coordinates, indices)
538            self.point_indices = indices
539
540
541
542
543        #Build n x m interpolation matrix
544        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
545        n = point_coordinates.shape[0]     #Nbr of data points
546
547        if verbose: print 'Number of datapoints: %d' %n
548        if verbose: print 'Number of basis functions: %d' %m
549
550        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
551        #However, Sparse_CSR does not have the same methods as Sparse yet
552        #The tests will reveal what needs to be done
553
554        #
555        #self.A = Sparse_CSR(Sparse(n,m))
556        #self.AtA = Sparse_CSR(Sparse(m,m))
557        self.A = Sparse(n,m)
558        self.AtA = Sparse(m,m)
559
560        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
561        root = build_quadtree(self.mesh,
562                              max_points_per_cell = max_points_per_cell)
563        #root.show()
564        self.expanded_quad_searches = []
565        #Compute matrix elements
566        for i in range(n):
567            #For each data_coordinate point
568
569            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
570            x = point_coordinates[i]
571
572            #Find vertices near x
573            candidate_vertices = root.search(x[0], x[1])
574            is_more_elements = True
575
576            element_found, sigma0, sigma1, sigma2, k = \
577                self.search_triangles_of_vertices(candidate_vertices, x)
578            first_expansion = True
579            while not element_found and is_more_elements and expand_search:
580                #if verbose: print 'Expanding search'
581                if first_expansion == True:
582                    self.expanded_quad_searches.append(1)
583                    first_expansion = False
584                else:
585                    end = len(self.expanded_quad_searches) - 1
586                    assert end >= 0
587                    self.expanded_quad_searches[end] += 1
588                candidate_vertices, branch = root.expand_search()
589                if branch == []:
590                    # Searching all the verts from the root cell that haven't
591                    # been searched.  This is the last try
592                    element_found, sigma0, sigma1, sigma2, k = \
593                      self.search_triangles_of_vertices(candidate_vertices, x)
594                    is_more_elements = False
595                else:
596                    element_found, sigma0, sigma1, sigma2, k = \
597                      self.search_triangles_of_vertices(candidate_vertices, x)
598
599               
600            #Update interpolation matrix A if necessary
601            if element_found is True:
602                #Assign values to matrix A
603
604                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
605                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
606                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
607
608                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
609                js     = [j0,j1,j2]
610
611                for j in js:
612                    self.A[i,j] = sigmas[j]
613                    for k in js:
614                        if interp_only == False:
615                            self.AtA[j,k] += sigmas[j]*sigmas[k]
616            else:
617                pass
618                #Ok if there is no triangle for datapoint
619                #(as in brute force version)
620                #raise 'Could not find triangle for point', x
621
622
623
624    def search_triangles_of_vertices(self, candidate_vertices, x):
625            #Find triangle containing x:
626            element_found = False
627
628            # This will be returned if element_found = False
629            sigma2 = -10.0
630            sigma0 = -10.0
631            sigma1 = -10.0
632            k = -10.0
633            #print "*$* candidate_vertices", candidate_vertices
634            #For all vertices in same cell as point x
635            for v in candidate_vertices:
636                #FIXME (DSG-DSG): this catches verts with no triangle.
637                #Currently pmesh is producing these.
638                #this should be stopped,
639                if self.mesh.vertexlist[v] is None:
640                    continue
641                #for each triangle id (k) which has v as a vertex
642                for k, _ in self.mesh.vertexlist[v]:
643
644                    #Get the three vertex_points of candidate triangle
645                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
646                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
647                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
648
649                    #print "PDSG - k", k
650                    #print "PDSG - xi0", xi0
651                    #print "PDSG - xi1", xi1
652                    #print "PDSG - xi2", xi2
653                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
654                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
655
656                    #Get the three normals
657                    n0 = self.mesh.get_normal(k, 0)
658                    n1 = self.mesh.get_normal(k, 1)
659                    n2 = self.mesh.get_normal(k, 2)
660
661
662                    #Compute interpolation
663                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
664                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
665                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
666
667                    #print "PDSG - sigma0", sigma0
668                    #print "PDSG - sigma1", sigma1
669                    #print "PDSG - sigma2", sigma2
670
671                    #FIXME: Maybe move out to test or something
672                    epsilon = 1.0e-6
673                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
674
675                    #Check that this triangle contains the data point
676
677                    #Sigmas can get negative within
678                    #machine precision on some machines (e.g nautilus)
679                    #Hence the small eps
680                    eps = 1.0e-15
681                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
682                        element_found = True
683                        break
684
685                if element_found is True:
686                    #Don't look for any other triangle
687                    break
688            return element_found, sigma0, sigma1, sigma2, k
689
690
691
692    def build_interpolation_matrix_A_brute(self, point_coordinates):
693        """Build n x m interpolation matrix, where
694        n is the number of data points and
695        m is the number of basis functions phi_k (one per vertex)
696
697        This is the brute force which is too slow for large problems,
698        but could be used for testing
699        """
700
701
702        #Convert input to Numeric arrays
703        point_coordinates = ensure_numeric(point_coordinates, Float)
704
705        #Build n x m interpolation matrix
706        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
707        n = point_coordinates.shape[0]     #Nbr of data points
708
709        self.A = Sparse(n,m)
710        self.AtA = Sparse(m,m)
711
712        #Compute matrix elements
713        for i in range(n):
714            #For each data_coordinate point
715
716            x = point_coordinates[i]
717            element_found = False
718            k = 0
719            while not element_found and k < len(self.mesh):
720                #For each triangle (brute force)
721                #FIXME: Real algorithm should only visit relevant triangles
722
723                #Get the three vertex_points
724                xi0 = self.mesh.get_vertex_coordinate(k, 0)
725                xi1 = self.mesh.get_vertex_coordinate(k, 1)
726                xi2 = self.mesh.get_vertex_coordinate(k, 2)
727
728                #Get the three normals
729                n0 = self.mesh.get_normal(k, 0)
730                n1 = self.mesh.get_normal(k, 1)
731                n2 = self.mesh.get_normal(k, 2)
732
733                #Compute interpolation
734                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
735                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
736                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
737
738                #FIXME: Maybe move out to test or something
739                epsilon = 1.0e-6
740                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
741
742                #Check that this triangle contains data point
743                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
744                    element_found = True
745                    #Assign values to matrix A
746
747                    j0 = self.mesh.triangles[k,0] #Global vertex id
748                    #self.A[i, j0] = sigma0
749
750                    j1 = self.mesh.triangles[k,1] #Global vertex id
751                    #self.A[i, j1] = sigma1
752
753                    j2 = self.mesh.triangles[k,2] #Global vertex id
754                    #self.A[i, j2] = sigma2
755
756                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
757                    js     = [j0,j1,j2]
758
759                    for j in js:
760                        self.A[i,j] = sigmas[j]
761                        for k in js:
762                            self.AtA[j,k] += sigmas[j]*sigmas[k]
763                k = k+1
764
765
766
767    def get_A(self):
768        return self.A.todense()
769
770    def get_B(self):
771        return self.B.todense()
772
773    def get_D(self):
774        return self.D.todense()
775
776        #FIXME: Remember to re-introduce the 1/n factor in the
777        #interpolation term
778
779    def build_smoothing_matrix_D(self):
780        """Build m x m smoothing matrix, where
781        m is the number of basis functions phi_k (one per vertex)
782
783        The smoothing matrix is defined as
784
785        D = D1 + D2
786
787        where
788
789        [D1]_{k,l} = \int_\Omega
790           \frac{\partial \phi_k}{\partial x}
791           \frac{\partial \phi_l}{\partial x}\,
792           dx dy
793
794        [D2]_{k,l} = \int_\Omega
795           \frac{\partial \phi_k}{\partial y}
796           \frac{\partial \phi_l}{\partial y}\,
797           dx dy
798
799
800        The derivatives \frac{\partial \phi_k}{\partial x},
801        \frac{\partial \phi_k}{\partial x} for a particular triangle
802        are obtained by computing the gradient a_k, b_k for basis function k
803        """
804
805        #FIXME: algorithm might be optimised by computing local 9x9
806        #"element stiffness matrices:
807
808        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
809
810        self.D = Sparse(m,m)
811
812        #For each triangle compute contributions to D = D1+D2
813        for i in range(len(self.mesh)):
814
815            #Get area
816            area = self.mesh.areas[i]
817
818            #Get global vertex indices
819            v0 = self.mesh.triangles[i,0]
820            v1 = self.mesh.triangles[i,1]
821            v2 = self.mesh.triangles[i,2]
822
823            #Get the three vertex_points
824            xi0 = self.mesh.get_vertex_coordinate(i, 0)
825            xi1 = self.mesh.get_vertex_coordinate(i, 1)
826            xi2 = self.mesh.get_vertex_coordinate(i, 2)
827
828            #Compute gradients for each vertex
829            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
830                              1, 0, 0)
831
832            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
833                              0, 1, 0)
834
835            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
836                              0, 0, 1)
837
838            #Compute diagonal contributions
839            self.D[v0,v0] += (a0*a0 + b0*b0)*area
840            self.D[v1,v1] += (a1*a1 + b1*b1)*area
841            self.D[v2,v2] += (a2*a2 + b2*b2)*area
842
843            #Compute contributions for basis functions sharing edges
844            e01 = (a0*a1 + b0*b1)*area
845            self.D[v0,v1] += e01
846            self.D[v1,v0] += e01
847
848            e12 = (a1*a2 + b1*b2)*area
849            self.D[v1,v2] += e12
850            self.D[v2,v1] += e12
851
852            e20 = (a2*a0 + b2*b0)*area
853            self.D[v2,v0] += e20
854            self.D[v0,v2] += e20
855
856
857    def fit(self, z):
858        """Fit a smooth surface to given 1d array of data points z.
859
860        The smooth surface is computed at each vertex in the underlying
861        mesh using the formula given in the module doc string.
862
863        Pre Condition:
864          self.A, self.AtA and self.B have been initialised
865
866        Inputs:
867          z: Single 1d vector or array of data at the point_coordinates.
868        """
869
870        #Convert input to Numeric arrays
871        z = ensure_numeric(z, Float)
872
873        if len(z.shape) > 1 :
874            raise VectorShapeError, 'Can only deal with 1d data vector'
875
876        if self.point_indices is not None:
877            #Remove values for any points that were outside mesh
878            z = take(z, self.point_indices)
879
880        #Compute right hand side based on data
881        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
882        # after a matrix is built, before calcs.
883        Atz = self.A.trans_mult(z)
884
885
886        #Check sanity
887        n, m = self.A.shape
888        if n<m and self.alpha == 0.0:
889            msg = 'ERROR (least_squares): Too few data points\n'
890            msg += 'There are only %d data points and alpha == 0. ' %n
891            msg += 'Need at least %d\n' %m
892            msg += 'Alternatively, set smoothing parameter alpha to a small '
893            msg += 'positive value,\ne.g. 1.0e-3.'
894            raise msg
895
896
897
898        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
899        #FIXME: Should we store the result here for later use? (ON)
900
901
902    def fit_points(self, z, verbose=False):
903        """Like fit, but more robust when each point has two or more attributes
904        FIXME (Ole): The name fit_points doesn't carry any meaning
905        for me. How about something like fit_multiple or fit_columns?
906        """
907
908        try:
909            if verbose: print 'Solving penalised least_squares problem'
910            return self.fit(z)
911        except VectorShapeError, e:
912            # broadcasting is not supported.
913
914            #Convert input to Numeric arrays
915            z = ensure_numeric(z, Float)
916
917            #Build n x m interpolation matrix
918            m = self.mesh.coordinates.shape[0] #Number of vertices
919            n = z.shape[1]                     #Number of data points
920
921            f = zeros((m,n), Float) #Resulting columns
922
923            for i in range(z.shape[1]):
924                f[:,i] = self.fit(z[:,i])
925
926            return f
927
928
929    def interpolate(self, f):
930        """Evaluate smooth surface f at data points implied in self.A.
931
932        The mesh values representing a smooth surface are
933        assumed to be specified in f. This argument could,
934        for example have been obtained from the method self.fit()
935
936        Pre Condition:
937          self.A has been initialised
938
939        Inputs:
940          f: Vector or array of data at the mesh vertices.
941          If f is an array, interpolation will be done for each column as
942          per underlying matrix-matrix multiplication
943
944        Output:
945          Interpolated values at data points implied in self.A
946
947        """
948
949        return self.A * f
950
951    def cull_outsiders(self, f):
952        pass
953
954
955
956
957class Interpolation_function:
958    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
959    which is interpolated from time series defined at vertices of
960    triangular mesh (such as those stored in sww files)
961
962    Let m be the number of vertices, n the number of triangles
963    and p the number of timesteps.
964
965    Mandatory input
966        time:               px1 array of monotonously increasing times (Float)
967        quantities:         Dictionary of arrays or 1 array (Float)
968                            The arrays must either have dimensions pxm or mx1.
969                            The resulting function will be time dependent in
970                            the former case while it will be constan with
971                            respect to time in the latter case.
972       
973    Optional input:
974        quantity_names:     List of keys into the quantities dictionary
975        vertex_coordinates: mx2 array of coordinates (Float)
976        triangles:          nx3 array of indices into vertex_coordinates (Int)
977        interpolation_points: Nx2 array of coordinates to be interpolated to
978        verbose:            Level of reporting
979   
980   
981    The quantities returned by the callable object are specified by
982    the list quantities which must contain the names of the
983    quantities to be returned and also reflect the order, e.g. for
984    the shallow water wave equation, on would have
985    quantities = ['stage', 'xmomentum', 'ymomentum']
986
987    The parameter interpolation_points decides at which points interpolated
988    quantities are to be computed whenever object is called.
989    If None, return average value
990    """
991
992   
993   
994    def __init__(self,
995                 time,
996                 quantities,
997                 quantity_names = None, 
998                 vertex_coordinates = None,
999                 triangles = None,
1000                 interpolation_points = None,
1001                 verbose = False):
1002        """Initialise object and build spatial interpolation if required
1003        """
1004
1005        from Numeric import array, zeros, Float, alltrue, concatenate,\
1006             reshape, ArrayType
1007
1008
1009        from config import time_format
1010        import types
1011
1012
1013
1014        #Check temporal info
1015        time = ensure_numeric(time)       
1016        msg = 'Time must be a monotonuosly '
1017        msg += 'increasing sequence %s' %time
1018        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
1019
1020
1021        #Check if quantities is a single array only
1022        if type(quantities) != types.DictType:
1023            quantities = ensure_numeric(quantities)
1024            quantity_names = ['Attribute']
1025
1026            #Make it a dictionary
1027            quantities = {quantity_names[0]: quantities}
1028
1029
1030        #Use keys if no names are specified
1031        if quantity_names is None:
1032            quantity_names = quantities.keys()
1033
1034
1035        #Check spatial info
1036        if vertex_coordinates is None:
1037            self.spatial = False
1038        else:   
1039            vertex_coordinates = ensure_numeric(vertex_coordinates)
1040
1041            assert triangles is not None, 'Triangles array must be specified'
1042            triangles = ensure_numeric(triangles)
1043            self.spatial = True           
1044           
1045
1046 
1047        #Save for use with statistics
1048        self.quantity_names = quantity_names       
1049        self.quantities = quantities       
1050        self.vertex_coordinates = vertex_coordinates
1051        self.interpolation_points = interpolation_points
1052        self.T = time[:]  # Time assumed to be relative to starttime
1053        self.index = 0    # Initial time index
1054        self.precomputed_values = {}
1055           
1056
1057
1058        #Precomputed spatial interpolation if requested
1059        if interpolation_points is not None:
1060            if self.spatial is False:
1061                raise 'Triangles and vertex_coordinates must be specified'
1062           
1063            try:
1064                self.interpolation_points = ensure_numeric(interpolation_points)
1065            except:
1066                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1067                      'or a list of points\n'
1068                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1069                                      '...')
1070                raise msg
1071
1072
1073            m = len(self.interpolation_points)
1074            p = len(self.T)
1075           
1076            for name in quantity_names:
1077                self.precomputed_values[name] = zeros((p, m), Float)
1078
1079            #Build interpolator
1080            interpol = Interpolation(vertex_coordinates,
1081                                     triangles,
1082                                     point_coordinates = \
1083                                     self.interpolation_points,
1084                                     alpha = 0,
1085                                     precrop = False, 
1086                                     verbose = verbose)
1087
1088            if verbose: print 'Interpolate'
1089            for i, t in enumerate(self.T):
1090                #Interpolate quantities at this timestep
1091                if verbose and i%((p+10)/10)==0:
1092                    print ' time step %d of %d' %(i, p)
1093                   
1094                for name in quantity_names:
1095                    if len(quantities[name].shape) == 2:
1096                        result = interpol.interpolate(quantities[name][i,:])
1097                    else:
1098                       #Assume no time dependency
1099                       result = interpol.interpolate(quantities[name][:])
1100                       
1101                    self.precomputed_values[name][i, :] = result
1102                   
1103                       
1104
1105            #Report
1106            if verbose:
1107                print self.statistics()
1108                #self.print_statistics()
1109           
1110        else:
1111            #Store quantitites as is
1112            for name in quantity_names:
1113                self.precomputed_values[name] = quantities[name]
1114
1115
1116        #else:
1117        #    #Return an average, making this a time series
1118        #    for name in quantity_names:
1119        #        self.values[name] = zeros(len(self.T), Float)
1120        #
1121        #    if verbose: print 'Compute mean values'
1122        #    for i, t in enumerate(self.T):
1123        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1124        #        for name in quantity_names:
1125        #           self.values[name][i] = mean(quantities[name][i,:])
1126
1127
1128
1129
1130    def __repr__(self):
1131        #return 'Interpolation function (spatio-temporal)'
1132        return self.statistics()
1133   
1134
1135    def __call__(self, t, point_id = None, x = None, y = None):
1136        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1137
1138        Inputs:
1139          t: time - Model time. Must lie within existing timesteps
1140          point_id: index of one of the preprocessed points.
1141          x, y:     Overrides location, point_id ignored
1142         
1143          If spatial info is present and all of x,y,point_id
1144          are None an exception is raised
1145                   
1146          If no spatial info is present, point_id and x,y arguments are ignored
1147          making f a function of time only.
1148
1149         
1150          FIXME: point_id could also be a slice
1151          FIXME: What if x and y are vectors?
1152          FIXME: What about f(x,y) without t?
1153        """
1154
1155        from math import pi, cos, sin, sqrt
1156        from Numeric import zeros, Float
1157        from utilities.numerical_tools import mean       
1158
1159        if self.spatial is True:
1160            if point_id is None:
1161                if x is None or y is None:
1162                    msg = 'Either point_id or x and y must be specified'
1163                    raise msg
1164            else:
1165                if self.interpolation_points is None:
1166                    msg = 'Interpolation_function must be instantiated ' +\
1167                          'with a list of interpolation points before parameter ' +\
1168                          'point_id can be used'
1169                    raise msg
1170
1171
1172        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1173        msg += ' does not match model time: %s\n' %t
1174        if t < self.T[0]: raise msg
1175        if t > self.T[-1]: raise msg
1176
1177        oldindex = self.index #Time index
1178
1179        #Find current time slot
1180        while t > self.T[self.index]: self.index += 1
1181        while t < self.T[self.index]: self.index -= 1
1182
1183        if t == self.T[self.index]:
1184            #Protect against case where t == T[-1] (last time)
1185            # - also works in general when t == T[i]
1186            ratio = 0
1187        else:
1188            #t is now between index and index+1
1189            ratio = (t - self.T[self.index])/\
1190                    (self.T[self.index+1] - self.T[self.index])
1191
1192        #Compute interpolated values
1193        q = zeros(len(self.quantity_names), Float)
1194
1195        for i, name in enumerate(self.quantity_names):
1196            Q = self.precomputed_values[name]
1197
1198            if self.spatial is False:
1199                #If there is no spatial info               
1200                assert len(Q.shape) == 1
1201
1202                Q0 = Q[self.index]
1203                if ratio > 0: Q1 = Q[self.index+1]
1204
1205            else:
1206                if x is not None and y is not None:
1207                    #Interpolate to x, y
1208                   
1209                    raise 'x,y interpolation not yet implemented'
1210                else:
1211                    #Use precomputed point
1212                    Q0 = Q[self.index, point_id]
1213                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1214
1215            #Linear temporal interpolation   
1216            if ratio > 0:
1217                q[i] = Q0 + ratio*(Q1 - Q0)
1218            else:
1219                q[i] = Q0
1220
1221
1222        #Return vector of interpolated values
1223        #if len(q) == 1:
1224        #    return q[0]
1225        #else:
1226        #    return q
1227
1228
1229        #Return vector of interpolated values
1230        #FIXME:
1231        if self.spatial is True:
1232            return q
1233        else:
1234            #Replicate q according to x and y
1235            #This is e.g used for Wind_stress
1236            if x is None or y is None: 
1237                return q
1238            else:
1239                try:
1240                    N = len(x)
1241                except:
1242                    return q
1243                else:
1244                    from Numeric import ones, Float
1245                    #x is a vector - Create one constant column for each value
1246                    N = len(x)
1247                    assert len(y) == N, 'x and y must have same length'
1248                    res = []
1249                    for col in q:
1250                        res.append(col*ones(N, Float))
1251                       
1252                return res
1253
1254
1255    def statistics(self):
1256        """Output statistics about interpolation_function
1257        """
1258       
1259        vertex_coordinates = self.vertex_coordinates
1260        interpolation_points = self.interpolation_points               
1261        quantity_names = self.quantity_names
1262        quantities = self.quantities
1263        precomputed_values = self.precomputed_values                 
1264               
1265        x = vertex_coordinates[:,0]
1266        y = vertex_coordinates[:,1]               
1267
1268        str =  '------------------------------------------------\n'
1269        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1270        str += '  Extent:\n'
1271        str += '    x in [%f, %f], len(x) == %d\n'\
1272               %(min(x), max(x), len(x))
1273        str += '    y in [%f, %f], len(y) == %d\n'\
1274               %(min(y), max(y), len(y))
1275        str += '    t in [%f, %f], len(t) == %d\n'\
1276               %(min(self.T), max(self.T), len(self.T))
1277        str += '  Quantities:\n'
1278        for name in quantity_names:
1279            q = quantities[name][:].flat
1280            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1281
1282        if interpolation_points is not None:   
1283            str += '  Interpolation points (xi, eta):'\
1284                   ' number of points == %d\n' %interpolation_points.shape[0]
1285            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1286                                            max(interpolation_points[:,0]))
1287            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1288                                             max(interpolation_points[:,1]))
1289            str += '  Interpolated quantities (over all timesteps):\n'
1290       
1291            for name in quantity_names:
1292                q = precomputed_values[name][:].flat
1293                str += '    %s at interpolation points in [%f, %f]\n'\
1294                       %(name, min(q), max(q))
1295        str += '------------------------------------------------\n'
1296
1297        return str
1298
1299        #FIXME: Delete
1300        #print '------------------------------------------------'
1301        #print 'Interpolation_function statistics:'
1302        #print '  Extent:'
1303        #print '    x in [%f, %f], len(x) == %d'\
1304        #      %(min(x), max(x), len(x))
1305        #print '    y in [%f, %f], len(y) == %d'\
1306        #      %(min(y), max(y), len(y))
1307        #print '    t in [%f, %f], len(t) == %d'\
1308        #      %(min(self.T), max(self.T), len(self.T))
1309        #print '  Quantities:'
1310        #for name in quantity_names:
1311        #    q = quantities[name][:].flat
1312        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1313        #print '  Interpolation points (xi, eta):'\
1314        #      ' number of points == %d ' %interpolation_points.shape[0]
1315        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1316        #                             max(interpolation_points[:,0]))
1317        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1318        #                              max(interpolation_points[:,1]))
1319        #print '  Interpolated quantities (over all timesteps):'
1320        #
1321        #for name in quantity_names:
1322        #    q = precomputed_values[name][:].flat
1323        #    print '    %s at interpolation points in [%f, %f]'\
1324        #          %(name, min(q), max(q))
1325        #print '------------------------------------------------'
1326
1327
1328#-------------------------------------------------------------
1329if __name__ == "__main__":
1330    """
1331    Load in a mesh and data points with attributes.
1332    Fit the attributes to the mesh.
1333    Save a new mesh file.
1334    """
1335    import os, sys
1336    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1337            %os.path.basename(sys.argv[0])
1338
1339    if len(sys.argv) < 4:
1340        print usage
1341    else:
1342        mesh_file = sys.argv[1]
1343        point_file = sys.argv[2]
1344        mesh_output_file = sys.argv[3]
1345
1346        expand_search = False
1347        if len(sys.argv) > 4:
1348            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1349                expand_search = True
1350            else:
1351                expand_search = False
1352
1353        verbose = False
1354        if len(sys.argv) > 5:
1355            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1356                verbose = False
1357            else:
1358                verbose = True
1359
1360        if len(sys.argv) > 6:
1361            alpha = sys.argv[6]
1362        else:
1363            alpha = DEFAULT_ALPHA
1364
1365        # This is used more for testing
1366        if len(sys.argv) > 7:
1367            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1368                display_errors = False
1369            else:
1370                display_errors = True
1371           
1372        t0 = time.time()
1373        try:
1374            fit_to_mesh_file(mesh_file,
1375                         point_file,
1376                         mesh_output_file,
1377                         alpha,
1378                         verbose= verbose,
1379                         expand_search = expand_search,
1380                         display_errors = display_errors)
1381        except IOError,e:
1382            import sys; sys.exit(1)
1383
1384        print 'That took %.2f seconds' %(time.time()-t0)
1385
Note: See TracBrowser for help on using the repository browser.