source: inundation/pyvolution/least_squares.py @ 1891

Last change on this file since 1891 was 1891, checked in by duncan, 18 years ago

checking in least squares changes, ticket 8

File size: 47.6 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType
25from mesh import Mesh
26
27from Numeric import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
28from sparse import Sparse, Sparse_CSR
29from cg_solve import conjugate_gradient, VectorShapeError
30
31from coordinate_transforms.geo_reference import Geo_reference
32
33import time
34
35
36try:
37    from util import gradient
38except ImportError, e:
39    #FIXME reduce the dependency of modules in pyvolution
40    # Have util in a dir, working like load_mesh, and get rid of this
41    def gradient(x0, y0, x1, y1, x2, y2, q0, q1, q2):
42        """
43        """
44
45        det = (y2-y0)*(x1-x0) - (y1-y0)*(x2-x0)
46        a = (y2-y0)*(q1-q0) - (y1-y0)*(q2-q0)
47        a /= det
48
49        b = (x1-x0)*(q2-q0) - (x2-x0)*(q1-q0)
50        b /= det
51
52        return a, b
53
54
55DEFAULT_ALPHA = 0.001
56
57def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
58                     alpha=DEFAULT_ALPHA, verbose= False,
59                     expand_search = False,
60                     data_origin = None,
61                     mesh_origin = None,
62                     precrop = False,
63                     display_errors = True):
64    """
65    Given a mesh file (tsh) and a point attribute file (xya), fit
66    point attributes to the mesh and write a mesh file with the
67    results.
68
69
70    If data_origin is not None it is assumed to be
71    a 3-tuple with geo referenced
72    UTM coordinates (zone, easting, northing)
73
74    NOTE: Throws IOErrors, for a variety of file problems.
75   
76    mesh_origin is the same but refers to the input tsh file.
77    FIXME: When the tsh format contains it own origin, these parameters can go.
78    FIXME: And both origins should be obtained from the specified files.
79    """
80
81    from load_mesh.loadASCII import import_mesh_file, \
82                 import_points_file, export_mesh_file, \
83                 concatinate_attributelist
84
85
86    try:
87        mesh_dict = import_mesh_file(mesh_file)
88    except IOError,e:
89        if display_errors:
90            print "Could not load bad file. ", e
91        raise IOError  #Re-raise exception
92       
93    vertex_coordinates = mesh_dict['vertices']
94    triangles = mesh_dict['triangles']
95    if type(mesh_dict['vertex_attributes']) == ArrayType:
96        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
97    else:
98        old_point_attributes = mesh_dict['vertex_attributes']
99
100    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
101        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
102    else:
103        old_title_list = mesh_dict['vertex_attribute_titles']
104
105    if verbose: print 'tsh file %s loaded' %mesh_file
106
107    # load in the .pts file
108    try:
109        point_dict = import_points_file(point_file, verbose=verbose)
110    except IOError,e:
111        if display_errors:
112            print "Could not load bad file. ", e
113        raise IOError  #Re-raise exception 
114
115    point_coordinates = point_dict['pointlist']
116    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
117
118    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
119        data_origin = point_dict['geo_reference'].get_origin()
120    else:
121        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
122
123    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
124        mesh_origin = mesh_dict['geo_reference'].get_origin()
125    else:
126        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
127
128    if verbose: print "points file loaded"
129    if verbose:print "fitting to mesh"
130    f = fit_to_mesh(vertex_coordinates,
131                    triangles,
132                    point_coordinates,
133                    point_attributes,
134                    alpha = alpha,
135                    verbose = verbose,
136                    expand_search = expand_search,
137                    data_origin = data_origin,
138                    mesh_origin = mesh_origin,
139                    precrop = precrop)
140    if verbose: print "finished fitting to mesh"
141
142    # convert array to list of lists
143    new_point_attributes = f.tolist()
144    #FIXME have this overwrite attributes with the same title - DSG
145    #Put the newer attributes last
146    if old_title_list <> []:
147        old_title_list.extend(title_list)
148        #FIXME can this be done a faster way? - DSG
149        for i in range(len(old_point_attributes)):
150            old_point_attributes[i].extend(new_point_attributes[i])
151        mesh_dict['vertex_attributes'] = old_point_attributes
152        mesh_dict['vertex_attribute_titles'] = old_title_list
153    else:
154        mesh_dict['vertex_attributes'] = new_point_attributes
155        mesh_dict['vertex_attribute_titles'] = title_list
156
157    #FIXME (Ole): Remember to output mesh_origin as well
158    if verbose: print "exporting to file ",mesh_output_file
159
160    try:
161        export_mesh_file(mesh_output_file, mesh_dict)
162    except IOError,e:
163        if display_errors:
164            print "Could not write file. ", e
165        raise IOError
166
167def fit_to_mesh(vertex_coordinates,
168                triangles,
169                point_coordinates,
170                point_attributes,
171                alpha = DEFAULT_ALPHA,
172                verbose = False,
173                expand_search = False,
174                data_origin = None,
175                mesh_origin = None,
176                precrop = False):
177    """
178    Fit a smooth surface to a triangulation,
179    given data points with attributes.
180
181
182        Inputs:
183
184          vertex_coordinates: List of coordinate pairs [xi, eta] of points
185          constituting mesh (or a an m x 2 Numeric array)
186
187          triangles: List of 3-tuples (or a Numeric array) of
188          integers representing indices of all vertices in the mesh.
189
190          point_coordinates: List of coordinate pairs [x, y] of data points
191          (or an nx2 Numeric array)
192
193          alpha: Smoothing parameter.
194
195          point_attributes: Vector or array of data at the point_coordinates.
196
197          data_origin and mesh_origin are 3-tuples consisting of
198          UTM zone, easting and northing. If specified
199          point coordinates and vertex coordinates are assumed to be
200          relative to their respective origins.
201
202    """
203    interp = Interpolation(vertex_coordinates,
204                           triangles,
205                           point_coordinates,
206                           alpha = alpha,
207                           verbose = verbose,
208                           expand_search = expand_search,
209                           data_origin = data_origin,
210                           mesh_origin = mesh_origin,
211                           precrop = precrop)
212
213    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
214    return vertex_attributes
215
216
217
218def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
219                    verbose = False, reduction = 1, format = 'netcdf'):
220    """Fits attributes from pts file to MxN rectangular mesh
221
222    Read pts file and create rectangular mesh of resolution MxN such that
223    it covers all points specified in pts file.
224
225    FIXME: This may be a temporary function until we decide on
226    netcdf formats etc
227
228    FIXME: Uses elevation hardwired
229    """
230
231    import util, mesh_factory
232
233    if verbose: print 'Read pts'
234    points, attributes = util.read_xya(pts_name, format)
235
236    #Reduce number of points a bit
237    points = points[::reduction]
238    elevation = attributes['elevation']  #Must be elevation
239    elevation = elevation[::reduction]
240
241    if verbose: print 'Got %d data points' %len(points)
242
243    if verbose: print 'Create mesh'
244    #Find extent
245    max_x = min_x = points[0][0]
246    max_y = min_y = points[0][1]
247    for point in points[1:]:
248        x = point[0]
249        if x > max_x: max_x = x
250        if x < min_x: min_x = x
251        y = point[1]
252        if y > max_y: max_y = y
253        if y < min_y: min_y = y
254
255    #Create appropriate mesh
256    vertex_coordinates, triangles, boundary =\
257         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
258                                (min_x, min_y))
259
260    #Fit attributes to mesh
261    vertex_attributes = fit_to_mesh(vertex_coordinates,
262                        triangles,
263                        points,
264                        elevation, alpha=alpha, verbose=verbose)
265
266
267
268    return vertex_coordinates, triangles, boundary, vertex_attributes
269
270
271
272class Interpolation:
273
274    def __init__(self,
275                 vertex_coordinates,
276                 triangles,
277                 point_coordinates = None,
278                 alpha = None,
279                 verbose = False,
280                 expand_search = True,
281                 interp_only = False,
282                 max_points_per_cell = 30,
283                 mesh_origin = None,
284                 data_origin = None,
285                 precrop = False):
286
287
288        """ Build interpolation matrix mapping from
289        function values at vertices to function values at data points
290
291        Inputs:
292
293          vertex_coordinates: List of coordinate pairs [xi, eta] of
294          points constituting mesh (or a an m x 2 Numeric array)
295          Points may appear multiple times
296          (e.g. if vertices have discontinuities)
297
298          triangles: List of 3-tuples (or a Numeric array) of
299          integers representing indices of all vertices in the mesh.
300
301          point_coordinates: List of coordinate pairs [x, y] of
302          data points (or an nx2 Numeric array)
303          If point_coordinates is absent, only smoothing matrix will
304          be built
305
306          alpha: Smoothing parameter
307
308          data_origin and mesh_origin are 3-tuples consisting of
309          UTM zone, easting and northing. If specified
310          point coordinates and vertex coordinates are assumed to be
311          relative to their respective origins.
312
313        """
314        from util import ensure_numeric
315
316        #Convert input to Numeric arrays
317        triangles = ensure_numeric(triangles, Int)
318        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
319
320        #Build underlying mesh
321        if verbose: print 'Building mesh'
322        #self.mesh = General_mesh(vertex_coordinates, triangles,
323        #FIXME: Trying the normal mesh while testing precrop,
324        #       The functionality of boundary_polygon is needed for that
325
326        #FIXME - geo ref does not have to go into mesh.
327        # Change the point co-ords to conform to the
328        # mesh co-ords early in the code
329        if mesh_origin == None:
330            geo = None
331        else:
332            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
333        self.mesh = Mesh(vertex_coordinates, triangles,
334                         geo_reference = geo)
335       
336        self.mesh.check_integrity()
337
338        self.data_origin = data_origin
339
340        self.point_indices = None
341
342        #Smoothing parameter
343        if alpha is None:
344            self.alpha = DEFAULT_ALPHA
345        else:   
346            self.alpha = alpha
347
348
349        if point_coordinates is not None:
350            if verbose: print 'Building interpolation matrix'
351            self.build_interpolation_matrix_A(point_coordinates,
352                                              verbose = verbose,
353                                              expand_search = expand_search,
354                                              interp_only = interp_only, 
355                                              max_points_per_cell =\
356                                              max_points_per_cell,
357                                              data_origin = data_origin,
358                                              precrop = precrop)
359        #Build coefficient matrices
360        if interp_only == False:
361            self.build_coefficient_matrix_B(point_coordinates,
362                                        verbose = verbose,
363                                        expand_search = expand_search,
364                                        max_points_per_cell =\
365                                        max_points_per_cell,
366                                        data_origin = data_origin,
367                                        precrop = precrop)
368
369
370    def set_point_coordinates(self, point_coordinates,
371                              data_origin = None,
372                              verbose = False,
373                              precrop = True):
374        """
375        A public interface to setting the point co-ordinates.
376        """
377        if point_coordinates is not None:
378            if verbose: print 'Building interpolation matrix'
379            self.build_interpolation_matrix_A(point_coordinates,
380                                              verbose = verbose,
381                                              data_origin = data_origin,
382                                              precrop = precrop)
383        self.build_coefficient_matrix_B(point_coordinates, data_origin)
384
385    def build_coefficient_matrix_B(self, point_coordinates=None,
386                                   verbose = False, expand_search = True,
387                                   max_points_per_cell=30,
388                                   data_origin = None,
389                                   precrop = False):
390        """Build final coefficient matrix"""
391
392
393        if self.alpha <> 0:
394            if verbose: print 'Building smoothing matrix'
395            self.build_smoothing_matrix_D()
396
397        if point_coordinates is not None:
398            if self.alpha <> 0:
399                self.B = self.AtA + self.alpha*self.D
400            else:
401                self.B = self.AtA
402
403            #Convert self.B matrix to CSR format for faster matrix vector
404            self.B = Sparse_CSR(self.B)
405
406    def build_interpolation_matrix_A(self, point_coordinates,
407                                     verbose = False, expand_search = True,
408                                     max_points_per_cell=30,
409                                     data_origin = None,
410                                     precrop = False,
411                                     interp_only = False):
412        """Build n x m interpolation matrix, where
413        n is the number of data points and
414        m is the number of basis functions phi_k (one per vertex)
415
416        This algorithm uses a quad tree data structure for fast binning of data points
417        origin is a 3-tuple consisting of UTM zone, easting and northing.
418        If specified coordinates are assumed to be relative to this origin.
419
420        This one will override any data_origin that may be specified in
421        interpolation instance
422
423        """
424
425
426        #FIXME (Ole): Check that this function is memeory efficient.
427        #6 million datapoints and 300000 basis functions
428        #causes out-of-memory situation
429        #First thing to check is whether there is room for self.A and self.AtA
430        #
431        #Maybe we need some sort of blocking
432
433        from quad import build_quadtree
434        from util import ensure_numeric
435
436        if data_origin is None:
437            data_origin = self.data_origin #Use the one from
438                                           #interpolation instance
439
440        #Convert input to Numeric arrays just in case.
441        point_coordinates = ensure_numeric(point_coordinates, Float)
442
443        #Keep track of discarded points (if any).
444        #This is only registered if precrop is True
445        self.cropped_points = False
446
447        #Shift data points to same origin as mesh (if specified)
448
449        #FIXME this will shift if there was no geo_ref.
450        #But all this should be removed anyhow.
451        #change coords before this point
452        mesh_origin = self.mesh.geo_reference.get_origin()
453        if point_coordinates is not None:
454            if data_origin is not None:
455                if mesh_origin is not None:
456
457                    #Transformation:
458                    #
459                    #Let x_0 be the reference point of the point coordinates
460                    #and xi_0 the reference point of the mesh.
461                    #
462                    #A point coordinate (x + x_0) is then made relative
463                    #to xi_0 by
464                    #
465                    # x_new = x + x_0 - xi_0
466                    #
467                    #and similarly for eta
468
469                    x_offset = data_origin[1] - mesh_origin[1]
470                    y_offset = data_origin[2] - mesh_origin[2]
471                else: #Shift back to a zero origin
472                    x_offset = data_origin[1]
473                    y_offset = data_origin[2]
474
475                point_coordinates[:,0] += x_offset
476                point_coordinates[:,1] += y_offset
477            else:
478                if mesh_origin is not None:
479                    #Use mesh origin for data points
480                    point_coordinates[:,0] -= mesh_origin[1]
481                    point_coordinates[:,1] -= mesh_origin[2]
482
483
484
485        #Remove points falling outside mesh boundary
486        #This reduced one example from 1356 seconds to 825 seconds
487        if precrop is True:
488            from Numeric import take
489            from util import inside_polygon
490
491            if verbose: print 'Getting boundary polygon'
492            P = self.mesh.get_boundary_polygon()
493
494            if verbose: print 'Getting indices inside mesh boundary'
495            indices = inside_polygon(point_coordinates, P, verbose = verbose)
496
497
498            if len(indices) != point_coordinates.shape[0]:
499                self.cropped_points = True
500                if verbose:
501                    print 'Done - %d points outside mesh have been cropped.'\
502                          %(point_coordinates.shape[0] - len(indices))
503
504            point_coordinates = take(point_coordinates, indices)
505            self.point_indices = indices
506
507
508
509
510        #Build n x m interpolation matrix
511        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
512        n = point_coordinates.shape[0]     #Nbr of data points
513
514        if verbose: print 'Number of datapoints: %d' %n
515        if verbose: print 'Number of basis functions: %d' %m
516
517        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
518        #However, Sparse_CSR does not have the same methods as Sparse yet
519        #The tests will reveal what needs to be done
520        #self.A = Sparse_CSR(Sparse(n,m))
521        #self.AtA = Sparse_CSR(Sparse(m,m))
522        self.A = Sparse(n,m)
523        self.AtA = Sparse(m,m)
524
525        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
526        root = build_quadtree(self.mesh,
527                              max_points_per_cell = max_points_per_cell)
528
529        #Compute matrix elements
530        for i in range(n):
531            #For each data_coordinate point
532
533            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
534            x = point_coordinates[i]
535
536            #Find vertices near x
537            candidate_vertices = root.search(x[0], x[1])
538            is_more_elements = True
539
540            element_found, sigma0, sigma1, sigma2, k = \
541                self.search_triangles_of_vertices(candidate_vertices, x)
542            while not element_found and is_more_elements and expand_search:
543                #if verbose: print 'Expanding search'
544                candidate_vertices, branch = root.expand_search()
545                if branch == []:
546                    # Searching all the verts from the root cell that haven't
547                    # been searched.  This is the last try
548                    element_found, sigma0, sigma1, sigma2, k = \
549                      self.search_triangles_of_vertices(candidate_vertices, x)
550                    is_more_elements = False
551                else:
552                    element_found, sigma0, sigma1, sigma2, k = \
553                      self.search_triangles_of_vertices(candidate_vertices, x)
554
555
556            #Update interpolation matrix A if necessary
557            if element_found is True:
558                #Assign values to matrix A
559
560                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
561                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
562                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
563
564                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
565                js     = [j0,j1,j2]
566
567                for j in js:
568                    self.A[i,j] = sigmas[j]
569                    for k in js:
570                        if interp_only == False:
571                            self.AtA[j,k] += sigmas[j]*sigmas[k]
572            else:
573                pass
574                #Ok if there is no triangle for datapoint
575                #(as in brute force version)
576                #raise 'Could not find triangle for point', x
577
578
579
580    def search_triangles_of_vertices(self, candidate_vertices, x):
581            #Find triangle containing x:
582            element_found = False
583
584            # This will be returned if element_found = False
585            sigma2 = -10.0
586            sigma0 = -10.0
587            sigma1 = -10.0
588            k = -10.0
589
590            #For all vertices in same cell as point x
591            for v in candidate_vertices:
592
593                #for each triangle id (k) which has v as a vertex
594                for k, _ in self.mesh.vertexlist[v]:
595
596                    #Get the three vertex_points of candidate triangle
597                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
598                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
599                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
600
601                    #print "PDSG - k", k
602                    #print "PDSG - xi0", xi0
603                    #print "PDSG - xi1", xi1
604                    #print "PDSG - xi2", xi2
605                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
606                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
607
608                    #Get the three normals
609                    n0 = self.mesh.get_normal(k, 0)
610                    n1 = self.mesh.get_normal(k, 1)
611                    n2 = self.mesh.get_normal(k, 2)
612
613
614                    #Compute interpolation
615                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
616                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
617                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
618
619                    #print "PDSG - sigma0", sigma0
620                    #print "PDSG - sigma1", sigma1
621                    #print "PDSG - sigma2", sigma2
622
623                    #FIXME: Maybe move out to test or something
624                    epsilon = 1.0e-6
625                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
626
627                    #Check that this triangle contains the data point
628
629                    #Sigmas can get negative within
630                    #machine precision on some machines (e.g nautilus)
631                    #Hence the small eps
632                    eps = 1.0e-15
633                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
634                        element_found = True
635                        break
636
637                if element_found is True:
638                    #Don't look for any other triangle
639                    break
640            return element_found, sigma0, sigma1, sigma2, k
641
642
643
644    def build_interpolation_matrix_A_brute(self, point_coordinates):
645        """Build n x m interpolation matrix, where
646        n is the number of data points and
647        m is the number of basis functions phi_k (one per vertex)
648
649        This is the brute force which is too slow for large problems,
650        but could be used for testing
651        """
652
653        from util import ensure_numeric
654
655        #Convert input to Numeric arrays
656        point_coordinates = ensure_numeric(point_coordinates, Float)
657
658        #Build n x m interpolation matrix
659        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
660        n = point_coordinates.shape[0]     #Nbr of data points
661
662        self.A = Sparse(n,m)
663        self.AtA = Sparse(m,m)
664
665        #Compute matrix elements
666        for i in range(n):
667            #For each data_coordinate point
668
669            x = point_coordinates[i]
670            element_found = False
671            k = 0
672            while not element_found and k < len(self.mesh):
673                #For each triangle (brute force)
674                #FIXME: Real algorithm should only visit relevant triangles
675
676                #Get the three vertex_points
677                xi0 = self.mesh.get_vertex_coordinate(k, 0)
678                xi1 = self.mesh.get_vertex_coordinate(k, 1)
679                xi2 = self.mesh.get_vertex_coordinate(k, 2)
680
681                #Get the three normals
682                n0 = self.mesh.get_normal(k, 0)
683                n1 = self.mesh.get_normal(k, 1)
684                n2 = self.mesh.get_normal(k, 2)
685
686                #Compute interpolation
687                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
688                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
689                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
690
691                #FIXME: Maybe move out to test or something
692                epsilon = 1.0e-6
693                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
694
695                #Check that this triangle contains data point
696                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
697                    element_found = True
698                    #Assign values to matrix A
699
700                    j0 = self.mesh.triangles[k,0] #Global vertex id
701                    #self.A[i, j0] = sigma0
702
703                    j1 = self.mesh.triangles[k,1] #Global vertex id
704                    #self.A[i, j1] = sigma1
705
706                    j2 = self.mesh.triangles[k,2] #Global vertex id
707                    #self.A[i, j2] = sigma2
708
709                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
710                    js     = [j0,j1,j2]
711
712                    for j in js:
713                        self.A[i,j] = sigmas[j]
714                        for k in js:
715                            self.AtA[j,k] += sigmas[j]*sigmas[k]
716                k = k+1
717
718
719
720    def get_A(self):
721        return self.A.todense()
722
723    def get_B(self):
724        return self.B.todense()
725
726    def get_D(self):
727        return self.D.todense()
728
729        #FIXME: Remember to re-introduce the 1/n factor in the
730        #interpolation term
731
732    def build_smoothing_matrix_D(self):
733        """Build m x m smoothing matrix, where
734        m is the number of basis functions phi_k (one per vertex)
735
736        The smoothing matrix is defined as
737
738        D = D1 + D2
739
740        where
741
742        [D1]_{k,l} = \int_\Omega
743           \frac{\partial \phi_k}{\partial x}
744           \frac{\partial \phi_l}{\partial x}\,
745           dx dy
746
747        [D2]_{k,l} = \int_\Omega
748           \frac{\partial \phi_k}{\partial y}
749           \frac{\partial \phi_l}{\partial y}\,
750           dx dy
751
752
753        The derivatives \frac{\partial \phi_k}{\partial x},
754        \frac{\partial \phi_k}{\partial x} for a particular triangle
755        are obtained by computing the gradient a_k, b_k for basis function k
756        """
757
758        #FIXME: algorithm might be optimised by computing local 9x9
759        #"element stiffness matrices:
760
761        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
762
763        self.D = Sparse(m,m)
764
765        #For each triangle compute contributions to D = D1+D2
766        for i in range(len(self.mesh)):
767
768            #Get area
769            area = self.mesh.areas[i]
770
771            #Get global vertex indices
772            v0 = self.mesh.triangles[i,0]
773            v1 = self.mesh.triangles[i,1]
774            v2 = self.mesh.triangles[i,2]
775
776            #Get the three vertex_points
777            xi0 = self.mesh.get_vertex_coordinate(i, 0)
778            xi1 = self.mesh.get_vertex_coordinate(i, 1)
779            xi2 = self.mesh.get_vertex_coordinate(i, 2)
780
781            #Compute gradients for each vertex
782            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
783                              1, 0, 0)
784
785            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
786                              0, 1, 0)
787
788            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
789                              0, 0, 1)
790
791            #Compute diagonal contributions
792            self.D[v0,v0] += (a0*a0 + b0*b0)*area
793            self.D[v1,v1] += (a1*a1 + b1*b1)*area
794            self.D[v2,v2] += (a2*a2 + b2*b2)*area
795
796            #Compute contributions for basis functions sharing edges
797            e01 = (a0*a1 + b0*b1)*area
798            self.D[v0,v1] += e01
799            self.D[v1,v0] += e01
800
801            e12 = (a1*a2 + b1*b2)*area
802            self.D[v1,v2] += e12
803            self.D[v2,v1] += e12
804
805            e20 = (a2*a0 + b2*b0)*area
806            self.D[v2,v0] += e20
807            self.D[v0,v2] += e20
808
809
810    def fit(self, z):
811        """Fit a smooth surface to given 1d array of data points z.
812
813        The smooth surface is computed at each vertex in the underlying
814        mesh using the formula given in the module doc string.
815
816        Pre Condition:
817          self.A, self.AtA and self.B have been initialised
818
819        Inputs:
820          z: Single 1d vector or array of data at the point_coordinates.
821        """
822
823        #Convert input to Numeric arrays
824        from util import ensure_numeric
825        z = ensure_numeric(z, Float)
826
827        if len(z.shape) > 1 :
828            raise VectorShapeError, 'Can only deal with 1d data vector'
829
830        if self.point_indices is not None:
831            #Remove values for any points that were outside mesh
832            z = take(z, self.point_indices)
833
834        #Compute right hand side based on data
835        Atz = self.A.trans_mult(z)
836
837
838        #Check sanity
839        n, m = self.A.shape
840        if n<m and self.alpha == 0.0:
841            msg = 'ERROR (least_squares): Too few data points\n'
842            msg += 'There are only %d data points and alpha == 0. ' %n
843            msg += 'Need at least %d\n' %m
844            msg += 'Alternatively, set smoothing parameter alpha to a small '
845            msg += 'positive value,\ne.g. 1.0e-3.'
846            raise msg
847
848
849
850        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
851        #FIXME: Should we store the result here for later use? (ON)
852
853
854    def fit_points(self, z, verbose=False):
855        """Like fit, but more robust when each point has two or more attributes
856        FIXME (Ole): The name fit_points doesn't carry any meaning
857        for me. How about something like fit_multiple or fit_columns?
858        """
859
860        try:
861            if verbose: print 'Solving penalised least_squares problem'
862            return self.fit(z)
863        except VectorShapeError, e:
864            # broadcasting is not supported.
865
866            #Convert input to Numeric arrays
867            from util import ensure_numeric
868            z = ensure_numeric(z, Float)
869
870            #Build n x m interpolation matrix
871            m = self.mesh.coordinates.shape[0] #Number of vertices
872            n = z.shape[1]                     #Number of data points
873
874            f = zeros((m,n), Float) #Resulting columns
875
876            for i in range(z.shape[1]):
877                f[:,i] = self.fit(z[:,i])
878
879            return f
880
881
882    def interpolate(self, f):
883        """Evaluate smooth surface f at data points implied in self.A.
884
885        The mesh values representing a smooth surface are
886        assumed to be specified in f. This argument could,
887        for example have been obtained from the method self.fit()
888
889        Pre Condition:
890          self.A has been initialised
891
892        Inputs:
893          f: Vector or array of data at the mesh vertices.
894          If f is an array, interpolation will be done for each column as
895          per underlying matrix-matrix multiplication
896
897        Output:
898          Interpolated values at data points implied in self.A
899
900        """
901
902        return self.A * f
903
904    def cull_outsiders(self, f):
905        pass
906
907
908
909
910class Interpolation_function:
911    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
912    which is interpolated from time series defined at vertices of
913    triangular mesh (such as those stored in sww files)
914
915    Let m be the number of vertices, n the number of triangles
916    and p the number of timesteps.
917
918    Mandatory input
919        time:               px1 array of monotonously increasing times (Float)
920        quantities:         Dictionary of pxm arrays or 1 pxm array (Float)
921       
922    Optional input:
923        quantity_names:     List of keys into the quantities dictionary
924        vertex_coordinates: mx2 array of coordinates (Float)
925        triangles:          nx3 array of indices into vertex_coordinates (Int)
926        interpolation_points: array of coordinates to be interpolated to
927        verbose:            Level of reporting
928   
929   
930    The quantities returned by the callable object are specified by
931    the list quantities which must contain the names of the
932    quantities to be returned and also reflect the order, e.g. for
933    the shallow water wave equation, on would have
934    quantities = ['stage', 'xmomentum', 'ymomentum']
935
936    The parameter interpolation_points decides at which points interpolated
937    quantities are to be computed whenever object is called.
938    If None, return average value
939    """
940
941   
942   
943    def __init__(self,
944                 time,
945                 quantities,
946                 quantity_names = None, 
947                 vertex_coordinates = None,
948                 triangles = None,
949                 interpolation_points = None,
950                 verbose = False):
951        """Initialise object and build spatial interpolation if required
952        """
953
954        from Numeric import array, zeros, Float, alltrue, concatenate,\
955             reshape, ArrayType
956
957
958        from util import mean, ensure_numeric
959        from config import time_format
960        import types
961
962
963
964        #Check temporal info
965        time = ensure_numeric(time)       
966        msg = 'Time must be a monotonuosly '
967        msg += 'increasing sequence %s' %time
968        assert alltrue(time[1:] - time[:-1] > 0 ), msg
969
970
971        #Check if quantities is a single array only
972        if type(quantities) != types.DictType:
973            quantities = ensure_numeric(quantities)
974            quantity_names = ['Attribute']
975
976            #Make it a dictionary
977            quantities = {quantity_names[0]: quantities}
978
979
980        #Use keys if no names are specified
981        if quantity_names is None:
982            quantity_names = quantities.keys()
983
984
985        #Check spatial info
986        if vertex_coordinates is None:
987            self.spatial = False
988        else:   
989            vertex_coordinates = ensure_numeric(vertex_coordinates)
990
991            assert triangles is not None, 'Triangles array must be specified'
992            triangles = ensure_numeric(triangles)
993            self.spatial = True           
994           
995
996 
997        #Save for use with statistics
998        self.quantity_names = quantity_names       
999        self.quantities = quantities       
1000        self.vertex_coordinates = vertex_coordinates
1001        self.interpolation_points = interpolation_points
1002        self.T = time[:]  #Time assumed to be relative to starttime
1003        self.index = 0    #Initial time index
1004        self.precomputed_values = {}
1005           
1006
1007
1008        #Precomputed spatial interpolation if requested
1009        if interpolation_points is not None:
1010            if self.spatial is False:
1011                raise 'Triangles and vertex_coordinates must be specified'
1012           
1013
1014            try:
1015                self.interpolation_points =\
1016                                          ensure_numeric(self.interpolation_points)
1017            except:
1018                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1019                      'or a list of points\n'
1020                msg += 'I got: %s.' %( str(self.interpolation_points)[:60] + '...')
1021                raise msg
1022
1023
1024            for name in quantity_names:
1025                self.precomputed_values[name] =\
1026                                              zeros((len(self.T),
1027                                                     len(self.interpolation_points)),
1028                                                    Float)
1029
1030            #Build interpolator
1031            interpol = Interpolation(vertex_coordinates,
1032                                     triangles,
1033                                     point_coordinates = self.interpolation_points,
1034                                     alpha = 0,
1035                                     precrop = False, 
1036                                     verbose = verbose)
1037
1038            if verbose: print 'Interpolate'
1039            n = len(self.T)
1040            for i, t in enumerate(self.T):
1041                #Interpolate quantities at this timestep
1042                if verbose and i%((n+10)/10)==0:
1043                    print ' time step %d of %d' %(i, n)
1044                   
1045                for name in quantity_names:
1046                    self.precomputed_values[name][i, :] =\
1047                    interpol.interpolate(quantities[name][i,:])
1048
1049            #Report
1050            if verbose:
1051                print self.statistics()
1052                #self.print_statistics()
1053           
1054        else:
1055            #Store quantitites as is
1056            for name in quantity_names:
1057                self.precomputed_values[name] = quantities[name]
1058
1059
1060        #else:
1061        #    #Return an average, making this a time series
1062        #    for name in quantity_names:
1063        #        self.values[name] = zeros(len(self.T), Float)
1064        #
1065        #    if verbose: print 'Compute mean values'
1066        #    for i, t in enumerate(self.T):
1067        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1068        #        for name in quantity_names:
1069        #           self.values[name][i] = mean(quantities[name][i,:])
1070
1071
1072
1073
1074    def __repr__(self):
1075        #return 'Interpolation function (spation-temporal)'
1076        return self.statistics()
1077   
1078
1079    def __call__(self, t, point_id = None, x = None, y = None):
1080        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1081
1082        Inputs:
1083          t: time - Model time. Must lie within existing timesteps
1084          point_id: index of one of the preprocessed points.
1085          x, y:     Overrides location, point_id ignored
1086         
1087          If spatial info is present and all of x,y,point_id
1088          are None an exception is raised
1089                   
1090          If no spatial info is present, point_id and x,y arguments are ignored
1091          making f a function of time only.
1092
1093         
1094          FIXME: point_id could also be a slice
1095          FIXME: What if x and y are vectors?
1096          FIXME: What about f(x,y) without t?
1097        """
1098
1099        from math import pi, cos, sin, sqrt
1100        from Numeric import zeros, Float
1101        from util import mean       
1102
1103        if self.spatial is True:
1104            if point_id is None:
1105                if x is None or y is None:
1106                    msg = 'Either point_id or x and y must be specified'
1107                    raise msg
1108            else:
1109                if self.interpolation_points is None:
1110                    msg = 'Interpolation_function must be instantiated ' +\
1111                          'with a list of interpolation points before parameter ' +\
1112                          'point_id can be used'
1113                    raise msg
1114
1115
1116        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1117        msg += ' does not match model time: %s\n' %t
1118        if t < self.T[0]: raise msg
1119        if t > self.T[-1]: raise msg
1120
1121        oldindex = self.index #Time index
1122
1123        #Find current time slot
1124        while t > self.T[self.index]: self.index += 1
1125        while t < self.T[self.index]: self.index -= 1
1126
1127        if t == self.T[self.index]:
1128            #Protect against case where t == T[-1] (last time)
1129            # - also works in general when t == T[i]
1130            ratio = 0
1131        else:
1132            #t is now between index and index+1
1133            ratio = (t - self.T[self.index])/\
1134                    (self.T[self.index+1] - self.T[self.index])
1135
1136        #Compute interpolated values
1137        q = zeros(len(self.quantity_names), Float)
1138
1139        for i, name in enumerate(self.quantity_names):
1140            Q = self.precomputed_values[name]
1141
1142            if self.spatial is False:
1143                #If there is no spatial info               
1144                assert len(Q.shape) == 1
1145
1146                Q0 = Q[self.index]
1147                if ratio > 0: Q1 = Q[self.index+1]
1148
1149            else:
1150                if x is not None and y is not None:
1151                    #Interpolate to x, y
1152                   
1153                    raise 'x,y interpolation not yet implemented'
1154                else:
1155                    #Use precomputed point
1156                    Q0 = Q[self.index, point_id]
1157                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1158
1159            #Linear temporal interpolation   
1160            if ratio > 0:
1161                q[i] = Q0 + ratio*(Q1 - Q0)
1162            else:
1163                q[i] = Q0
1164
1165
1166        #Return vector of interpolated values
1167        #if len(q) == 1:
1168        #    return q[0]
1169        #else:
1170        #    return q
1171
1172
1173        #Return vector of interpolated values
1174        #FIXME:
1175        if self.spatial is True:
1176            return q
1177        else:
1178            #Replicate q according to x and y
1179            #This is e.g used for Wind_stress
1180            if x == None or y == None: 
1181                return q
1182            else:
1183                try:
1184                    N = len(x)
1185                except:
1186                    return q
1187                else:
1188                    from Numeric import ones, Float
1189                    #x is a vector - Create one constant column for each value
1190                    N = len(x)
1191                    assert len(y) == N, 'x and y must have same length'
1192                    res = []
1193                    for col in q:
1194                        res.append(col*ones(N, Float))
1195                       
1196                return res
1197
1198
1199    def statistics(self):
1200        """Output statistics about interpolation_function
1201        """
1202       
1203        vertex_coordinates = self.vertex_coordinates
1204        interpolation_points = self.interpolation_points               
1205        quantity_names = self.quantity_names
1206        quantities = self.quantities
1207        precomputed_values = self.precomputed_values                 
1208               
1209        x = vertex_coordinates[:,0]
1210        y = vertex_coordinates[:,1]               
1211
1212        str =  '------------------------------------------------\n'
1213        str += 'Interpolation_function (spation-temporal) statistics:\n'
1214        str += '  Extent:\n'
1215        str += '    x in [%f, %f], len(x) == %d\n'\
1216               %(min(x), max(x), len(x))
1217        str += '    y in [%f, %f], len(y) == %d\n'\
1218               %(min(y), max(y), len(y))
1219        str += '    t in [%f, %f], len(t) == %d\n'\
1220               %(min(self.T), max(self.T), len(self.T))
1221        str += '  Quantities:\n'
1222        for name in quantity_names:
1223            q = quantities[name][:].flat
1224            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1225
1226        if interpolation_points is not None:   
1227            str += '  Interpolation points (xi, eta):'\
1228                   ' number of points == %d\n' %interpolation_points.shape[0]
1229            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1230                                            max(interpolation_points[:,0]))
1231            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1232                                             max(interpolation_points[:,1]))
1233            str += '  Interpolated quantities (over all timesteps):\n'
1234       
1235            for name in quantity_names:
1236                q = precomputed_values[name][:].flat
1237                str += '    %s at interpolation points in [%f, %f]\n'\
1238                       %(name, min(q), max(q))
1239        str += '------------------------------------------------\n'
1240
1241        return str
1242
1243        #FIXME: Delete
1244        #print '------------------------------------------------'
1245        #print 'Interpolation_function statistics:'
1246        #print '  Extent:'
1247        #print '    x in [%f, %f], len(x) == %d'\
1248        #      %(min(x), max(x), len(x))
1249        #print '    y in [%f, %f], len(y) == %d'\
1250        #      %(min(y), max(y), len(y))
1251        #print '    t in [%f, %f], len(t) == %d'\
1252        #      %(min(self.T), max(self.T), len(self.T))
1253        #print '  Quantities:'
1254        #for name in quantity_names:
1255        #    q = quantities[name][:].flat
1256        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1257        #print '  Interpolation points (xi, eta):'\
1258        #      ' number of points == %d ' %interpolation_points.shape[0]
1259        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1260        #                             max(interpolation_points[:,0]))
1261        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1262        #                              max(interpolation_points[:,1]))
1263        #print '  Interpolated quantities (over all timesteps):'
1264        #
1265        #for name in quantity_names:
1266        #    q = precomputed_values[name][:].flat
1267        #    print '    %s at interpolation points in [%f, %f]'\
1268        #          %(name, min(q), max(q))
1269        #print '------------------------------------------------'
1270
1271
1272#-------------------------------------------------------------
1273if __name__ == "__main__":
1274    """
1275    Load in a mesh and data points with attributes.
1276    Fit the attributes to the mesh.
1277    Save a new mesh file.
1278    """
1279    import os, sys
1280    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1281            %os.path.basename(sys.argv[0])
1282
1283    if len(sys.argv) < 4:
1284        print usage
1285    else:
1286        mesh_file = sys.argv[1]
1287        point_file = sys.argv[2]
1288        mesh_output_file = sys.argv[3]
1289
1290        expand_search = False
1291        if len(sys.argv) > 4:
1292            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1293                expand_search = True
1294            else:
1295                expand_search = False
1296
1297        verbose = False
1298        if len(sys.argv) > 5:
1299            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1300                verbose = False
1301            else:
1302                verbose = True
1303
1304        if len(sys.argv) > 6:
1305            alpha = sys.argv[6]
1306        else:
1307            alpha = DEFAULT_ALPHA
1308
1309        # This is used more for testing
1310        if len(sys.argv) > 7:
1311            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1312                display_errors = False
1313            else:
1314                display_errors = True
1315           
1316        t0 = time.time()
1317        try:
1318            fit_to_mesh_file(mesh_file,
1319                         point_file,
1320                         mesh_output_file,
1321                         alpha,
1322                         verbose= verbose,
1323                         expand_search = expand_search,
1324                         display_errors = display_errors)
1325        except IOError,e:
1326            import sys; sys.exit(1)
1327
1328        print 'That took %.2f seconds' %(time.time()-t0)
1329
Note: See TracBrowser for help on using the repository browser.