source: inundation/pyvolution/least_squares.py @ 1872

Last change on this file since 1872 was 1872, checked in by ole, 18 years ago

Worked on exception catching. A test still fails on cyclone

File size: 46.4 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    mesh_origin is the same but refers to the input tsh file.
75    FIXME: When the tsh format contains it own origin, these parameters can go.
76    FIXME: And both origins should be obtained from the specified files.
77    """
78
79    from load_mesh.loadASCII import import_mesh_file, \
80                 import_points_file, export_mesh_file, \
81                 concatinate_attributelist
82
83
84    try:
85        mesh_dict = import_mesh_file(mesh_file)
86    except IOError,e:
87        if display_errors:
88            print "Could not load bad file. ", e
89        raise IOError  #Re-raise exception
90       
91    vertex_coordinates = mesh_dict['vertices']
92    triangles = mesh_dict['triangles']
93    if type(mesh_dict['vertex_attributes']) == ArrayType:
94        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
95    else:
96        old_point_attributes = mesh_dict['vertex_attributes']
97
98    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
99        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
100    else:
101        old_title_list = mesh_dict['vertex_attribute_titles']
102
103    if verbose: print 'tsh file %s loaded' %mesh_file
104
105    # load in the .pts file
106    try:
107        point_dict = import_points_file(point_file, verbose=verbose)
108    except IOError,e:
109        if display_errors:
110            print "Could not load bad file. ", e
111        raise IOError  #Re-raise exception       
112
113    point_coordinates = point_dict['pointlist']
114    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
115
116    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
117        data_origin = point_dict['geo_reference'].get_origin()
118    else:
119        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
120
121    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
122        mesh_origin = mesh_dict['geo_reference'].get_origin()
123    else:
124        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
125
126    if verbose: print "points file loaded"
127    if verbose:print "fitting to mesh"
128    f = fit_to_mesh(vertex_coordinates,
129                    triangles,
130                    point_coordinates,
131                    point_attributes,
132                    alpha = alpha,
133                    verbose = verbose,
134                    expand_search = expand_search,
135                    data_origin = data_origin,
136                    mesh_origin = mesh_origin,
137                    precrop = precrop)
138    if verbose: print "finished fitting to mesh"
139
140    # convert array to list of lists
141    new_point_attributes = f.tolist()
142    #FIXME have this overwrite attributes with the same title - DSG
143    #Put the newer attributes last
144    if old_title_list <> []:
145        old_title_list.extend(title_list)
146        #FIXME can this be done a faster way? - DSG
147        for i in range(len(old_point_attributes)):
148            old_point_attributes[i].extend(new_point_attributes[i])
149        mesh_dict['vertex_attributes'] = old_point_attributes
150        mesh_dict['vertex_attribute_titles'] = old_title_list
151    else:
152        mesh_dict['vertex_attributes'] = new_point_attributes
153        mesh_dict['vertex_attribute_titles'] = title_list
154
155    #FIXME (Ole): Remember to output mesh_origin as well
156    if verbose: print "exporting to file ",mesh_output_file
157
158    try:
159        export_mesh_file(mesh_output_file, mesh_dict)
160    except IOError,e:
161        if display_errors:
162            print "Could not write file. ", e
163        #import sys; sys.exit()
164        raise IOError
165
166def fit_to_mesh(vertex_coordinates,
167                triangles,
168                point_coordinates,
169                point_attributes,
170                alpha = DEFAULT_ALPHA,
171                verbose = False,
172                expand_search = False,
173                data_origin = None,
174                mesh_origin = None,
175                precrop = False):
176    """
177    Fit a smooth surface to a triangulation,
178    given data points with attributes.
179
180
181        Inputs:
182
183          vertex_coordinates: List of coordinate pairs [xi, eta] of points
184          constituting mesh (or a an m x 2 Numeric array)
185
186          triangles: List of 3-tuples (or a Numeric array) of
187          integers representing indices of all vertices in the mesh.
188
189          point_coordinates: List of coordinate pairs [x, y] of data points
190          (or an nx2 Numeric array)
191
192          alpha: Smoothing parameter.
193
194          point_attributes: Vector or array of data at the point_coordinates.
195
196          data_origin and mesh_origin are 3-tuples consisting of
197          UTM zone, easting and northing. If specified
198          point coordinates and vertex coordinates are assumed to be
199          relative to their respective origins.
200
201    """
202    interp = Interpolation(vertex_coordinates,
203                           triangles,
204                           point_coordinates,
205                           alpha = alpha,
206                           verbose = verbose,
207                           expand_search = expand_search,
208                           data_origin = data_origin,
209                           mesh_origin = mesh_origin,
210                           precrop = precrop)
211
212    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
213    return vertex_attributes
214
215
216
217def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
218                    verbose = False, reduction = 1, format = 'netcdf'):
219    """Fits attributes from pts file to MxN rectangular mesh
220
221    Read pts file and create rectangular mesh of resolution MxN such that
222    it covers all points specified in pts file.
223
224    FIXME: This may be a temporary function until we decide on
225    netcdf formats etc
226
227    FIXME: Uses elevation hardwired
228    """
229
230    import util, mesh_factory
231
232    if verbose: print 'Read pts'
233    points, attributes = util.read_xya(pts_name, format)
234
235    #Reduce number of points a bit
236    points = points[::reduction]
237    elevation = attributes['elevation']  #Must be elevation
238    elevation = elevation[::reduction]
239
240    if verbose: print 'Got %d data points' %len(points)
241
242    if verbose: print 'Create mesh'
243    #Find extent
244    max_x = min_x = points[0][0]
245    max_y = min_y = points[0][1]
246    for point in points[1:]:
247        x = point[0]
248        if x > max_x: max_x = x
249        if x < min_x: min_x = x
250        y = point[1]
251        if y > max_y: max_y = y
252        if y < min_y: min_y = y
253
254    #Create appropriate mesh
255    vertex_coordinates, triangles, boundary =\
256         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
257                                (min_x, min_y))
258
259    #Fit attributes to mesh
260    vertex_attributes = fit_to_mesh(vertex_coordinates,
261                        triangles,
262                        points,
263                        elevation, alpha=alpha, verbose=verbose)
264
265
266
267    return vertex_coordinates, triangles, boundary, vertex_attributes
268
269
270
271class Interpolation:
272
273    def __init__(self,
274                 vertex_coordinates,
275                 triangles,
276                 point_coordinates = None,
277                 alpha = None,
278                 verbose = False,
279                 expand_search = True,
280                 max_points_per_cell = 30,
281                 mesh_origin = None,
282                 data_origin = None,
283                 precrop = False):
284
285
286        """ Build interpolation matrix mapping from
287        function values at vertices to function values at data points
288
289        Inputs:
290
291          vertex_coordinates: List of coordinate pairs [xi, eta] of
292          points constituting mesh (or a an m x 2 Numeric array)
293          Points may appear multiple times
294          (e.g. if vertices have discontinuities)
295
296          triangles: List of 3-tuples (or a Numeric array) of
297          integers representing indices of all vertices in the mesh.
298
299          point_coordinates: List of coordinate pairs [x, y] of
300          data points (or an nx2 Numeric array)
301          If point_coordinates is absent, only smoothing matrix will
302          be built
303
304          alpha: Smoothing parameter
305
306          data_origin and mesh_origin are 3-tuples consisting of
307          UTM zone, easting and northing. If specified
308          point coordinates and vertex coordinates are assumed to be
309          relative to their respective origins.
310
311        """
312        from util import ensure_numeric
313
314        #Convert input to Numeric arrays
315        triangles = ensure_numeric(triangles, Int)
316        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
317
318        #Build underlying mesh
319        if verbose: print 'Building mesh'
320        #self.mesh = General_mesh(vertex_coordinates, triangles,
321        #FIXME: Trying the normal mesh while testing precrop,
322        #       The functionality of boundary_polygon is needed for that
323
324        #FIXME - geo ref does not have to go into mesh.
325        # Change the point co-ords to conform to the
326        # mesh co-ords early in the code
327        if mesh_origin == None:
328            geo = None
329        else:
330            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
331        self.mesh = Mesh(vertex_coordinates, triangles,
332                         geo_reference = geo)
333       
334        self.mesh.check_integrity()
335
336        self.data_origin = data_origin
337
338        self.point_indices = None
339
340        #Smoothing parameter
341        if alpha is None:
342            self.alpha = DEFAULT_ALPHA
343        else:   
344            self.alpha = alpha
345
346        #Build coefficient matrices
347        self.build_coefficient_matrix_B(point_coordinates,
348                                        verbose = verbose,
349                                        expand_search = expand_search,
350                                        max_points_per_cell =\
351                                        max_points_per_cell,
352                                        data_origin = data_origin,
353                                        precrop = precrop)
354
355
356    def set_point_coordinates(self, point_coordinates,
357                              data_origin = None):
358        """
359        A public interface to setting the point co-ordinates.
360        """
361        self.build_coefficient_matrix_B(point_coordinates, data_origin)
362
363    def build_coefficient_matrix_B(self, point_coordinates=None,
364                                   verbose = False, expand_search = True,
365                                   max_points_per_cell=30,
366                                   data_origin = None,
367                                   precrop = False):
368        """Build final coefficient matrix"""
369
370
371        if self.alpha <> 0:
372            if verbose: print 'Building smoothing matrix'
373            self.build_smoothing_matrix_D()
374
375        if point_coordinates is not None:
376
377            if verbose: print 'Building interpolation matrix'
378            self.build_interpolation_matrix_A(point_coordinates,
379                                              verbose = verbose,
380                                              expand_search = expand_search,
381                                              max_points_per_cell =\
382                                              max_points_per_cell,
383                                              data_origin = data_origin,
384                                              precrop = precrop)
385
386            if self.alpha <> 0:
387                self.B = self.AtA + self.alpha*self.D
388            else:
389                self.B = self.AtA
390
391            #Convert self.B matrix to CSR format for faster matrix vector
392            self.B = Sparse_CSR(self.B)
393
394    def build_interpolation_matrix_A(self, point_coordinates,
395                                     verbose = False, expand_search = True,
396                                     max_points_per_cell=30,
397                                     data_origin = None,
398                                     precrop = False):
399        """Build n x m interpolation matrix, where
400        n is the number of data points and
401        m is the number of basis functions phi_k (one per vertex)
402
403        This algorithm uses a quad tree data structure for fast binning of data points
404        origin is a 3-tuple consisting of UTM zone, easting and northing.
405        If specified coordinates are assumed to be relative to this origin.
406
407        This one will override any data_origin that may be specified in
408        interpolation instance
409
410        """
411
412
413        #FIXME (Ole): Check that this function is memeory efficient.
414        #6 million datapoints and 300000 basis functions
415        #causes out-of-memory situation
416        #First thing to check is whether there is room for self.A and self.AtA
417        #
418        #Maybe we need some sort of blocking
419
420        from quad import build_quadtree
421        from util import ensure_numeric
422
423        if data_origin is None:
424            data_origin = self.data_origin #Use the one from
425                                           #interpolation instance
426
427        #Convert input to Numeric arrays just in case.
428        point_coordinates = ensure_numeric(point_coordinates, Float)
429
430        #Keep track of discarded points (if any).
431        #This is only registered if precrop is True
432        self.cropped_points = False
433
434        #Shift data points to same origin as mesh (if specified)
435
436        #FIXME this will shift if there was no geo_ref.
437        #But all this should be removed anyhow.
438        #change coords before this point
439        mesh_origin = self.mesh.geo_reference.get_origin()
440        if point_coordinates is not None:
441            if data_origin is not None:
442                if mesh_origin is not None:
443
444                    #Transformation:
445                    #
446                    #Let x_0 be the reference point of the point coordinates
447                    #and xi_0 the reference point of the mesh.
448                    #
449                    #A point coordinate (x + x_0) is then made relative
450                    #to xi_0 by
451                    #
452                    # x_new = x + x_0 - xi_0
453                    #
454                    #and similarly for eta
455
456                    x_offset = data_origin[1] - mesh_origin[1]
457                    y_offset = data_origin[2] - mesh_origin[2]
458                else: #Shift back to a zero origin
459                    x_offset = data_origin[1]
460                    y_offset = data_origin[2]
461
462                point_coordinates[:,0] += x_offset
463                point_coordinates[:,1] += y_offset
464            else:
465                if mesh_origin is not None:
466                    #Use mesh origin for data points
467                    point_coordinates[:,0] -= mesh_origin[1]
468                    point_coordinates[:,1] -= mesh_origin[2]
469
470
471
472        #Remove points falling outside mesh boundary
473        #This reduced one example from 1356 seconds to 825 seconds
474        if precrop is True:
475            from Numeric import take
476            from util import inside_polygon
477
478            if verbose: print 'Getting boundary polygon'
479            P = self.mesh.get_boundary_polygon()
480
481            if verbose: print 'Getting indices inside mesh boundary'
482            indices = inside_polygon(point_coordinates, P, verbose = verbose)
483
484
485            if len(indices) != point_coordinates.shape[0]:
486                self.cropped_points = True
487                if verbose:
488                    print 'Done - %d points outside mesh have been cropped.'\
489                          %(point_coordinates.shape[0] - len(indices))
490
491            point_coordinates = take(point_coordinates, indices)
492            self.point_indices = indices
493
494
495
496
497        #Build n x m interpolation matrix
498        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
499        n = point_coordinates.shape[0]     #Nbr of data points
500
501        if verbose: print 'Number of datapoints: %d' %n
502        if verbose: print 'Number of basis functions: %d' %m
503
504        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
505        #However, Sparse_CSR does not have the same methods as Sparse yet
506        #The tests will reveal what needs to be done
507        #self.A = Sparse_CSR(Sparse(n,m))
508        #self.AtA = Sparse_CSR(Sparse(m,m))
509        self.A = Sparse(n,m)
510        self.AtA = Sparse(m,m)
511
512        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
513        root = build_quadtree(self.mesh,
514                              max_points_per_cell = max_points_per_cell)
515
516        #Compute matrix elements
517        for i in range(n):
518            #For each data_coordinate point
519
520            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
521            x = point_coordinates[i]
522
523            #Find vertices near x
524            candidate_vertices = root.search(x[0], x[1])
525            is_more_elements = True
526
527            element_found, sigma0, sigma1, sigma2, k = \
528                self.search_triangles_of_vertices(candidate_vertices, x)
529            while not element_found and is_more_elements and expand_search:
530                #if verbose: print 'Expanding search'
531                candidate_vertices, branch = root.expand_search()
532                if branch == []:
533                    # Searching all the verts from the root cell that haven't
534                    # been searched.  This is the last try
535                    element_found, sigma0, sigma1, sigma2, k = \
536                      self.search_triangles_of_vertices(candidate_vertices, x)
537                    is_more_elements = False
538                else:
539                    element_found, sigma0, sigma1, sigma2, k = \
540                      self.search_triangles_of_vertices(candidate_vertices, x)
541
542
543            #Update interpolation matrix A if necessary
544            if element_found is True:
545                #Assign values to matrix A
546
547                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
548                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
549                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
550
551                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
552                js     = [j0,j1,j2]
553
554                for j in js:
555                    self.A[i,j] = sigmas[j]
556                    for k in js:
557                        self.AtA[j,k] += sigmas[j]*sigmas[k]
558            else:
559                pass
560                #Ok if there is no triangle for datapoint
561                #(as in brute force version)
562                #raise 'Could not find triangle for point', x
563
564
565
566    def search_triangles_of_vertices(self, candidate_vertices, x):
567            #Find triangle containing x:
568            element_found = False
569
570            # This will be returned if element_found = False
571            sigma2 = -10.0
572            sigma0 = -10.0
573            sigma1 = -10.0
574            k = -10.0
575
576            #For all vertices in same cell as point x
577            for v in candidate_vertices:
578
579                #for each triangle id (k) which has v as a vertex
580                for k, _ in self.mesh.vertexlist[v]:
581
582                    #Get the three vertex_points of candidate triangle
583                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
584                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
585                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
586
587                    #print "PDSG - k", k
588                    #print "PDSG - xi0", xi0
589                    #print "PDSG - xi1", xi1
590                    #print "PDSG - xi2", xi2
591                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
592                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
593
594                    #Get the three normals
595                    n0 = self.mesh.get_normal(k, 0)
596                    n1 = self.mesh.get_normal(k, 1)
597                    n2 = self.mesh.get_normal(k, 2)
598
599
600                    #Compute interpolation
601                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
602                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
603                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
604
605                    #print "PDSG - sigma0", sigma0
606                    #print "PDSG - sigma1", sigma1
607                    #print "PDSG - sigma2", sigma2
608
609                    #FIXME: Maybe move out to test or something
610                    epsilon = 1.0e-6
611                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
612
613                    #Check that this triangle contains the data point
614
615                    #Sigmas can get negative within
616                    #machine precision on some machines (e.g nautilus)
617                    #Hence the small eps
618                    eps = 1.0e-15
619                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
620                        element_found = True
621                        break
622
623                if element_found is True:
624                    #Don't look for any other triangle
625                    break
626            return element_found, sigma0, sigma1, sigma2, k
627
628
629
630    def build_interpolation_matrix_A_brute(self, point_coordinates):
631        """Build n x m interpolation matrix, where
632        n is the number of data points and
633        m is the number of basis functions phi_k (one per vertex)
634
635        This is the brute force which is too slow for large problems,
636        but could be used for testing
637        """
638
639        from util import ensure_numeric
640
641        #Convert input to Numeric arrays
642        point_coordinates = ensure_numeric(point_coordinates, Float)
643
644        #Build n x m interpolation matrix
645        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
646        n = point_coordinates.shape[0]     #Nbr of data points
647
648        self.A = Sparse(n,m)
649        self.AtA = Sparse(m,m)
650
651        #Compute matrix elements
652        for i in range(n):
653            #For each data_coordinate point
654
655            x = point_coordinates[i]
656            element_found = False
657            k = 0
658            while not element_found and k < len(self.mesh):
659                #For each triangle (brute force)
660                #FIXME: Real algorithm should only visit relevant triangles
661
662                #Get the three vertex_points
663                xi0 = self.mesh.get_vertex_coordinate(k, 0)
664                xi1 = self.mesh.get_vertex_coordinate(k, 1)
665                xi2 = self.mesh.get_vertex_coordinate(k, 2)
666
667                #Get the three normals
668                n0 = self.mesh.get_normal(k, 0)
669                n1 = self.mesh.get_normal(k, 1)
670                n2 = self.mesh.get_normal(k, 2)
671
672                #Compute interpolation
673                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
674                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
675                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
676
677                #FIXME: Maybe move out to test or something
678                epsilon = 1.0e-6
679                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
680
681                #Check that this triangle contains data point
682                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
683                    element_found = True
684                    #Assign values to matrix A
685
686                    j0 = self.mesh.triangles[k,0] #Global vertex id
687                    #self.A[i, j0] = sigma0
688
689                    j1 = self.mesh.triangles[k,1] #Global vertex id
690                    #self.A[i, j1] = sigma1
691
692                    j2 = self.mesh.triangles[k,2] #Global vertex id
693                    #self.A[i, j2] = sigma2
694
695                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
696                    js     = [j0,j1,j2]
697
698                    for j in js:
699                        self.A[i,j] = sigmas[j]
700                        for k in js:
701                            self.AtA[j,k] += sigmas[j]*sigmas[k]
702                k = k+1
703
704
705
706    def get_A(self):
707        return self.A.todense()
708
709    def get_B(self):
710        return self.B.todense()
711
712    def get_D(self):
713        return self.D.todense()
714
715        #FIXME: Remember to re-introduce the 1/n factor in the
716        #interpolation term
717
718    def build_smoothing_matrix_D(self):
719        """Build m x m smoothing matrix, where
720        m is the number of basis functions phi_k (one per vertex)
721
722        The smoothing matrix is defined as
723
724        D = D1 + D2
725
726        where
727
728        [D1]_{k,l} = \int_\Omega
729           \frac{\partial \phi_k}{\partial x}
730           \frac{\partial \phi_l}{\partial x}\,
731           dx dy
732
733        [D2]_{k,l} = \int_\Omega
734           \frac{\partial \phi_k}{\partial y}
735           \frac{\partial \phi_l}{\partial y}\,
736           dx dy
737
738
739        The derivatives \frac{\partial \phi_k}{\partial x},
740        \frac{\partial \phi_k}{\partial x} for a particular triangle
741        are obtained by computing the gradient a_k, b_k for basis function k
742        """
743
744        #FIXME: algorithm might be optimised by computing local 9x9
745        #"element stiffness matrices:
746
747        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
748
749        self.D = Sparse(m,m)
750
751        #For each triangle compute contributions to D = D1+D2
752        for i in range(len(self.mesh)):
753
754            #Get area
755            area = self.mesh.areas[i]
756
757            #Get global vertex indices
758            v0 = self.mesh.triangles[i,0]
759            v1 = self.mesh.triangles[i,1]
760            v2 = self.mesh.triangles[i,2]
761
762            #Get the three vertex_points
763            xi0 = self.mesh.get_vertex_coordinate(i, 0)
764            xi1 = self.mesh.get_vertex_coordinate(i, 1)
765            xi2 = self.mesh.get_vertex_coordinate(i, 2)
766
767            #Compute gradients for each vertex
768            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
769                              1, 0, 0)
770
771            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
772                              0, 1, 0)
773
774            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
775                              0, 0, 1)
776
777            #Compute diagonal contributions
778            self.D[v0,v0] += (a0*a0 + b0*b0)*area
779            self.D[v1,v1] += (a1*a1 + b1*b1)*area
780            self.D[v2,v2] += (a2*a2 + b2*b2)*area
781
782            #Compute contributions for basis functions sharing edges
783            e01 = (a0*a1 + b0*b1)*area
784            self.D[v0,v1] += e01
785            self.D[v1,v0] += e01
786
787            e12 = (a1*a2 + b1*b2)*area
788            self.D[v1,v2] += e12
789            self.D[v2,v1] += e12
790
791            e20 = (a2*a0 + b2*b0)*area
792            self.D[v2,v0] += e20
793            self.D[v0,v2] += e20
794
795
796    def fit(self, z):
797        """Fit a smooth surface to given 1d array of data points z.
798
799        The smooth surface is computed at each vertex in the underlying
800        mesh using the formula given in the module doc string.
801
802        Pre Condition:
803          self.A, self.AtA and self.B have been initialised
804
805        Inputs:
806          z: Single 1d vector or array of data at the point_coordinates.
807        """
808
809        #Convert input to Numeric arrays
810        from util import ensure_numeric
811        z = ensure_numeric(z, Float)
812
813        if len(z.shape) > 1 :
814            raise VectorShapeError, 'Can only deal with 1d data vector'
815
816        if self.point_indices is not None:
817            #Remove values for any points that were outside mesh
818            z = take(z, self.point_indices)
819
820        #Compute right hand side based on data
821        Atz = self.A.trans_mult(z)
822
823
824        #Check sanity
825        n, m = self.A.shape
826        if n<m and self.alpha == 0.0:
827            msg = 'ERROR (least_squares): Too few data points\n'
828            msg += 'There are only %d data points and alpha == 0. ' %n
829            msg += 'Need at least %d\n' %m
830            msg += 'Alternatively, set smoothing parameter alpha to a small '
831            msg += 'positive value,\ne.g. 1.0e-3.'
832            raise msg
833
834
835
836        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
837        #FIXME: Should we store the result here for later use? (ON)
838
839
840    def fit_points(self, z, verbose=False):
841        """Like fit, but more robust when each point has two or more attributes
842        FIXME (Ole): The name fit_points doesn't carry any meaning
843        for me. How about something like fit_multiple or fit_columns?
844        """
845
846        try:
847            if verbose: print 'Solving penalised least_squares problem'
848            return self.fit(z)
849        except VectorShapeError, e:
850            # broadcasting is not supported.
851
852            #Convert input to Numeric arrays
853            from util import ensure_numeric
854            z = ensure_numeric(z, Float)
855
856            #Build n x m interpolation matrix
857            m = self.mesh.coordinates.shape[0] #Number of vertices
858            n = z.shape[1]                     #Number of data points
859
860            f = zeros((m,n), Float) #Resulting columns
861
862            for i in range(z.shape[1]):
863                f[:,i] = self.fit(z[:,i])
864
865            return f
866
867
868    def interpolate(self, f):
869        """Evaluate smooth surface f at data points implied in self.A.
870
871        The mesh values representing a smooth surface are
872        assumed to be specified in f. This argument could,
873        for example have been obtained from the method self.fit()
874
875        Pre Condition:
876          self.A has been initialised
877
878        Inputs:
879          f: Vector or array of data at the mesh vertices.
880          If f is an array, interpolation will be done for each column as
881          per underlying matrix-matrix multiplication
882
883        Output:
884          Interpolated values at data points implied in self.A
885
886        """
887
888        return self.A * f
889
890    def cull_outsiders(self, f):
891        pass
892
893
894
895
896class Interpolation_function:
897    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
898    which is interpolated from time series defined at vertices of
899    triangular mesh (such as those stored in sww files)
900
901    Let m be the number of vertices, n the number of triangles
902    and p the number of timesteps.
903
904    Mandatory input
905        time:               px1 array of monotonously increasing times (Float)
906        quantities:         Dictionary of pxm arrays or 1 pxm array (Float)
907       
908    Optional input:
909        quantity_names:     List of keys into the quantities dictionary
910        vertex_coordinates: mx2 array of coordinates (Float)
911        triangles:          nx3 array of indices into vertex_coordinates (Int)
912        interpolation_points: array of coordinates to be interpolated to
913        verbose:            Level of reporting
914   
915   
916    The quantities returned by the callable object are specified by
917    the list quantities which must contain the names of the
918    quantities to be returned and also reflect the order, e.g. for
919    the shallow water wave equation, on would have
920    quantities = ['stage', 'xmomentum', 'ymomentum']
921
922    The parameter interpolation_points decides at which points interpolated
923    quantities are to be computed whenever object is called.
924    If None, return average value
925    """
926
927   
928   
929    def __init__(self,
930                 time,
931                 quantities,
932                 quantity_names = None, 
933                 vertex_coordinates = None,
934                 triangles = None,
935                 interpolation_points = None,
936                 verbose = False):
937        """Initialise object and build spatial interpolation if required
938        """
939
940        from Numeric import array, zeros, Float, alltrue, concatenate,\
941             reshape, ArrayType
942
943        from util import mean, ensure_numeric
944        from config import time_format
945        import types
946
947
948
949        #Check temporal info
950        time = ensure_numeric(time)       
951        msg = 'Time must be a monotonuosly '
952        msg += 'increasing sequence %s' %time
953        assert alltrue(time[1:] - time[:-1] > 0 ), msg
954
955
956        #Check if quantities is a single array only
957        if type(quantities) != types.DictType:
958            quantities = ensure_numeric(quantities)
959            quantity_names = ['Attribute']
960
961            #Make it a dictionary
962            quantities = {quantity_names[0]: quantities}
963
964
965        #Use keys if no names are specified
966        if quantity_names is None:
967            quantity_names = quantities.keys()
968
969
970        #Check spatial info
971        if vertex_coordinates is None:
972            self.spatial = False
973        else:   
974            vertex_coordinates = ensure_numeric(vertex_coordinates)
975
976            assert triangles is not None, 'Triangles array must be specified'
977            triangles = ensure_numeric(triangles)
978            self.spatial = True           
979           
980
981 
982        #Save for use with statistics
983        self.quantity_names = quantity_names       
984        self.quantities = quantities       
985        self.vertex_coordinates = vertex_coordinates
986        self.interpolation_points = interpolation_points
987        self.T = time[:]  #Time assumed to be relative to starttime
988        self.index = 0    #Initial time index
989        self.precomputed_values = {}
990           
991
992
993        #Precomputed spatial interpolation if requested
994        if interpolation_points is not None:
995            if self.spatial is False:
996                raise 'Triangles and vertex_coordinates must be specified'
997           
998
999            try:
1000                self.interpolation_points =\
1001                                          ensure_numeric(self.interpolation_points)
1002            except:
1003                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1004                      'or a list of points\n'
1005                msg += 'I got: %s.' %( str(self.interpolation_points)[:60] + '...')
1006                raise msg
1007
1008
1009            for name in quantity_names:
1010                self.precomputed_values[name] =\
1011                                              zeros((len(self.T),
1012                                                     len(self.interpolation_points)),
1013                                                    Float)
1014
1015            #Build interpolator
1016            interpol = Interpolation(vertex_coordinates,
1017                                     triangles,
1018                                     point_coordinates = self.interpolation_points,
1019                                     alpha = 0,
1020                                     precrop = False, 
1021                                     verbose = verbose)
1022
1023            if verbose: print 'Interpolate'
1024            n = len(self.T)
1025            for i, t in enumerate(self.T):
1026                #Interpolate quantities at this timestep
1027                if verbose and i%((n+10)/10)==0:
1028                    print ' time step %d of %d' %(i, n)
1029                   
1030                for name in quantity_names:
1031                    self.precomputed_values[name][i, :] =\
1032                    interpol.interpolate(quantities[name][i,:])
1033
1034            #Report
1035            if verbose:
1036                print self.statistics()
1037                #self.print_statistics()
1038           
1039        else:
1040            #Store quantitites as is
1041            for name in quantity_names:
1042                self.precomputed_values[name] = quantities[name]
1043
1044
1045        #else:
1046        #    #Return an average, making this a time series
1047        #    for name in quantity_names:
1048        #        self.values[name] = zeros(len(self.T), Float)
1049        #
1050        #    if verbose: print 'Compute mean values'
1051        #    for i, t in enumerate(self.T):
1052        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1053        #        for name in quantity_names:
1054        #           self.values[name][i] = mean(quantities[name][i,:])
1055
1056
1057
1058
1059    def __repr__(self):
1060        #return 'Interpolation function (spation-temporal)'
1061        return self.statistics()
1062   
1063
1064    def __call__(self, t, point_id = None, x = None, y = None):
1065        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1066
1067        Inputs:
1068          t: time - Model time. Must lie within existing timesteps
1069          point_id: index of one of the preprocessed points.
1070          x, y:     Overrides location, point_id ignored
1071         
1072          If spatial info is present and all of x,y,point_id
1073          are None an exception is raised
1074                   
1075          If no spatial info is present, point_id and x,y arguments are ignored
1076          making f a function of time only.
1077
1078         
1079          FIXME: point_id could also be a slice
1080          FIXME: What if x and y are vectors?
1081          FIXME: What about f(x,y) without t?
1082        """
1083
1084        from math import pi, cos, sin, sqrt
1085        from Numeric import zeros, Float
1086        from util import mean       
1087
1088        if self.spatial is True:
1089            if point_id is None:
1090                if x is None or y is None:
1091                    msg = 'Either point_id or x and y must be specified'
1092                    raise msg
1093            else:
1094                if self.interpolation_points is None:
1095                    msg = 'Interpolation_function must be instantiated ' +\
1096                          'with a list of interpolation points before parameter ' +\
1097                          'point_id can be used'
1098                    raise msg
1099
1100
1101        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1102        msg += ' does not match model time: %s\n' %t
1103        if t < self.T[0]: raise msg
1104        if t > self.T[-1]: raise msg
1105
1106        oldindex = self.index #Time index
1107
1108        #Find current time slot
1109        while t > self.T[self.index]: self.index += 1
1110        while t < self.T[self.index]: self.index -= 1
1111
1112        if t == self.T[self.index]:
1113            #Protect against case where t == T[-1] (last time)
1114            # - also works in general when t == T[i]
1115            ratio = 0
1116        else:
1117            #t is now between index and index+1
1118            ratio = (t - self.T[self.index])/\
1119                    (self.T[self.index+1] - self.T[self.index])
1120
1121        #Compute interpolated values
1122        q = zeros(len(self.quantity_names), Float)
1123
1124        for i, name in enumerate(self.quantity_names):
1125            Q = self.precomputed_values[name]
1126
1127            if self.spatial is False:
1128                #If there is no spatial info               
1129                assert len(Q.shape) == 1
1130
1131                Q0 = Q[self.index]
1132                if ratio > 0: Q1 = Q[self.index+1]
1133
1134            else:
1135                if x is not None and y is not None:
1136                    #Interpolate to x, y
1137                   
1138                    raise 'x,y interpolation not yet implemented'
1139                else:
1140                    #Use precomputed point
1141                    Q0 = Q[self.index, point_id]
1142                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1143
1144            #Linear temporal interpolation   
1145            if ratio > 0:
1146                q[i] = Q0 + ratio*(Q1 - Q0)
1147            else:
1148                q[i] = Q0
1149
1150
1151        #Return vector of interpolated values
1152        #if len(q) == 1:
1153        #    return q[0]
1154        #else:
1155        #    return q
1156
1157
1158        #Return vector of interpolated values
1159        #FIXME:
1160        if self.spatial is True:
1161            return q
1162        else:
1163            #Replicate q according to x and y
1164            #This is e.g used for Wind_stress
1165            if x == None or y == None: 
1166                return q
1167            else:
1168                try:
1169                    N = len(x)
1170                except:
1171                    return q
1172                else:
1173                    from Numeric import ones, Float
1174                    #x is a vector - Create one constant column for each value
1175                    N = len(x)
1176                    assert len(y) == N, 'x and y must have same length'
1177                    res = []
1178                    for col in q:
1179                        res.append(col*ones(N, Float))
1180                       
1181                return res
1182
1183
1184    def statistics(self):
1185        """Output statistics about interpolation_function
1186        """
1187       
1188        vertex_coordinates = self.vertex_coordinates
1189        interpolation_points = self.interpolation_points               
1190        quantity_names = self.quantity_names
1191        quantities = self.quantities
1192        precomputed_values = self.precomputed_values                 
1193               
1194        x = vertex_coordinates[:,0]
1195        y = vertex_coordinates[:,1]               
1196
1197        str =  '------------------------------------------------\n'
1198        str += 'Interpolation_function (spation-temporal) statistics:\n'
1199        str += '  Extent:\n'
1200        str += '    x in [%f, %f], len(x) == %d\n'\
1201               %(min(x), max(x), len(x))
1202        str += '    y in [%f, %f], len(y) == %d\n'\
1203               %(min(y), max(y), len(y))
1204        str += '    t in [%f, %f], len(t) == %d\n'\
1205               %(min(self.T), max(self.T), len(self.T))
1206        str += '  Quantities:\n'
1207        for name in quantity_names:
1208            q = quantities[name][:].flat
1209            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1210
1211        if interpolation_points is not None:   
1212            str += '  Interpolation points (xi, eta):'\
1213                   ' number of points == %d\n' %interpolation_points.shape[0]
1214            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1215                                            max(interpolation_points[:,0]))
1216            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1217                                             max(interpolation_points[:,1]))
1218            str += '  Interpolated quantities (over all timesteps):\n'
1219       
1220            for name in quantity_names:
1221                q = precomputed_values[name][:].flat
1222                str += '    %s at interpolation points in [%f, %f]\n'\
1223                       %(name, min(q), max(q))
1224        str += '------------------------------------------------\n'
1225
1226        return str
1227
1228        #FIXME: Delete
1229        #print '------------------------------------------------'
1230        #print 'Interpolation_function statistics:'
1231        #print '  Extent:'
1232        #print '    x in [%f, %f], len(x) == %d'\
1233        #      %(min(x), max(x), len(x))
1234        #print '    y in [%f, %f], len(y) == %d'\
1235        #      %(min(y), max(y), len(y))
1236        #print '    t in [%f, %f], len(t) == %d'\
1237        #      %(min(self.T), max(self.T), len(self.T))
1238        #print '  Quantities:'
1239        #for name in quantity_names:
1240        #    q = quantities[name][:].flat
1241        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1242        #print '  Interpolation points (xi, eta):'\
1243        #      ' number of points == %d ' %interpolation_points.shape[0]
1244        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1245        #                             max(interpolation_points[:,0]))
1246        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1247        #                              max(interpolation_points[:,1]))
1248        #print '  Interpolated quantities (over all timesteps):'
1249        #
1250        #for name in quantity_names:
1251        #    q = precomputed_values[name][:].flat
1252        #    print '    %s at interpolation points in [%f, %f]'\
1253        #          %(name, min(q), max(q))
1254        #print '------------------------------------------------'
1255
1256
1257#-------------------------------------------------------------
1258if __name__ == "__main__":
1259    """
1260    Load in a mesh and data points with attributes.
1261    Fit the attributes to the mesh.
1262    Save a new mesh file.
1263    """
1264    import os, sys
1265    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha]"\
1266            %os.path.basename(sys.argv[0])
1267
1268    if len(sys.argv) < 4:
1269        print usage
1270    else:
1271        mesh_file = sys.argv[1]
1272        point_file = sys.argv[2]
1273        mesh_output_file = sys.argv[3]
1274
1275        expand_search = False
1276        if len(sys.argv) > 4:
1277            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1278                expand_search = True
1279            else:
1280                expand_search = False
1281
1282        verbose = False
1283        if len(sys.argv) > 5:
1284            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1285                verbose = False
1286            else:
1287                verbose = True
1288
1289        if len(sys.argv) > 6:
1290            alpha = sys.argv[6]
1291        else:
1292            alpha = DEFAULT_ALPHA
1293
1294        t0 = time.time()
1295        fit_to_mesh_file(mesh_file,
1296                         point_file,
1297                         mesh_output_file,
1298                         alpha,
1299                         verbose= verbose,
1300                         expand_search = expand_search)
1301
1302        print 'That took %.2f seconds' %(time.time()-t0)
1303
Note: See TracBrowser for help on using the repository browser.