source: inundation/pyvolution/least_squares.py @ 1888

Last change on this file since 1888 was 1888, checked in by duncan, 19 years ago

added test for checking exit status when using least_squares.py in the command line.

File size: 46.9 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203    interp = Interpolation(vertex_coordinates,
204                           triangles,
205                           point_coordinates,
206                           alpha = alpha,
207                           verbose = verbose,
208                           expand_search = expand_search,
209                           data_origin = data_origin,
210                           mesh_origin = mesh_origin,
211                           precrop = precrop)
212
213    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
214    return vertex_attributes
215
216
217
218def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
219                    verbose = False, reduction = 1, format = 'netcdf'):
220    """Fits attributes from pts file to MxN rectangular mesh
221
222    Read pts file and create rectangular mesh of resolution MxN such that
223    it covers all points specified in pts file.
224
225    FIXME: This may be a temporary function until we decide on
226    netcdf formats etc
227
228    FIXME: Uses elevation hardwired
229    """
230
231    import util, mesh_factory
232
233    if verbose: print 'Read pts'
234    points, attributes = util.read_xya(pts_name, format)
235
236    #Reduce number of points a bit
237    points = points[::reduction]
238    elevation = attributes['elevation']  #Must be elevation
239    elevation = elevation[::reduction]
240
241    if verbose: print 'Got %d data points' %len(points)
242
243    if verbose: print 'Create mesh'
244    #Find extent
245    max_x = min_x = points[0][0]
246    max_y = min_y = points[0][1]
247    for point in points[1:]:
248        x = point[0]
249        if x > max_x: max_x = x
250        if x < min_x: min_x = x
251        y = point[1]
252        if y > max_y: max_y = y
253        if y < min_y: min_y = y
254
255    #Create appropriate mesh
256    vertex_coordinates, triangles, boundary =\
257         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
258                                (min_x, min_y))
259
260    #Fit attributes to mesh
261    vertex_attributes = fit_to_mesh(vertex_coordinates,
262                        triangles,
263                        points,
264                        elevation, alpha=alpha, verbose=verbose)
265
266
267
268    return vertex_coordinates, triangles, boundary, vertex_attributes
269
270
271
272class Interpolation:
273
274    def __init__(self,
275                 vertex_coordinates,
276                 triangles,
277                 point_coordinates = None,
278                 alpha = None,
279                 verbose = False,
280                 expand_search = True,
281                 max_points_per_cell = 30,
282                 mesh_origin = None,
283                 data_origin = None,
284                 precrop = False):
285
286
287        """ Build interpolation matrix mapping from
288        function values at vertices to function values at data points
289
290        Inputs:
291
292          vertex_coordinates: List of coordinate pairs [xi, eta] of
293          points constituting mesh (or a an m x 2 Numeric array)
294          Points may appear multiple times
295          (e.g. if vertices have discontinuities)
296
297          triangles: List of 3-tuples (or a Numeric array) of
298          integers representing indices of all vertices in the mesh.
299
300          point_coordinates: List of coordinate pairs [x, y] of
301          data points (or an nx2 Numeric array)
302          If point_coordinates is absent, only smoothing matrix will
303          be built
304
305          alpha: Smoothing parameter
306
307          data_origin and mesh_origin are 3-tuples consisting of
308          UTM zone, easting and northing. If specified
309          point coordinates and vertex coordinates are assumed to be
310          relative to their respective origins.
311
312        """
313        from util import ensure_numeric
314
315        #Convert input to Numeric arrays
316        triangles = ensure_numeric(triangles, Int)
317        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
318
319        #Build underlying mesh
320        if verbose: print 'Building mesh'
321        #self.mesh = General_mesh(vertex_coordinates, triangles,
322        #FIXME: Trying the normal mesh while testing precrop,
323        #       The functionality of boundary_polygon is needed for that
324
325        #FIXME - geo ref does not have to go into mesh.
326        # Change the point co-ords to conform to the
327        # mesh co-ords early in the code
328        if mesh_origin == None:
329            geo = None
330        else:
331            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
332        self.mesh = Mesh(vertex_coordinates, triangles,
333                         geo_reference = geo)
334       
335        self.mesh.check_integrity()
336
337        self.data_origin = data_origin
338
339        self.point_indices = None
340
341        #Smoothing parameter
342        if alpha is None:
343            self.alpha = DEFAULT_ALPHA
344        else:   
345            self.alpha = alpha
346
347        #Build coefficient matrices
348        self.build_coefficient_matrix_B(point_coordinates,
349                                        verbose = verbose,
350                                        expand_search = expand_search,
351                                        max_points_per_cell =\
352                                        max_points_per_cell,
353                                        data_origin = data_origin,
354                                        precrop = precrop)
355
356
357    def set_point_coordinates(self, point_coordinates,
358                              data_origin = None):
359        """
360        A public interface to setting the point co-ordinates.
361        """
362        self.build_coefficient_matrix_B(point_coordinates, data_origin)
363
364    def build_coefficient_matrix_B(self, point_coordinates=None,
365                                   verbose = False, expand_search = True,
366                                   max_points_per_cell=30,
367                                   data_origin = None,
368                                   precrop = False):
369        """Build final coefficient matrix"""
370
371
372        if self.alpha <> 0:
373            if verbose: print 'Building smoothing matrix'
374            self.build_smoothing_matrix_D()
375
376        if point_coordinates is not None:
377
378            if verbose: print 'Building interpolation matrix'
379            self.build_interpolation_matrix_A(point_coordinates,
380                                              verbose = verbose,
381                                              expand_search = expand_search,
382                                              max_points_per_cell =\
383                                              max_points_per_cell,
384                                              data_origin = data_origin,
385                                              precrop = precrop)
386
387            if self.alpha <> 0:
388                self.B = self.AtA + self.alpha*self.D
389            else:
390                self.B = self.AtA
391
392            #Convert self.B matrix to CSR format for faster matrix vector
393            self.B = Sparse_CSR(self.B)
394
395    def build_interpolation_matrix_A(self, point_coordinates,
396                                     verbose = False, expand_search = True,
397                                     max_points_per_cell=30,
398                                     data_origin = None,
399                                     precrop = False):
400        """Build n x m interpolation matrix, where
401        n is the number of data points and
402        m is the number of basis functions phi_k (one per vertex)
403
404        This algorithm uses a quad tree data structure for fast binning of data points
405        origin is a 3-tuple consisting of UTM zone, easting and northing.
406        If specified coordinates are assumed to be relative to this origin.
407
408        This one will override any data_origin that may be specified in
409        interpolation instance
410
411        """
412
413
414        #FIXME (Ole): Check that this function is memeory efficient.
415        #6 million datapoints and 300000 basis functions
416        #causes out-of-memory situation
417        #First thing to check is whether there is room for self.A and self.AtA
418        #
419        #Maybe we need some sort of blocking
420
421        from quad import build_quadtree
422        from util import ensure_numeric
423
424        if data_origin is None:
425            data_origin = self.data_origin #Use the one from
426                                           #interpolation instance
427
428        #Convert input to Numeric arrays just in case.
429        point_coordinates = ensure_numeric(point_coordinates, Float)
430
431        #Keep track of discarded points (if any).
432        #This is only registered if precrop is True
433        self.cropped_points = False
434
435        #Shift data points to same origin as mesh (if specified)
436
437        #FIXME this will shift if there was no geo_ref.
438        #But all this should be removed anyhow.
439        #change coords before this point
440        mesh_origin = self.mesh.geo_reference.get_origin()
441        if point_coordinates is not None:
442            if data_origin is not None:
443                if mesh_origin is not None:
444
445                    #Transformation:
446                    #
447                    #Let x_0 be the reference point of the point coordinates
448                    #and xi_0 the reference point of the mesh.
449                    #
450                    #A point coordinate (x + x_0) is then made relative
451                    #to xi_0 by
452                    #
453                    # x_new = x + x_0 - xi_0
454                    #
455                    #and similarly for eta
456
457                    x_offset = data_origin[1] - mesh_origin[1]
458                    y_offset = data_origin[2] - mesh_origin[2]
459                else: #Shift back to a zero origin
460                    x_offset = data_origin[1]
461                    y_offset = data_origin[2]
462
463                point_coordinates[:,0] += x_offset
464                point_coordinates[:,1] += y_offset
465            else:
466                if mesh_origin is not None:
467                    #Use mesh origin for data points
468                    point_coordinates[:,0] -= mesh_origin[1]
469                    point_coordinates[:,1] -= mesh_origin[2]
470
471
472
473        #Remove points falling outside mesh boundary
474        #This reduced one example from 1356 seconds to 825 seconds
475        if precrop is True:
476            from Numeric import take
477            from util import inside_polygon
478
479            if verbose: print 'Getting boundary polygon'
480            P = self.mesh.get_boundary_polygon()
481
482            if verbose: print 'Getting indices inside mesh boundary'
483            indices = inside_polygon(point_coordinates, P, verbose = verbose)
484
485
486            if len(indices) != point_coordinates.shape[0]:
487                self.cropped_points = True
488                if verbose:
489                    print 'Done - %d points outside mesh have been cropped.'\
490                          %(point_coordinates.shape[0] - len(indices))
491
492            point_coordinates = take(point_coordinates, indices)
493            self.point_indices = indices
494
495
496
497
498        #Build n x m interpolation matrix
499        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
500        n = point_coordinates.shape[0]     #Nbr of data points
501
502        if verbose: print 'Number of datapoints: %d' %n
503        if verbose: print 'Number of basis functions: %d' %m
504
505        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
506        #However, Sparse_CSR does not have the same methods as Sparse yet
507        #The tests will reveal what needs to be done
508        #self.A = Sparse_CSR(Sparse(n,m))
509        #self.AtA = Sparse_CSR(Sparse(m,m))
510        self.A = Sparse(n,m)
511        self.AtA = Sparse(m,m)
512
513        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
514        root = build_quadtree(self.mesh,
515                              max_points_per_cell = max_points_per_cell)
516
517        #Compute matrix elements
518        for i in range(n):
519            #For each data_coordinate point
520
521            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
522            x = point_coordinates[i]
523
524            #Find vertices near x
525            candidate_vertices = root.search(x[0], x[1])
526            is_more_elements = True
527
528            element_found, sigma0, sigma1, sigma2, k = \
529                self.search_triangles_of_vertices(candidate_vertices, x)
530            while not element_found and is_more_elements and expand_search:
531                #if verbose: print 'Expanding search'
532                candidate_vertices, branch = root.expand_search()
533                if branch == []:
534                    # Searching all the verts from the root cell that haven't
535                    # been searched.  This is the last try
536                    element_found, sigma0, sigma1, sigma2, k = \
537                      self.search_triangles_of_vertices(candidate_vertices, x)
538                    is_more_elements = False
539                else:
540                    element_found, sigma0, sigma1, sigma2, k = \
541                      self.search_triangles_of_vertices(candidate_vertices, x)
542
543
544            #Update interpolation matrix A if necessary
545            if element_found is True:
546                #Assign values to matrix A
547
548                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
549                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
550                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
551
552                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
553                js     = [j0,j1,j2]
554
555                for j in js:
556                    self.A[i,j] = sigmas[j]
557                    for k in js:
558                        self.AtA[j,k] += sigmas[j]*sigmas[k]
559            else:
560                pass
561                #Ok if there is no triangle for datapoint
562                #(as in brute force version)
563                #raise 'Could not find triangle for point', x
564
565
566
567    def search_triangles_of_vertices(self, candidate_vertices, x):
568            #Find triangle containing x:
569            element_found = False
570
571            # This will be returned if element_found = False
572            sigma2 = -10.0
573            sigma0 = -10.0
574            sigma1 = -10.0
575            k = -10.0
576
577            #For all vertices in same cell as point x
578            for v in candidate_vertices:
579
580                #for each triangle id (k) which has v as a vertex
581                for k, _ in self.mesh.vertexlist[v]:
582
583                    #Get the three vertex_points of candidate triangle
584                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
585                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
586                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
587
588                    #print "PDSG - k", k
589                    #print "PDSG - xi0", xi0
590                    #print "PDSG - xi1", xi1
591                    #print "PDSG - xi2", xi2
592                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
593                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
594
595                    #Get the three normals
596                    n0 = self.mesh.get_normal(k, 0)
597                    n1 = self.mesh.get_normal(k, 1)
598                    n2 = self.mesh.get_normal(k, 2)
599
600
601                    #Compute interpolation
602                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
603                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
604                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
605
606                    #print "PDSG - sigma0", sigma0
607                    #print "PDSG - sigma1", sigma1
608                    #print "PDSG - sigma2", sigma2
609
610                    #FIXME: Maybe move out to test or something
611                    epsilon = 1.0e-6
612                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
613
614                    #Check that this triangle contains the data point
615
616                    #Sigmas can get negative within
617                    #machine precision on some machines (e.g nautilus)
618                    #Hence the small eps
619                    eps = 1.0e-15
620                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
621                        element_found = True
622                        break
623
624                if element_found is True:
625                    #Don't look for any other triangle
626                    break
627            return element_found, sigma0, sigma1, sigma2, k
628
629
630
631    def build_interpolation_matrix_A_brute(self, point_coordinates):
632        """Build n x m interpolation matrix, where
633        n is the number of data points and
634        m is the number of basis functions phi_k (one per vertex)
635
636        This is the brute force which is too slow for large problems,
637        but could be used for testing
638        """
639
640        from util import ensure_numeric
641
642        #Convert input to Numeric arrays
643        point_coordinates = ensure_numeric(point_coordinates, Float)
644
645        #Build n x m interpolation matrix
646        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
647        n = point_coordinates.shape[0]     #Nbr of data points
648
649        self.A = Sparse(n,m)
650        self.AtA = Sparse(m,m)
651
652        #Compute matrix elements
653        for i in range(n):
654            #For each data_coordinate point
655
656            x = point_coordinates[i]
657            element_found = False
658            k = 0
659            while not element_found and k < len(self.mesh):
660                #For each triangle (brute force)
661                #FIXME: Real algorithm should only visit relevant triangles
662
663                #Get the three vertex_points
664                xi0 = self.mesh.get_vertex_coordinate(k, 0)
665                xi1 = self.mesh.get_vertex_coordinate(k, 1)
666                xi2 = self.mesh.get_vertex_coordinate(k, 2)
667
668                #Get the three normals
669                n0 = self.mesh.get_normal(k, 0)
670                n1 = self.mesh.get_normal(k, 1)
671                n2 = self.mesh.get_normal(k, 2)
672
673                #Compute interpolation
674                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
675                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
676                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
677
678                #FIXME: Maybe move out to test or something
679                epsilon = 1.0e-6
680                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
681
682                #Check that this triangle contains data point
683                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
684                    element_found = True
685                    #Assign values to matrix A
686
687                    j0 = self.mesh.triangles[k,0] #Global vertex id
688                    #self.A[i, j0] = sigma0
689
690                    j1 = self.mesh.triangles[k,1] #Global vertex id
691                    #self.A[i, j1] = sigma1
692
693                    j2 = self.mesh.triangles[k,2] #Global vertex id
694                    #self.A[i, j2] = sigma2
695
696                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
697                    js     = [j0,j1,j2]
698
699                    for j in js:
700                        self.A[i,j] = sigmas[j]
701                        for k in js:
702                            self.AtA[j,k] += sigmas[j]*sigmas[k]
703                k = k+1
704
705
706
707    def get_A(self):
708        return self.A.todense()
709
710    def get_B(self):
711        return self.B.todense()
712
713    def get_D(self):
714        return self.D.todense()
715
716        #FIXME: Remember to re-introduce the 1/n factor in the
717        #interpolation term
718
719    def build_smoothing_matrix_D(self):
720        """Build m x m smoothing matrix, where
721        m is the number of basis functions phi_k (one per vertex)
722
723        The smoothing matrix is defined as
724
725        D = D1 + D2
726
727        where
728
729        [D1]_{k,l} = \int_\Omega
730           \frac{\partial \phi_k}{\partial x}
731           \frac{\partial \phi_l}{\partial x}\,
732           dx dy
733
734        [D2]_{k,l} = \int_\Omega
735           \frac{\partial \phi_k}{\partial y}
736           \frac{\partial \phi_l}{\partial y}\,
737           dx dy
738
739
740        The derivatives \frac{\partial \phi_k}{\partial x},
741        \frac{\partial \phi_k}{\partial x} for a particular triangle
742        are obtained by computing the gradient a_k, b_k for basis function k
743        """
744
745        #FIXME: algorithm might be optimised by computing local 9x9
746        #"element stiffness matrices:
747
748        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
749
750        self.D = Sparse(m,m)
751
752        #For each triangle compute contributions to D = D1+D2
753        for i in range(len(self.mesh)):
754
755            #Get area
756            area = self.mesh.areas[i]
757
758            #Get global vertex indices
759            v0 = self.mesh.triangles[i,0]
760            v1 = self.mesh.triangles[i,1]
761            v2 = self.mesh.triangles[i,2]
762
763            #Get the three vertex_points
764            xi0 = self.mesh.get_vertex_coordinate(i, 0)
765            xi1 = self.mesh.get_vertex_coordinate(i, 1)
766            xi2 = self.mesh.get_vertex_coordinate(i, 2)
767
768            #Compute gradients for each vertex
769            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
770                              1, 0, 0)
771
772            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
773                              0, 1, 0)
774
775            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
776                              0, 0, 1)
777
778            #Compute diagonal contributions
779            self.D[v0,v0] += (a0*a0 + b0*b0)*area
780            self.D[v1,v1] += (a1*a1 + b1*b1)*area
781            self.D[v2,v2] += (a2*a2 + b2*b2)*area
782
783            #Compute contributions for basis functions sharing edges
784            e01 = (a0*a1 + b0*b1)*area
785            self.D[v0,v1] += e01
786            self.D[v1,v0] += e01
787
788            e12 = (a1*a2 + b1*b2)*area
789            self.D[v1,v2] += e12
790            self.D[v2,v1] += e12
791
792            e20 = (a2*a0 + b2*b0)*area
793            self.D[v2,v0] += e20
794            self.D[v0,v2] += e20
795
796
797    def fit(self, z):
798        """Fit a smooth surface to given 1d array of data points z.
799
800        The smooth surface is computed at each vertex in the underlying
801        mesh using the formula given in the module doc string.
802
803        Pre Condition:
804          self.A, self.AtA and self.B have been initialised
805
806        Inputs:
807          z: Single 1d vector or array of data at the point_coordinates.
808        """
809
810        #Convert input to Numeric arrays
811        from util import ensure_numeric
812        z = ensure_numeric(z, Float)
813
814        if len(z.shape) > 1 :
815            raise VectorShapeError, 'Can only deal with 1d data vector'
816
817        if self.point_indices is not None:
818            #Remove values for any points that were outside mesh
819            z = take(z, self.point_indices)
820
821        #Compute right hand side based on data
822        Atz = self.A.trans_mult(z)
823
824
825        #Check sanity
826        n, m = self.A.shape
827        if n<m and self.alpha == 0.0:
828            msg = 'ERROR (least_squares): Too few data points\n'
829            msg += 'There are only %d data points and alpha == 0. ' %n
830            msg += 'Need at least %d\n' %m
831            msg += 'Alternatively, set smoothing parameter alpha to a small '
832            msg += 'positive value,\ne.g. 1.0e-3.'
833            raise msg
834
835
836
837        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
838        #FIXME: Should we store the result here for later use? (ON)
839
840
841    def fit_points(self, z, verbose=False):
842        """Like fit, but more robust when each point has two or more attributes
843        FIXME (Ole): The name fit_points doesn't carry any meaning
844        for me. How about something like fit_multiple or fit_columns?
845        """
846
847        try:
848            if verbose: print 'Solving penalised least_squares problem'
849            return self.fit(z)
850        except VectorShapeError, e:
851            # broadcasting is not supported.
852
853            #Convert input to Numeric arrays
854            from util import ensure_numeric
855            z = ensure_numeric(z, Float)
856
857            #Build n x m interpolation matrix
858            m = self.mesh.coordinates.shape[0] #Number of vertices
859            n = z.shape[1]                     #Number of data points
860
861            f = zeros((m,n), Float) #Resulting columns
862
863            for i in range(z.shape[1]):
864                f[:,i] = self.fit(z[:,i])
865
866            return f
867
868
869    def interpolate(self, f):
870        """Evaluate smooth surface f at data points implied in self.A.
871
872        The mesh values representing a smooth surface are
873        assumed to be specified in f. This argument could,
874        for example have been obtained from the method self.fit()
875
876        Pre Condition:
877          self.A has been initialised
878
879        Inputs:
880          f: Vector or array of data at the mesh vertices.
881          If f is an array, interpolation will be done for each column as
882          per underlying matrix-matrix multiplication
883
884        Output:
885          Interpolated values at data points implied in self.A
886
887        """
888
889        return self.A * f
890
891    def cull_outsiders(self, f):
892        pass
893
894
895
896
897class Interpolation_function:
898    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
899    which is interpolated from time series defined at vertices of
900    triangular mesh (such as those stored in sww files)
901
902    Let m be the number of vertices, n the number of triangles
903    and p the number of timesteps.
904
905    Mandatory input
906        time:               px1 array of monotonously increasing times (Float)
907        quantities:         Dictionary of pxm arrays or 1 pxm array (Float)
908       
909    Optional input:
910        quantity_names:     List of keys into the quantities dictionary
911        vertex_coordinates: mx2 array of coordinates (Float)
912        triangles:          nx3 array of indices into vertex_coordinates (Int)
913        interpolation_points: array of coordinates to be interpolated to
914        verbose:            Level of reporting
915   
916   
917    The quantities returned by the callable object are specified by
918    the list quantities which must contain the names of the
919    quantities to be returned and also reflect the order, e.g. for
920    the shallow water wave equation, on would have
921    quantities = ['stage', 'xmomentum', 'ymomentum']
922
923    The parameter interpolation_points decides at which points interpolated
924    quantities are to be computed whenever object is called.
925    If None, return average value
926    """
927
928   
929   
930    def __init__(self,
931                 time,
932                 quantities,
933                 quantity_names = None, 
934                 vertex_coordinates = None,
935                 triangles = None,
936                 interpolation_points = None,
937                 verbose = False):
938        """Initialise object and build spatial interpolation if required
939        """
940
941        from Numeric import array, zeros, Float, alltrue, concatenate,\
942             reshape, ArrayType
943
944
945        from util import mean, ensure_numeric
946        from config import time_format
947        import types
948
949
950
951        #Check temporal info
952        time = ensure_numeric(time)       
953        msg = 'Time must be a monotonuosly '
954        msg += 'increasing sequence %s' %time
955        assert alltrue(time[1:] - time[:-1] > 0 ), msg
956
957
958        #Check if quantities is a single array only
959        if type(quantities) != types.DictType:
960            quantities = ensure_numeric(quantities)
961            quantity_names = ['Attribute']
962
963            #Make it a dictionary
964            quantities = {quantity_names[0]: quantities}
965
966
967        #Use keys if no names are specified
968        if quantity_names is None:
969            quantity_names = quantities.keys()
970
971
972        #Check spatial info
973        if vertex_coordinates is None:
974            self.spatial = False
975        else:   
976            vertex_coordinates = ensure_numeric(vertex_coordinates)
977
978            assert triangles is not None, 'Triangles array must be specified'
979            triangles = ensure_numeric(triangles)
980            self.spatial = True           
981           
982
983 
984        #Save for use with statistics
985        self.quantity_names = quantity_names       
986        self.quantities = quantities       
987        self.vertex_coordinates = vertex_coordinates
988        self.interpolation_points = interpolation_points
989        self.T = time[:]  #Time assumed to be relative to starttime
990        self.index = 0    #Initial time index
991        self.precomputed_values = {}
992           
993
994
995        #Precomputed spatial interpolation if requested
996        if interpolation_points is not None:
997            if self.spatial is False:
998                raise 'Triangles and vertex_coordinates must be specified'
999           
1000
1001            try:
1002                self.interpolation_points =\
1003                                          ensure_numeric(self.interpolation_points)
1004            except:
1005                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1006                      'or a list of points\n'
1007                msg += 'I got: %s.' %( str(self.interpolation_points)[:60] + '...')
1008                raise msg
1009
1010
1011            for name in quantity_names:
1012                self.precomputed_values[name] =\
1013                                              zeros((len(self.T),
1014                                                     len(self.interpolation_points)),
1015                                                    Float)
1016
1017            #Build interpolator
1018            interpol = Interpolation(vertex_coordinates,
1019                                     triangles,
1020                                     point_coordinates = self.interpolation_points,
1021                                     alpha = 0,
1022                                     precrop = False, 
1023                                     verbose = verbose)
1024
1025            if verbose: print 'Interpolate'
1026            n = len(self.T)
1027            for i, t in enumerate(self.T):
1028                #Interpolate quantities at this timestep
1029                if verbose and i%((n+10)/10)==0:
1030                    print ' time step %d of %d' %(i, n)
1031                   
1032                for name in quantity_names:
1033                    self.precomputed_values[name][i, :] =\
1034                    interpol.interpolate(quantities[name][i,:])
1035
1036            #Report
1037            if verbose:
1038                print self.statistics()
1039                #self.print_statistics()
1040           
1041        else:
1042            #Store quantitites as is
1043            for name in quantity_names:
1044                self.precomputed_values[name] = quantities[name]
1045
1046
1047        #else:
1048        #    #Return an average, making this a time series
1049        #    for name in quantity_names:
1050        #        self.values[name] = zeros(len(self.T), Float)
1051        #
1052        #    if verbose: print 'Compute mean values'
1053        #    for i, t in enumerate(self.T):
1054        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1055        #        for name in quantity_names:
1056        #           self.values[name][i] = mean(quantities[name][i,:])
1057
1058
1059
1060
1061    def __repr__(self):
1062        #return 'Interpolation function (spation-temporal)'
1063        return self.statistics()
1064   
1065
1066    def __call__(self, t, point_id = None, x = None, y = None):
1067        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1068
1069        Inputs:
1070          t: time - Model time. Must lie within existing timesteps
1071          point_id: index of one of the preprocessed points.
1072          x, y:     Overrides location, point_id ignored
1073         
1074          If spatial info is present and all of x,y,point_id
1075          are None an exception is raised
1076                   
1077          If no spatial info is present, point_id and x,y arguments are ignored
1078          making f a function of time only.
1079
1080         
1081          FIXME: point_id could also be a slice
1082          FIXME: What if x and y are vectors?
1083          FIXME: What about f(x,y) without t?
1084        """
1085
1086        from math import pi, cos, sin, sqrt
1087        from Numeric import zeros, Float
1088        from util import mean       
1089
1090        if self.spatial is True:
1091            if point_id is None:
1092                if x is None or y is None:
1093                    msg = 'Either point_id or x and y must be specified'
1094                    raise msg
1095            else:
1096                if self.interpolation_points is None:
1097                    msg = 'Interpolation_function must be instantiated ' +\
1098                          'with a list of interpolation points before parameter ' +\
1099                          'point_id can be used'
1100                    raise msg
1101
1102
1103        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1104        msg += ' does not match model time: %s\n' %t
1105        if t < self.T[0]: raise msg
1106        if t > self.T[-1]: raise msg
1107
1108        oldindex = self.index #Time index
1109
1110        #Find current time slot
1111        while t > self.T[self.index]: self.index += 1
1112        while t < self.T[self.index]: self.index -= 1
1113
1114        if t == self.T[self.index]:
1115            #Protect against case where t == T[-1] (last time)
1116            # - also works in general when t == T[i]
1117            ratio = 0
1118        else:
1119            #t is now between index and index+1
1120            ratio = (t - self.T[self.index])/\
1121                    (self.T[self.index+1] - self.T[self.index])
1122
1123        #Compute interpolated values
1124        q = zeros(len(self.quantity_names), Float)
1125
1126        for i, name in enumerate(self.quantity_names):
1127            Q = self.precomputed_values[name]
1128
1129            if self.spatial is False:
1130                #If there is no spatial info               
1131                assert len(Q.shape) == 1
1132
1133                Q0 = Q[self.index]
1134                if ratio > 0: Q1 = Q[self.index+1]
1135
1136            else:
1137                if x is not None and y is not None:
1138                    #Interpolate to x, y
1139                   
1140                    raise 'x,y interpolation not yet implemented'
1141                else:
1142                    #Use precomputed point
1143                    Q0 = Q[self.index, point_id]
1144                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1145
1146            #Linear temporal interpolation   
1147            if ratio > 0:
1148                q[i] = Q0 + ratio*(Q1 - Q0)
1149            else:
1150                q[i] = Q0
1151
1152
1153        #Return vector of interpolated values
1154        #if len(q) == 1:
1155        #    return q[0]
1156        #else:
1157        #    return q
1158
1159
1160        #Return vector of interpolated values
1161        #FIXME:
1162        if self.spatial is True:
1163            return q
1164        else:
1165            #Replicate q according to x and y
1166            #This is e.g used for Wind_stress
1167            if x == None or y == None: 
1168                return q
1169            else:
1170                try:
1171                    N = len(x)
1172                except:
1173                    return q
1174                else:
1175                    from Numeric import ones, Float
1176                    #x is a vector - Create one constant column for each value
1177                    N = len(x)
1178                    assert len(y) == N, 'x and y must have same length'
1179                    res = []
1180                    for col in q:
1181                        res.append(col*ones(N, Float))
1182                       
1183                return res
1184
1185
1186    def statistics(self):
1187        """Output statistics about interpolation_function
1188        """
1189       
1190        vertex_coordinates = self.vertex_coordinates
1191        interpolation_points = self.interpolation_points               
1192        quantity_names = self.quantity_names
1193        quantities = self.quantities
1194        precomputed_values = self.precomputed_values                 
1195               
1196        x = vertex_coordinates[:,0]
1197        y = vertex_coordinates[:,1]               
1198
1199        str =  '------------------------------------------------\n'
1200        str += 'Interpolation_function (spation-temporal) statistics:\n'
1201        str += '  Extent:\n'
1202        str += '    x in [%f, %f], len(x) == %d\n'\
1203               %(min(x), max(x), len(x))
1204        str += '    y in [%f, %f], len(y) == %d\n'\
1205               %(min(y), max(y), len(y))
1206        str += '    t in [%f, %f], len(t) == %d\n'\
1207               %(min(self.T), max(self.T), len(self.T))
1208        str += '  Quantities:\n'
1209        for name in quantity_names:
1210            q = quantities[name][:].flat
1211            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1212
1213        if interpolation_points is not None:   
1214            str += '  Interpolation points (xi, eta):'\
1215                   ' number of points == %d\n' %interpolation_points.shape[0]
1216            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1217                                            max(interpolation_points[:,0]))
1218            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1219                                             max(interpolation_points[:,1]))
1220            str += '  Interpolated quantities (over all timesteps):\n'
1221       
1222            for name in quantity_names:
1223                q = precomputed_values[name][:].flat
1224                str += '    %s at interpolation points in [%f, %f]\n'\
1225                       %(name, min(q), max(q))
1226        str += '------------------------------------------------\n'
1227
1228        return str
1229
1230        #FIXME: Delete
1231        #print '------------------------------------------------'
1232        #print 'Interpolation_function statistics:'
1233        #print '  Extent:'
1234        #print '    x in [%f, %f], len(x) == %d'\
1235        #      %(min(x), max(x), len(x))
1236        #print '    y in [%f, %f], len(y) == %d'\
1237        #      %(min(y), max(y), len(y))
1238        #print '    t in [%f, %f], len(t) == %d'\
1239        #      %(min(self.T), max(self.T), len(self.T))
1240        #print '  Quantities:'
1241        #for name in quantity_names:
1242        #    q = quantities[name][:].flat
1243        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1244        #print '  Interpolation points (xi, eta):'\
1245        #      ' number of points == %d ' %interpolation_points.shape[0]
1246        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1247        #                             max(interpolation_points[:,0]))
1248        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1249        #                              max(interpolation_points[:,1]))
1250        #print '  Interpolated quantities (over all timesteps):'
1251        #
1252        #for name in quantity_names:
1253        #    q = precomputed_values[name][:].flat
1254        #    print '    %s at interpolation points in [%f, %f]'\
1255        #          %(name, min(q), max(q))
1256        #print '------------------------------------------------'
1257
1258
1259#-------------------------------------------------------------
1260if __name__ == "__main__":
1261    """
1262    Load in a mesh and data points with attributes.
1263    Fit the attributes to the mesh.
1264    Save a new mesh file.
1265    """
1266    import os, sys
1267    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1268            %os.path.basename(sys.argv[0])
1269
1270    if len(sys.argv) < 4:
1271        print usage
1272    else:
1273        mesh_file = sys.argv[1]
1274        point_file = sys.argv[2]
1275        mesh_output_file = sys.argv[3]
1276
1277        expand_search = False
1278        if len(sys.argv) > 4:
1279            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1280                expand_search = True
1281            else:
1282                expand_search = False
1283
1284        verbose = False
1285        if len(sys.argv) > 5:
1286            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1287                verbose = False
1288            else:
1289                verbose = True
1290
1291        if len(sys.argv) > 6:
1292            alpha = sys.argv[6]
1293        else:
1294            alpha = DEFAULT_ALPHA
1295
1296        # This is used more for testing
1297        if len(sys.argv) > 7:
1298            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1299                display_errors = False
1300            else:
1301                display_errors = True
1302           
1303        t0 = time.time()
1304        try:
1305            fit_to_mesh_file(mesh_file,
1306                         point_file,
1307                         mesh_output_file,
1308                         alpha,
1309                         verbose= verbose,
1310                         expand_search = expand_search,
1311                         display_errors = display_errors)
1312        except IOError,e:
1313            import sys; sys.exit(1)
1314
1315        print 'That took %.2f seconds' %(time.time()-t0)
1316
Note: See TracBrowser for help on using the repository browser.