source: inundation/pyvolution/least_squares.py @ 1884

Last change on this file since 1884 was 1884, checked in by ole, 19 years ago

Made interpolation_points in file_function absolute UTM coordinates and wrote test

File size: 46.4 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    mesh_origin is the same but refers to the input tsh file.
75    FIXME: When the tsh format contains it own origin, these parameters can go.
76    FIXME: And both origins should be obtained from the specified files.
77    """
78
79    from load_mesh.loadASCII import import_mesh_file, \
80                 import_points_file, export_mesh_file, \
81                 concatinate_attributelist
82
83
84    try:
85        mesh_dict = import_mesh_file(mesh_file)
86    except IOError,e:
87        if display_errors:
88            print "Could not load bad file. ", e
89        raise IOError  #Re-raise exception
90       
91    vertex_coordinates = mesh_dict['vertices']
92    triangles = mesh_dict['triangles']
93    if type(mesh_dict['vertex_attributes']) == ArrayType:
94        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
95    else:
96        old_point_attributes = mesh_dict['vertex_attributes']
97
98    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
99        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
100    else:
101        old_title_list = mesh_dict['vertex_attribute_titles']
102
103    if verbose: print 'tsh file %s loaded' %mesh_file
104
105    # load in the .pts file
106    try:
107        point_dict = import_points_file(point_file, verbose=verbose)
108    except IOError,e:
109        if display_errors:
110            print "Could not load bad file. ", e
111        raise IOError  #Re-raise exception 
112
113    point_coordinates = point_dict['pointlist']
114    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
115
116    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
117        data_origin = point_dict['geo_reference'].get_origin()
118    else:
119        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
120
121    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
122        mesh_origin = mesh_dict['geo_reference'].get_origin()
123    else:
124        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
125
126    if verbose: print "points file loaded"
127    if verbose:print "fitting to mesh"
128    f = fit_to_mesh(vertex_coordinates,
129                    triangles,
130                    point_coordinates,
131                    point_attributes,
132                    alpha = alpha,
133                    verbose = verbose,
134                    expand_search = expand_search,
135                    data_origin = data_origin,
136                    mesh_origin = mesh_origin,
137                    precrop = precrop)
138    if verbose: print "finished fitting to mesh"
139
140    # convert array to list of lists
141    new_point_attributes = f.tolist()
142    #FIXME have this overwrite attributes with the same title - DSG
143    #Put the newer attributes last
144    if old_title_list <> []:
145        old_title_list.extend(title_list)
146        #FIXME can this be done a faster way? - DSG
147        for i in range(len(old_point_attributes)):
148            old_point_attributes[i].extend(new_point_attributes[i])
149        mesh_dict['vertex_attributes'] = old_point_attributes
150        mesh_dict['vertex_attribute_titles'] = old_title_list
151    else:
152        mesh_dict['vertex_attributes'] = new_point_attributes
153        mesh_dict['vertex_attribute_titles'] = title_list
154
155    #FIXME (Ole): Remember to output mesh_origin as well
156    if verbose: print "exporting to file ",mesh_output_file
157
158    try:
159        export_mesh_file(mesh_output_file, mesh_dict)
160    except IOError,e:
161        if display_errors:
162            print "Could not write file. ", e
163        raise IOError
164
165def fit_to_mesh(vertex_coordinates,
166                triangles,
167                point_coordinates,
168                point_attributes,
169                alpha = DEFAULT_ALPHA,
170                verbose = False,
171                expand_search = False,
172                data_origin = None,
173                mesh_origin = None,
174                precrop = False):
175    """
176    Fit a smooth surface to a triangulation,
177    given data points with attributes.
178
179
180        Inputs:
181
182          vertex_coordinates: List of coordinate pairs [xi, eta] of points
183          constituting mesh (or a an m x 2 Numeric array)
184
185          triangles: List of 3-tuples (or a Numeric array) of
186          integers representing indices of all vertices in the mesh.
187
188          point_coordinates: List of coordinate pairs [x, y] of data points
189          (or an nx2 Numeric array)
190
191          alpha: Smoothing parameter.
192
193          point_attributes: Vector or array of data at the point_coordinates.
194
195          data_origin and mesh_origin are 3-tuples consisting of
196          UTM zone, easting and northing. If specified
197          point coordinates and vertex coordinates are assumed to be
198          relative to their respective origins.
199
200    """
201    interp = Interpolation(vertex_coordinates,
202                           triangles,
203                           point_coordinates,
204                           alpha = alpha,
205                           verbose = verbose,
206                           expand_search = expand_search,
207                           data_origin = data_origin,
208                           mesh_origin = mesh_origin,
209                           precrop = precrop)
210
211    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
212    return vertex_attributes
213
214
215
216def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
217                    verbose = False, reduction = 1, format = 'netcdf'):
218    """Fits attributes from pts file to MxN rectangular mesh
219
220    Read pts file and create rectangular mesh of resolution MxN such that
221    it covers all points specified in pts file.
222
223    FIXME: This may be a temporary function until we decide on
224    netcdf formats etc
225
226    FIXME: Uses elevation hardwired
227    """
228
229    import util, mesh_factory
230
231    if verbose: print 'Read pts'
232    points, attributes = util.read_xya(pts_name, format)
233
234    #Reduce number of points a bit
235    points = points[::reduction]
236    elevation = attributes['elevation']  #Must be elevation
237    elevation = elevation[::reduction]
238
239    if verbose: print 'Got %d data points' %len(points)
240
241    if verbose: print 'Create mesh'
242    #Find extent
243    max_x = min_x = points[0][0]
244    max_y = min_y = points[0][1]
245    for point in points[1:]:
246        x = point[0]
247        if x > max_x: max_x = x
248        if x < min_x: min_x = x
249        y = point[1]
250        if y > max_y: max_y = y
251        if y < min_y: min_y = y
252
253    #Create appropriate mesh
254    vertex_coordinates, triangles, boundary =\
255         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
256                                (min_x, min_y))
257
258    #Fit attributes to mesh
259    vertex_attributes = fit_to_mesh(vertex_coordinates,
260                        triangles,
261                        points,
262                        elevation, alpha=alpha, verbose=verbose)
263
264
265
266    return vertex_coordinates, triangles, boundary, vertex_attributes
267
268
269
270class Interpolation:
271
272    def __init__(self,
273                 vertex_coordinates,
274                 triangles,
275                 point_coordinates = None,
276                 alpha = None,
277                 verbose = False,
278                 expand_search = True,
279                 max_points_per_cell = 30,
280                 mesh_origin = None,
281                 data_origin = None,
282                 precrop = False):
283
284
285        """ Build interpolation matrix mapping from
286        function values at vertices to function values at data points
287
288        Inputs:
289
290          vertex_coordinates: List of coordinate pairs [xi, eta] of
291          points constituting mesh (or a an m x 2 Numeric array)
292          Points may appear multiple times
293          (e.g. if vertices have discontinuities)
294
295          triangles: List of 3-tuples (or a Numeric array) of
296          integers representing indices of all vertices in the mesh.
297
298          point_coordinates: List of coordinate pairs [x, y] of
299          data points (or an nx2 Numeric array)
300          If point_coordinates is absent, only smoothing matrix will
301          be built
302
303          alpha: Smoothing parameter
304
305          data_origin and mesh_origin are 3-tuples consisting of
306          UTM zone, easting and northing. If specified
307          point coordinates and vertex coordinates are assumed to be
308          relative to their respective origins.
309
310        """
311        from util import ensure_numeric
312
313        #Convert input to Numeric arrays
314        triangles = ensure_numeric(triangles, Int)
315        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
316
317        #Build underlying mesh
318        if verbose: print 'Building mesh'
319        #self.mesh = General_mesh(vertex_coordinates, triangles,
320        #FIXME: Trying the normal mesh while testing precrop,
321        #       The functionality of boundary_polygon is needed for that
322
323        #FIXME - geo ref does not have to go into mesh.
324        # Change the point co-ords to conform to the
325        # mesh co-ords early in the code
326        if mesh_origin == None:
327            geo = None
328        else:
329            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
330        self.mesh = Mesh(vertex_coordinates, triangles,
331                         geo_reference = geo)
332       
333        self.mesh.check_integrity()
334
335        self.data_origin = data_origin
336
337        self.point_indices = None
338
339        #Smoothing parameter
340        if alpha is None:
341            self.alpha = DEFAULT_ALPHA
342        else:   
343            self.alpha = alpha
344
345        #Build coefficient matrices
346        self.build_coefficient_matrix_B(point_coordinates,
347                                        verbose = verbose,
348                                        expand_search = expand_search,
349                                        max_points_per_cell =\
350                                        max_points_per_cell,
351                                        data_origin = data_origin,
352                                        precrop = precrop)
353
354
355    def set_point_coordinates(self, point_coordinates,
356                              data_origin = None):
357        """
358        A public interface to setting the point co-ordinates.
359        """
360        self.build_coefficient_matrix_B(point_coordinates, data_origin)
361
362    def build_coefficient_matrix_B(self, point_coordinates=None,
363                                   verbose = False, expand_search = True,
364                                   max_points_per_cell=30,
365                                   data_origin = None,
366                                   precrop = False):
367        """Build final coefficient matrix"""
368
369
370        if self.alpha <> 0:
371            if verbose: print 'Building smoothing matrix'
372            self.build_smoothing_matrix_D()
373
374        if point_coordinates is not None:
375
376            if verbose: print 'Building interpolation matrix'
377            self.build_interpolation_matrix_A(point_coordinates,
378                                              verbose = verbose,
379                                              expand_search = expand_search,
380                                              max_points_per_cell =\
381                                              max_points_per_cell,
382                                              data_origin = data_origin,
383                                              precrop = precrop)
384
385            if self.alpha <> 0:
386                self.B = self.AtA + self.alpha*self.D
387            else:
388                self.B = self.AtA
389
390            #Convert self.B matrix to CSR format for faster matrix vector
391            self.B = Sparse_CSR(self.B)
392
393    def build_interpolation_matrix_A(self, point_coordinates,
394                                     verbose = False, expand_search = True,
395                                     max_points_per_cell=30,
396                                     data_origin = None,
397                                     precrop = False):
398        """Build n x m interpolation matrix, where
399        n is the number of data points and
400        m is the number of basis functions phi_k (one per vertex)
401
402        This algorithm uses a quad tree data structure for fast binning of data points
403        origin is a 3-tuple consisting of UTM zone, easting and northing.
404        If specified coordinates are assumed to be relative to this origin.
405
406        This one will override any data_origin that may be specified in
407        interpolation instance
408
409        """
410
411
412        #FIXME (Ole): Check that this function is memeory efficient.
413        #6 million datapoints and 300000 basis functions
414        #causes out-of-memory situation
415        #First thing to check is whether there is room for self.A and self.AtA
416        #
417        #Maybe we need some sort of blocking
418
419        from quad import build_quadtree
420        from util import ensure_numeric
421
422        if data_origin is None:
423            data_origin = self.data_origin #Use the one from
424                                           #interpolation instance
425
426        #Convert input to Numeric arrays just in case.
427        point_coordinates = ensure_numeric(point_coordinates, Float)
428
429        #Keep track of discarded points (if any).
430        #This is only registered if precrop is True
431        self.cropped_points = False
432
433        #Shift data points to same origin as mesh (if specified)
434
435        #FIXME this will shift if there was no geo_ref.
436        #But all this should be removed anyhow.
437        #change coords before this point
438        mesh_origin = self.mesh.geo_reference.get_origin()
439        if point_coordinates is not None:
440            if data_origin is not None:
441                if mesh_origin is not None:
442
443                    #Transformation:
444                    #
445                    #Let x_0 be the reference point of the point coordinates
446                    #and xi_0 the reference point of the mesh.
447                    #
448                    #A point coordinate (x + x_0) is then made relative
449                    #to xi_0 by
450                    #
451                    # x_new = x + x_0 - xi_0
452                    #
453                    #and similarly for eta
454
455                    x_offset = data_origin[1] - mesh_origin[1]
456                    y_offset = data_origin[2] - mesh_origin[2]
457                else: #Shift back to a zero origin
458                    x_offset = data_origin[1]
459                    y_offset = data_origin[2]
460
461                point_coordinates[:,0] += x_offset
462                point_coordinates[:,1] += y_offset
463            else:
464                if mesh_origin is not None:
465                    #Use mesh origin for data points
466                    point_coordinates[:,0] -= mesh_origin[1]
467                    point_coordinates[:,1] -= mesh_origin[2]
468
469
470
471        #Remove points falling outside mesh boundary
472        #This reduced one example from 1356 seconds to 825 seconds
473        if precrop is True:
474            from Numeric import take
475            from util import inside_polygon
476
477            if verbose: print 'Getting boundary polygon'
478            P = self.mesh.get_boundary_polygon()
479
480            if verbose: print 'Getting indices inside mesh boundary'
481            indices = inside_polygon(point_coordinates, P, verbose = verbose)
482
483
484            if len(indices) != point_coordinates.shape[0]:
485                self.cropped_points = True
486                if verbose:
487                    print 'Done - %d points outside mesh have been cropped.'\
488                          %(point_coordinates.shape[0] - len(indices))
489
490            point_coordinates = take(point_coordinates, indices)
491            self.point_indices = indices
492
493
494
495
496        #Build n x m interpolation matrix
497        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
498        n = point_coordinates.shape[0]     #Nbr of data points
499
500        if verbose: print 'Number of datapoints: %d' %n
501        if verbose: print 'Number of basis functions: %d' %m
502
503        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
504        #However, Sparse_CSR does not have the same methods as Sparse yet
505        #The tests will reveal what needs to be done
506        #self.A = Sparse_CSR(Sparse(n,m))
507        #self.AtA = Sparse_CSR(Sparse(m,m))
508        self.A = Sparse(n,m)
509        self.AtA = Sparse(m,m)
510
511        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
512        root = build_quadtree(self.mesh,
513                              max_points_per_cell = max_points_per_cell)
514
515        #Compute matrix elements
516        for i in range(n):
517            #For each data_coordinate point
518
519            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
520            x = point_coordinates[i]
521
522            #Find vertices near x
523            candidate_vertices = root.search(x[0], x[1])
524            is_more_elements = True
525
526            element_found, sigma0, sigma1, sigma2, k = \
527                self.search_triangles_of_vertices(candidate_vertices, x)
528            while not element_found and is_more_elements and expand_search:
529                #if verbose: print 'Expanding search'
530                candidate_vertices, branch = root.expand_search()
531                if branch == []:
532                    # Searching all the verts from the root cell that haven't
533                    # been searched.  This is the last try
534                    element_found, sigma0, sigma1, sigma2, k = \
535                      self.search_triangles_of_vertices(candidate_vertices, x)
536                    is_more_elements = False
537                else:
538                    element_found, sigma0, sigma1, sigma2, k = \
539                      self.search_triangles_of_vertices(candidate_vertices, x)
540
541
542            #Update interpolation matrix A if necessary
543            if element_found is True:
544                #Assign values to matrix A
545
546                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
547                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
548                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
549
550                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
551                js     = [j0,j1,j2]
552
553                for j in js:
554                    self.A[i,j] = sigmas[j]
555                    for k in js:
556                        self.AtA[j,k] += sigmas[j]*sigmas[k]
557            else:
558                pass
559                #Ok if there is no triangle for datapoint
560                #(as in brute force version)
561                #raise 'Could not find triangle for point', x
562
563
564
565    def search_triangles_of_vertices(self, candidate_vertices, x):
566            #Find triangle containing x:
567            element_found = False
568
569            # This will be returned if element_found = False
570            sigma2 = -10.0
571            sigma0 = -10.0
572            sigma1 = -10.0
573            k = -10.0
574
575            #For all vertices in same cell as point x
576            for v in candidate_vertices:
577
578                #for each triangle id (k) which has v as a vertex
579                for k, _ in self.mesh.vertexlist[v]:
580
581                    #Get the three vertex_points of candidate triangle
582                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
583                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
584                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
585
586                    #print "PDSG - k", k
587                    #print "PDSG - xi0", xi0
588                    #print "PDSG - xi1", xi1
589                    #print "PDSG - xi2", xi2
590                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
591                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
592
593                    #Get the three normals
594                    n0 = self.mesh.get_normal(k, 0)
595                    n1 = self.mesh.get_normal(k, 1)
596                    n2 = self.mesh.get_normal(k, 2)
597
598
599                    #Compute interpolation
600                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
601                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
602                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
603
604                    #print "PDSG - sigma0", sigma0
605                    #print "PDSG - sigma1", sigma1
606                    #print "PDSG - sigma2", sigma2
607
608                    #FIXME: Maybe move out to test or something
609                    epsilon = 1.0e-6
610                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
611
612                    #Check that this triangle contains the data point
613
614                    #Sigmas can get negative within
615                    #machine precision on some machines (e.g nautilus)
616                    #Hence the small eps
617                    eps = 1.0e-15
618                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
619                        element_found = True
620                        break
621
622                if element_found is True:
623                    #Don't look for any other triangle
624                    break
625            return element_found, sigma0, sigma1, sigma2, k
626
627
628
629    def build_interpolation_matrix_A_brute(self, point_coordinates):
630        """Build n x m interpolation matrix, where
631        n is the number of data points and
632        m is the number of basis functions phi_k (one per vertex)
633
634        This is the brute force which is too slow for large problems,
635        but could be used for testing
636        """
637
638        from util import ensure_numeric
639
640        #Convert input to Numeric arrays
641        point_coordinates = ensure_numeric(point_coordinates, Float)
642
643        #Build n x m interpolation matrix
644        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
645        n = point_coordinates.shape[0]     #Nbr of data points
646
647        self.A = Sparse(n,m)
648        self.AtA = Sparse(m,m)
649
650        #Compute matrix elements
651        for i in range(n):
652            #For each data_coordinate point
653
654            x = point_coordinates[i]
655            element_found = False
656            k = 0
657            while not element_found and k < len(self.mesh):
658                #For each triangle (brute force)
659                #FIXME: Real algorithm should only visit relevant triangles
660
661                #Get the three vertex_points
662                xi0 = self.mesh.get_vertex_coordinate(k, 0)
663                xi1 = self.mesh.get_vertex_coordinate(k, 1)
664                xi2 = self.mesh.get_vertex_coordinate(k, 2)
665
666                #Get the three normals
667                n0 = self.mesh.get_normal(k, 0)
668                n1 = self.mesh.get_normal(k, 1)
669                n2 = self.mesh.get_normal(k, 2)
670
671                #Compute interpolation
672                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
673                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
674                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
675
676                #FIXME: Maybe move out to test or something
677                epsilon = 1.0e-6
678                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
679
680                #Check that this triangle contains data point
681                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
682                    element_found = True
683                    #Assign values to matrix A
684
685                    j0 = self.mesh.triangles[k,0] #Global vertex id
686                    #self.A[i, j0] = sigma0
687
688                    j1 = self.mesh.triangles[k,1] #Global vertex id
689                    #self.A[i, j1] = sigma1
690
691                    j2 = self.mesh.triangles[k,2] #Global vertex id
692                    #self.A[i, j2] = sigma2
693
694                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
695                    js     = [j0,j1,j2]
696
697                    for j in js:
698                        self.A[i,j] = sigmas[j]
699                        for k in js:
700                            self.AtA[j,k] += sigmas[j]*sigmas[k]
701                k = k+1
702
703
704
705    def get_A(self):
706        return self.A.todense()
707
708    def get_B(self):
709        return self.B.todense()
710
711    def get_D(self):
712        return self.D.todense()
713
714        #FIXME: Remember to re-introduce the 1/n factor in the
715        #interpolation term
716
717    def build_smoothing_matrix_D(self):
718        """Build m x m smoothing matrix, where
719        m is the number of basis functions phi_k (one per vertex)
720
721        The smoothing matrix is defined as
722
723        D = D1 + D2
724
725        where
726
727        [D1]_{k,l} = \int_\Omega
728           \frac{\partial \phi_k}{\partial x}
729           \frac{\partial \phi_l}{\partial x}\,
730           dx dy
731
732        [D2]_{k,l} = \int_\Omega
733           \frac{\partial \phi_k}{\partial y}
734           \frac{\partial \phi_l}{\partial y}\,
735           dx dy
736
737
738        The derivatives \frac{\partial \phi_k}{\partial x},
739        \frac{\partial \phi_k}{\partial x} for a particular triangle
740        are obtained by computing the gradient a_k, b_k for basis function k
741        """
742
743        #FIXME: algorithm might be optimised by computing local 9x9
744        #"element stiffness matrices:
745
746        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
747
748        self.D = Sparse(m,m)
749
750        #For each triangle compute contributions to D = D1+D2
751        for i in range(len(self.mesh)):
752
753            #Get area
754            area = self.mesh.areas[i]
755
756            #Get global vertex indices
757            v0 = self.mesh.triangles[i,0]
758            v1 = self.mesh.triangles[i,1]
759            v2 = self.mesh.triangles[i,2]
760
761            #Get the three vertex_points
762            xi0 = self.mesh.get_vertex_coordinate(i, 0)
763            xi1 = self.mesh.get_vertex_coordinate(i, 1)
764            xi2 = self.mesh.get_vertex_coordinate(i, 2)
765
766            #Compute gradients for each vertex
767            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
768                              1, 0, 0)
769
770            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
771                              0, 1, 0)
772
773            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
774                              0, 0, 1)
775
776            #Compute diagonal contributions
777            self.D[v0,v0] += (a0*a0 + b0*b0)*area
778            self.D[v1,v1] += (a1*a1 + b1*b1)*area
779            self.D[v2,v2] += (a2*a2 + b2*b2)*area
780
781            #Compute contributions for basis functions sharing edges
782            e01 = (a0*a1 + b0*b1)*area
783            self.D[v0,v1] += e01
784            self.D[v1,v0] += e01
785
786            e12 = (a1*a2 + b1*b2)*area
787            self.D[v1,v2] += e12
788            self.D[v2,v1] += e12
789
790            e20 = (a2*a0 + b2*b0)*area
791            self.D[v2,v0] += e20
792            self.D[v0,v2] += e20
793
794
795    def fit(self, z):
796        """Fit a smooth surface to given 1d array of data points z.
797
798        The smooth surface is computed at each vertex in the underlying
799        mesh using the formula given in the module doc string.
800
801        Pre Condition:
802          self.A, self.AtA and self.B have been initialised
803
804        Inputs:
805          z: Single 1d vector or array of data at the point_coordinates.
806        """
807
808        #Convert input to Numeric arrays
809        from util import ensure_numeric
810        z = ensure_numeric(z, Float)
811
812        if len(z.shape) > 1 :
813            raise VectorShapeError, 'Can only deal with 1d data vector'
814
815        if self.point_indices is not None:
816            #Remove values for any points that were outside mesh
817            z = take(z, self.point_indices)
818
819        #Compute right hand side based on data
820        Atz = self.A.trans_mult(z)
821
822
823        #Check sanity
824        n, m = self.A.shape
825        if n<m and self.alpha == 0.0:
826            msg = 'ERROR (least_squares): Too few data points\n'
827            msg += 'There are only %d data points and alpha == 0. ' %n
828            msg += 'Need at least %d\n' %m
829            msg += 'Alternatively, set smoothing parameter alpha to a small '
830            msg += 'positive value,\ne.g. 1.0e-3.'
831            raise msg
832
833
834
835        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
836        #FIXME: Should we store the result here for later use? (ON)
837
838
839    def fit_points(self, z, verbose=False):
840        """Like fit, but more robust when each point has two or more attributes
841        FIXME (Ole): The name fit_points doesn't carry any meaning
842        for me. How about something like fit_multiple or fit_columns?
843        """
844
845        try:
846            if verbose: print 'Solving penalised least_squares problem'
847            return self.fit(z)
848        except VectorShapeError, e:
849            # broadcasting is not supported.
850
851            #Convert input to Numeric arrays
852            from util import ensure_numeric
853            z = ensure_numeric(z, Float)
854
855            #Build n x m interpolation matrix
856            m = self.mesh.coordinates.shape[0] #Number of vertices
857            n = z.shape[1]                     #Number of data points
858
859            f = zeros((m,n), Float) #Resulting columns
860
861            for i in range(z.shape[1]):
862                f[:,i] = self.fit(z[:,i])
863
864            return f
865
866
867    def interpolate(self, f):
868        """Evaluate smooth surface f at data points implied in self.A.
869
870        The mesh values representing a smooth surface are
871        assumed to be specified in f. This argument could,
872        for example have been obtained from the method self.fit()
873
874        Pre Condition:
875          self.A has been initialised
876
877        Inputs:
878          f: Vector or array of data at the mesh vertices.
879          If f is an array, interpolation will be done for each column as
880          per underlying matrix-matrix multiplication
881
882        Output:
883          Interpolated values at data points implied in self.A
884
885        """
886
887        return self.A * f
888
889    def cull_outsiders(self, f):
890        pass
891
892
893
894
895class Interpolation_function:
896    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
897    which is interpolated from time series defined at vertices of
898    triangular mesh (such as those stored in sww files)
899
900    Let m be the number of vertices, n the number of triangles
901    and p the number of timesteps.
902
903    Mandatory input
904        time:               px1 array of monotonously increasing times (Float)
905        quantities:         Dictionary of pxm arrays or 1 pxm array (Float)
906       
907    Optional input:
908        quantity_names:     List of keys into the quantities dictionary
909        vertex_coordinates: mx2 array of coordinates (Float)
910        triangles:          nx3 array of indices into vertex_coordinates (Int)
911        interpolation_points: array of coordinates to be interpolated to
912        verbose:            Level of reporting
913   
914   
915    The quantities returned by the callable object are specified by
916    the list quantities which must contain the names of the
917    quantities to be returned and also reflect the order, e.g. for
918    the shallow water wave equation, on would have
919    quantities = ['stage', 'xmomentum', 'ymomentum']
920
921    The parameter interpolation_points decides at which points interpolated
922    quantities are to be computed whenever object is called.
923    If None, return average value
924    """
925
926   
927   
928    def __init__(self,
929                 time,
930                 quantities,
931                 quantity_names = None, 
932                 vertex_coordinates = None,
933                 triangles = None,
934                 interpolation_points = None,
935                 verbose = False):
936        """Initialise object and build spatial interpolation if required
937        """
938
939        from Numeric import array, zeros, Float, alltrue, concatenate,\
940             reshape, ArrayType
941
942
943        from util import mean, ensure_numeric
944        from config import time_format
945        import types
946
947
948
949        #Check temporal info
950        time = ensure_numeric(time)       
951        msg = 'Time must be a monotonuosly '
952        msg += 'increasing sequence %s' %time
953        assert alltrue(time[1:] - time[:-1] > 0 ), msg
954
955
956        #Check if quantities is a single array only
957        if type(quantities) != types.DictType:
958            quantities = ensure_numeric(quantities)
959            quantity_names = ['Attribute']
960
961            #Make it a dictionary
962            quantities = {quantity_names[0]: quantities}
963
964
965        #Use keys if no names are specified
966        if quantity_names is None:
967            quantity_names = quantities.keys()
968
969
970        #Check spatial info
971        if vertex_coordinates is None:
972            self.spatial = False
973        else:   
974            vertex_coordinates = ensure_numeric(vertex_coordinates)
975
976            assert triangles is not None, 'Triangles array must be specified'
977            triangles = ensure_numeric(triangles)
978            self.spatial = True           
979           
980
981 
982        #Save for use with statistics
983        self.quantity_names = quantity_names       
984        self.quantities = quantities       
985        self.vertex_coordinates = vertex_coordinates
986        self.interpolation_points = interpolation_points
987        self.T = time[:]  #Time assumed to be relative to starttime
988        self.index = 0    #Initial time index
989        self.precomputed_values = {}
990           
991
992
993        #Precomputed spatial interpolation if requested
994        if interpolation_points is not None:
995            if self.spatial is False:
996                raise 'Triangles and vertex_coordinates must be specified'
997           
998
999            try:
1000                self.interpolation_points =\
1001                                          ensure_numeric(self.interpolation_points)
1002            except:
1003                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1004                      'or a list of points\n'
1005                msg += 'I got: %s.' %( str(self.interpolation_points)[:60] + '...')
1006                raise msg
1007
1008
1009            for name in quantity_names:
1010                self.precomputed_values[name] =\
1011                                              zeros((len(self.T),
1012                                                     len(self.interpolation_points)),
1013                                                    Float)
1014
1015            #Build interpolator
1016            interpol = Interpolation(vertex_coordinates,
1017                                     triangles,
1018                                     point_coordinates = self.interpolation_points,
1019                                     alpha = 0,
1020                                     precrop = False, 
1021                                     verbose = verbose)
1022
1023            if verbose: print 'Interpolate'
1024            n = len(self.T)
1025            for i, t in enumerate(self.T):
1026                #Interpolate quantities at this timestep
1027                if verbose and i%((n+10)/10)==0:
1028                    print ' time step %d of %d' %(i, n)
1029                   
1030                for name in quantity_names:
1031                    self.precomputed_values[name][i, :] =\
1032                    interpol.interpolate(quantities[name][i,:])
1033
1034            #Report
1035            if verbose:
1036                print self.statistics()
1037                #self.print_statistics()
1038           
1039        else:
1040            #Store quantitites as is
1041            for name in quantity_names:
1042                self.precomputed_values[name] = quantities[name]
1043
1044
1045        #else:
1046        #    #Return an average, making this a time series
1047        #    for name in quantity_names:
1048        #        self.values[name] = zeros(len(self.T), Float)
1049        #
1050        #    if verbose: print 'Compute mean values'
1051        #    for i, t in enumerate(self.T):
1052        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1053        #        for name in quantity_names:
1054        #           self.values[name][i] = mean(quantities[name][i,:])
1055
1056
1057
1058
1059    def __repr__(self):
1060        #return 'Interpolation function (spation-temporal)'
1061        return self.statistics()
1062   
1063
1064    def __call__(self, t, point_id = None, x = None, y = None):
1065        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1066
1067        Inputs:
1068          t: time - Model time. Must lie within existing timesteps
1069          point_id: index of one of the preprocessed points.
1070          x, y:     Overrides location, point_id ignored
1071         
1072          If spatial info is present and all of x,y,point_id
1073          are None an exception is raised
1074                   
1075          If no spatial info is present, point_id and x,y arguments are ignored
1076          making f a function of time only.
1077
1078         
1079          FIXME: point_id could also be a slice
1080          FIXME: What if x and y are vectors?
1081          FIXME: What about f(x,y) without t?
1082        """
1083
1084        from math import pi, cos, sin, sqrt
1085        from Numeric import zeros, Float
1086        from util import mean       
1087
1088        if self.spatial is True:
1089            if point_id is None:
1090                if x is None or y is None:
1091                    msg = 'Either point_id or x and y must be specified'
1092                    raise msg
1093            else:
1094                if self.interpolation_points is None:
1095                    msg = 'Interpolation_function must be instantiated ' +\
1096                          'with a list of interpolation points before parameter ' +\
1097                          'point_id can be used'
1098                    raise msg
1099
1100
1101        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1102        msg += ' does not match model time: %s\n' %t
1103        if t < self.T[0]: raise msg
1104        if t > self.T[-1]: raise msg
1105
1106        oldindex = self.index #Time index
1107
1108        #Find current time slot
1109        while t > self.T[self.index]: self.index += 1
1110        while t < self.T[self.index]: self.index -= 1
1111
1112        if t == self.T[self.index]:
1113            #Protect against case where t == T[-1] (last time)
1114            # - also works in general when t == T[i]
1115            ratio = 0
1116        else:
1117            #t is now between index and index+1
1118            ratio = (t - self.T[self.index])/\
1119                    (self.T[self.index+1] - self.T[self.index])
1120
1121        #Compute interpolated values
1122        q = zeros(len(self.quantity_names), Float)
1123
1124        for i, name in enumerate(self.quantity_names):
1125            Q = self.precomputed_values[name]
1126
1127            if self.spatial is False:
1128                #If there is no spatial info               
1129                assert len(Q.shape) == 1
1130
1131                Q0 = Q[self.index]
1132                if ratio > 0: Q1 = Q[self.index+1]
1133
1134            else:
1135                if x is not None and y is not None:
1136                    #Interpolate to x, y
1137                   
1138                    raise 'x,y interpolation not yet implemented'
1139                else:
1140                    #Use precomputed point
1141                    Q0 = Q[self.index, point_id]
1142                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1143
1144            #Linear temporal interpolation   
1145            if ratio > 0:
1146                q[i] = Q0 + ratio*(Q1 - Q0)
1147            else:
1148                q[i] = Q0
1149
1150
1151        #Return vector of interpolated values
1152        #if len(q) == 1:
1153        #    return q[0]
1154        #else:
1155        #    return q
1156
1157
1158        #Return vector of interpolated values
1159        #FIXME:
1160        if self.spatial is True:
1161            return q
1162        else:
1163            #Replicate q according to x and y
1164            #This is e.g used for Wind_stress
1165            if x == None or y == None: 
1166                return q
1167            else:
1168                try:
1169                    N = len(x)
1170                except:
1171                    return q
1172                else:
1173                    from Numeric import ones, Float
1174                    #x is a vector - Create one constant column for each value
1175                    N = len(x)
1176                    assert len(y) == N, 'x and y must have same length'
1177                    res = []
1178                    for col in q:
1179                        res.append(col*ones(N, Float))
1180                       
1181                return res
1182
1183
1184    def statistics(self):
1185        """Output statistics about interpolation_function
1186        """
1187       
1188        vertex_coordinates = self.vertex_coordinates
1189        interpolation_points = self.interpolation_points               
1190        quantity_names = self.quantity_names
1191        quantities = self.quantities
1192        precomputed_values = self.precomputed_values                 
1193               
1194        x = vertex_coordinates[:,0]
1195        y = vertex_coordinates[:,1]               
1196
1197        str =  '------------------------------------------------\n'
1198        str += 'Interpolation_function (spation-temporal) statistics:\n'
1199        str += '  Extent:\n'
1200        str += '    x in [%f, %f], len(x) == %d\n'\
1201               %(min(x), max(x), len(x))
1202        str += '    y in [%f, %f], len(y) == %d\n'\
1203               %(min(y), max(y), len(y))
1204        str += '    t in [%f, %f], len(t) == %d\n'\
1205               %(min(self.T), max(self.T), len(self.T))
1206        str += '  Quantities:\n'
1207        for name in quantity_names:
1208            q = quantities[name][:].flat
1209            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1210
1211        if interpolation_points is not None:   
1212            str += '  Interpolation points (xi, eta):'\
1213                   ' number of points == %d\n' %interpolation_points.shape[0]
1214            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1215                                            max(interpolation_points[:,0]))
1216            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1217                                             max(interpolation_points[:,1]))
1218            str += '  Interpolated quantities (over all timesteps):\n'
1219       
1220            for name in quantity_names:
1221                q = precomputed_values[name][:].flat
1222                str += '    %s at interpolation points in [%f, %f]\n'\
1223                       %(name, min(q), max(q))
1224        str += '------------------------------------------------\n'
1225
1226        return str
1227
1228        #FIXME: Delete
1229        #print '------------------------------------------------'
1230        #print 'Interpolation_function statistics:'
1231        #print '  Extent:'
1232        #print '    x in [%f, %f], len(x) == %d'\
1233        #      %(min(x), max(x), len(x))
1234        #print '    y in [%f, %f], len(y) == %d'\
1235        #      %(min(y), max(y), len(y))
1236        #print '    t in [%f, %f], len(t) == %d'\
1237        #      %(min(self.T), max(self.T), len(self.T))
1238        #print '  Quantities:'
1239        #for name in quantity_names:
1240        #    q = quantities[name][:].flat
1241        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1242        #print '  Interpolation points (xi, eta):'\
1243        #      ' number of points == %d ' %interpolation_points.shape[0]
1244        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1245        #                             max(interpolation_points[:,0]))
1246        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1247        #                              max(interpolation_points[:,1]))
1248        #print '  Interpolated quantities (over all timesteps):'
1249        #
1250        #for name in quantity_names:
1251        #    q = precomputed_values[name][:].flat
1252        #    print '    %s at interpolation points in [%f, %f]'\
1253        #          %(name, min(q), max(q))
1254        #print '------------------------------------------------'
1255
1256
1257#-------------------------------------------------------------
1258if __name__ == "__main__":
1259    """
1260    Load in a mesh and data points with attributes.
1261    Fit the attributes to the mesh.
1262    Save a new mesh file.
1263    """
1264    import os, sys
1265    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha]"\
1266            %os.path.basename(sys.argv[0])
1267
1268    if len(sys.argv) < 4:
1269        print usage
1270    else:
1271        mesh_file = sys.argv[1]
1272        point_file = sys.argv[2]
1273        mesh_output_file = sys.argv[3]
1274
1275        expand_search = False
1276        if len(sys.argv) > 4:
1277            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1278                expand_search = True
1279            else:
1280                expand_search = False
1281
1282        verbose = False
1283        if len(sys.argv) > 5:
1284            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1285                verbose = False
1286            else:
1287                verbose = True
1288
1289        if len(sys.argv) > 6:
1290            alpha = sys.argv[6]
1291        else:
1292            alpha = DEFAULT_ALPHA
1293
1294        t0 = time.time()
1295        fit_to_mesh_file(mesh_file,
1296                         point_file,
1297                         mesh_output_file,
1298                         alpha,
1299                         verbose= verbose,
1300                         expand_search = expand_search)
1301
1302        print 'That took %.2f seconds' %(time.time()-t0)
1303
Note: See TracBrowser for help on using the repository browser.