source: inundation/fit_interpolate/spike_least_squares.py @ 2879

Last change on this file since 2879 was 2781, checked in by duncan, 19 years ago

investigating ticket #8

File size: 50.1 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22class FittingError(exceptions.Exception): pass
23
24
25#from general_mesh import General_mesh
26from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType, NewAxis
27from pyvolution.mesh import Mesh
28
29from Numeric import zeros, take, compress, array, Float, Int, dot, transpose, concatenate, ArrayType
30from utilities.sparse import Sparse, Sparse_CSR
31from utilities.cg_solve import conjugate_gradient, VectorShapeError
32from utilities.numerical_tools import ensure_numeric, mean, gradient
33
34
35from coordinate_transforms.geo_reference import Geo_reference
36
37import time
38
39
40
41DEFAULT_ALPHA = 0.001
42
43def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
44                     alpha=DEFAULT_ALPHA, verbose= False,
45                     expand_search = False,
46                     data_origin = None,
47                     mesh_origin = None,
48                     precrop = False,
49                     display_errors = True):
50    """
51    Given a mesh file (tsh) and a point attribute file (xya), fit
52    point attributes to the mesh and write a mesh file with the
53    results.
54
55
56    If data_origin is not None it is assumed to be
57    a 3-tuple with geo referenced
58    UTM coordinates (zone, easting, northing)
59
60    NOTE: Throws IOErrors, for a variety of file problems.
61   
62    mesh_origin is the same but refers to the input tsh file.
63    FIXME: When the tsh format contains it own origin, these parameters can go.
64    FIXME: And both origins should be obtained from the specified files.
65    """
66
67    from load_mesh.loadASCII import import_mesh_file, \
68                 import_points_file, export_mesh_file, \
69                 concatinate_attributelist
70
71
72    try:
73        mesh_dict = import_mesh_file(mesh_file)
74    except IOError,e:
75        if display_errors:
76            print "Could not load bad file. ", e
77        raise IOError  #Re-raise exception
78       
79    vertex_coordinates = mesh_dict['vertices']
80    triangles = mesh_dict['triangles']
81    if type(mesh_dict['vertex_attributes']) == ArrayType:
82        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
83    else:
84        old_point_attributes = mesh_dict['vertex_attributes']
85
86    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
87        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
88    else:
89        old_title_list = mesh_dict['vertex_attribute_titles']
90
91    if verbose: print 'tsh file %s loaded' %mesh_file
92
93    # load in the .pts file
94    try:
95        point_dict = import_points_file(point_file, verbose=verbose)
96    except IOError,e:
97        if display_errors:
98            print "Could not load bad file. ", e
99        raise IOError  #Re-raise exception 
100
101    point_coordinates = point_dict['pointlist']
102    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
103
104    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
105        data_origin = point_dict['geo_reference'].get_origin()
106    else:
107        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
108
109    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
110        mesh_origin = mesh_dict['geo_reference'].get_origin()
111    else:
112        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
113
114    if verbose: print "points file loaded"
115    if verbose: print "fitting to mesh"
116    f = fit_to_mesh(vertex_coordinates,
117                    triangles,
118                    point_coordinates,
119                    point_attributes,
120                    alpha = alpha,
121                    verbose = verbose,
122                    expand_search = expand_search,
123                    data_origin = data_origin,
124                    mesh_origin = mesh_origin,
125                    precrop = precrop)
126    if verbose: print "finished fitting to mesh"
127
128    # convert array to list of lists
129    new_point_attributes = f.tolist()
130    #FIXME have this overwrite attributes with the same title - DSG
131    #Put the newer attributes last
132    if old_title_list <> []:
133        old_title_list.extend(title_list)
134        #FIXME can this be done a faster way? - DSG
135        for i in range(len(old_point_attributes)):
136            old_point_attributes[i].extend(new_point_attributes[i])
137        mesh_dict['vertex_attributes'] = old_point_attributes
138        mesh_dict['vertex_attribute_titles'] = old_title_list
139    else:
140        mesh_dict['vertex_attributes'] = new_point_attributes
141        mesh_dict['vertex_attribute_titles'] = title_list
142
143    #FIXME (Ole): Remember to output mesh_origin as well
144    if verbose: print "exporting to file ", mesh_output_file
145
146    try:
147        export_mesh_file(mesh_output_file, mesh_dict)
148    except IOError,e:
149        if display_errors:
150            print "Could not write file. ", e
151        raise IOError
152
153def fit_to_mesh(vertex_coordinates,
154                triangles,
155                point_coordinates,
156                point_attributes,
157                alpha = DEFAULT_ALPHA,
158                verbose = False,
159                acceptable_overshoot = 1.01,
160                expand_search = False,
161                data_origin = None,
162                mesh_origin = None,
163                precrop = False,
164                use_cache = False):
165    """
166    Fit a smooth surface to a triangulation,
167    given data points with attributes.
168
169
170        Inputs:
171
172          vertex_coordinates: List of coordinate pairs [xi, eta] of points
173          constituting mesh (or a an m x 2 Numeric array)
174
175          triangles: List of 3-tuples (or a Numeric array) of
176          integers representing indices of all vertices in the mesh.
177
178          point_coordinates: List of coordinate pairs [x, y] of data points
179          (or an nx2 Numeric array)
180
181          alpha: Smoothing parameter.
182
183          acceptable overshoot: controls the allowed factor by which fitted values
184          may exceed the value of input data. The lower limit is defined
185          as min(z) - acceptable_overshoot*delta z and upper limit
186          as max(z) + acceptable_overshoot*delta z
187         
188
189          point_attributes: Vector or array of data at the point_coordinates.
190
191          data_origin and mesh_origin are 3-tuples consisting of
192          UTM zone, easting and northing. If specified
193          point coordinates and vertex coordinates are assumed to be
194          relative to their respective origins.
195
196    """
197
198    if use_cache is True:
199        from caching.caching import cache
200        interp = cache(_interpolation,
201                       (vertex_coordinates,
202                        triangles,
203                        point_coordinates),
204                       {'alpha': alpha,
205                        'verbose': verbose,
206                        'expand_search': expand_search,
207                        'data_origin': data_origin,
208                        'mesh_origin': mesh_origin,
209                        'precrop': precrop},
210                       verbose = verbose)       
211       
212    else:
213        interp = Interpolation(vertex_coordinates,
214                               triangles,
215                               point_coordinates,
216                               alpha = alpha,
217                               verbose = verbose,
218                               expand_search = expand_search,
219                               data_origin = data_origin,
220                               mesh_origin = mesh_origin,
221                               precrop = precrop)
222
223    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
224
225
226    #Sanity check
227    point_coordinates = ensure_numeric(point_coordinates)
228    vertex_coordinates = ensure_numeric(vertex_coordinates)
229
230    #Data points
231    X = point_coordinates[:,0]
232    Y = point_coordinates[:,1] 
233    Z = ensure_numeric(point_attributes)
234    if len(Z.shape) == 1:
235        Z = Z[:, NewAxis]
236       
237
238    #Data points inside mesh boundary
239    indices = interp.point_indices
240    if indices is not None:   
241        Xc = take(X, indices)
242        Yc = take(Y, indices)   
243        Zc = take(Z, indices)
244    else:
245        Xc = X
246        Yc = Y 
247        Zc = Z       
248   
249    #Vertex coordinates
250    Xi = vertex_coordinates[:,0]
251    Eta = vertex_coordinates[:,1]       
252    Zeta = ensure_numeric(vertex_attributes)
253    if len(Zeta.shape) == 1:
254        Zeta = Zeta[:, NewAxis]   
255
256    for i in range(Zeta.shape[1]): #For each attribute
257        zeta = Zeta[:,i]
258        z = Z[:,i]               
259        zc = Zc[:,i]
260
261        max_zc = max(zc)
262        min_zc = min(zc)
263        delta_zc = max_zc-min_zc
264        upper_limit = max_zc + delta_zc*acceptable_overshoot
265        lower_limit = min_zc - delta_zc*acceptable_overshoot       
266       
267
268        if max(zeta) > upper_limit or min(zeta) < lower_limit:
269            msg = 'Least sqares produced values outside the allowed '
270            msg += 'range [%f, %f].\n' %(lower_limit, upper_limit)
271            msg += 'z in [%f, %f], zeta in [%f, %f].\n' %(min_zc, max_zc,
272                                                          min(zeta), max(zeta))
273            msg += 'If greater range is needed, increase the value of '
274            msg += 'acceptable_fit_overshoot (currently %.2f).\n' %(acceptable_overshoot)
275
276
277            offending_vertices = (zeta > upper_limit or zeta < lower_limit)
278            Xi_c = compress(offending_vertices, Xi)
279            Eta_c = compress(offending_vertices, Eta)
280            offending_coordinates = concatenate((Xi_c[:, NewAxis],
281                                                 Eta_c[:, NewAxis]),
282                                                axis=1)
283
284            msg += 'Offending locations:\n %s' %(offending_coordinates)
285           
286            raise FittingError, msg
287
288
289   
290        if verbose:
291            print '+------------------------------------------------'
292            print 'Least squares statistics'
293            print '+------------------------------------------------'   
294            print 'points: %d points' %(len(z))
295            print '    x in [%f, %f]'%(min(X), max(X))
296            print '    y in [%f, %f]'%(min(Y), max(Y))
297            print '    z in [%f, %f]'%(min(z), max(z))
298            print
299
300            if indices is not None:
301                print 'Cropped points: %d points' %(len(zc))
302                print '    x in [%f, %f]'%(min(Xc), max(Xc))
303                print '    y in [%f, %f]'%(min(Yc), max(Yc))
304                print '    z in [%f, %f]'%(min(zc), max(zc))
305                print
306           
307
308            print 'Mesh: %d vertices' %(len(zeta))
309            print '    xi in [%f, %f]'%(min(Xi), max(Xi))
310            print '    eta in [%f, %f]'%(min(Eta), max(Eta))
311            print '    zeta in [%f, %f]'%(min(zeta), max(zeta))
312            print '+------------------------------------------------'
313
314    return vertex_attributes
315
316
317
318def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
319                    verbose = False, reduction = 1):
320    """Fits attributes from pts file to MxN rectangular mesh
321
322    Read pts file and create rectangular mesh of resolution MxN such that
323    it covers all points specified in pts file.
324
325    FIXME: This may be a temporary function until we decide on
326    netcdf formats etc
327
328    FIXME: Uses elevation hardwired
329    """
330
331    import  mesh_factory
332    from load_mesh.loadASCII import import_points_file
333   
334    if verbose: print 'Read pts'
335    points_dict = import_points_file(pts_name)
336    #points, attributes = util.read_xya(pts_name)
337
338    #Reduce number of points a bit
339    points = points_dict['pointlist'][::reduction]
340    elevation = points_dict['attributelist']['elevation']  #Must be elevation
341    elevation = elevation[::reduction]
342
343    if verbose: print 'Got %d data points' %len(points)
344
345    if verbose: print 'Create mesh'
346    #Find extent
347    max_x = min_x = points[0][0]
348    max_y = min_y = points[0][1]
349    for point in points[1:]:
350        x = point[0]
351        if x > max_x: max_x = x
352        if x < min_x: min_x = x
353        y = point[1]
354        if y > max_y: max_y = y
355        if y < min_y: min_y = y
356
357    #Create appropriate mesh
358    vertex_coordinates, triangles, boundary =\
359         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
360                                (min_x, min_y))
361
362    #Fit attributes to mesh
363    vertex_attributes = fit_to_mesh(vertex_coordinates,
364                        triangles,
365                        points,
366                        elevation, alpha=alpha, verbose=verbose)
367
368
369
370    return vertex_coordinates, triangles, boundary, vertex_attributes
371
372
373def _interpolation(*args, **kwargs):
374    """Private function for use with caching. Reason is that classes
375    may change their byte code between runs which is annoying.
376    """
377   
378    return Interpolation(*args, **kwargs)
379
380
381class Interpolation:
382
383    def __init__(self,
384                 vertex_coordinates,
385                 triangles,
386                 z,
387                 point_coordinates = None,
388                 alpha = None,
389                 verbose = False,
390                 expand_search = True,
391                 interp_only = False,
392                 max_points_per_cell = 30,
393                 mesh_origin = None,
394                 data_origin = None,
395                 precrop = False):
396
397
398        """ Build interpolation matrix mapping from
399        function values at vertices to function values at data points
400
401        Inputs:
402
403          vertex_coordinates: List of coordinate pairs [xi, eta] of
404          points constituting mesh (or a an m x 2 Numeric array)
405          Points may appear multiple times
406          (e.g. if vertices have discontinuities)
407
408          triangles: List of 3-tuples (or a Numeric array) of
409          integers representing indices of all vertices in the mesh.
410
411          point_coordinates: List of coordinate pairs [x, y] of
412          data points (or an nx2 Numeric array)
413          If point_coordinates is absent, only smoothing matrix will
414          be built
415
416          alpha: Smoothing parameter
417
418          data_origin and mesh_origin are 3-tuples consisting of
419          UTM zone, easting and northing. If specified
420          point coordinates and vertex coordinates are assumed to be
421          relative to their respective origins.
422         
423          z: Single 1d vector or array of data at the point_coordinates.
424
425        """
426
427        #Convert input to Numeric arrays
428        triangles = ensure_numeric(triangles, Int)
429        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
430
431        #Build underlying mesh
432        if verbose: print 'Building mesh'
433        #self.mesh = General_mesh(vertex_coordinates, triangles,
434        #FIXME: Trying the normal mesh while testing precrop,
435        #       The functionality of boundary_polygon is needed for that
436
437        #FIXME - geo ref does not have to go into mesh.
438        # Change the point co-ords to conform to the
439        # mesh co-ords early in the code
440        if mesh_origin is None:
441            geo = None
442        else:
443            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
444        self.mesh = Mesh(vertex_coordinates, triangles,
445                         geo_reference = geo)
446
447        if verbose: print 'Checking mesh integrity'
448        self.mesh.check_integrity()
449
450        if verbose: print 'Mesh integrity checked'
451       
452        self.data_origin = data_origin
453
454        self.point_indices = None
455
456        #Smoothing parameter
457        if alpha is None:
458            self.alpha = DEFAULT_ALPHA
459        else:   
460            self.alpha = alpha
461
462
463        if point_coordinates is not None:
464            if verbose: print 'Building interpolation matrix'
465            self.build_interpolation_matrix_A(point_coordinates,
466                                              z,
467                                              verbose = verbose,
468                                              expand_search = expand_search,
469                                              interp_only = interp_only, 
470                                              max_points_per_cell =\
471                                              max_points_per_cell,
472                                              data_origin = data_origin,
473                                              precrop = precrop)
474        #Build coefficient matrices
475        if interp_only == False:
476            self.build_coefficient_matrix_B(point_coordinates,
477                                        verbose = verbose,
478                                        expand_search = expand_search,
479                                        max_points_per_cell =\
480                                        max_points_per_cell,
481                                        data_origin = data_origin,
482                                        precrop = precrop)
483        if verbose: print 'Finished interpolation'
484
485    def set_point_coordinates(self, point_coordinates,
486                              data_origin = None,
487                              verbose = False,
488                              precrop = True):
489        """
490        A public interface to setting the point co-ordinates.
491        """
492        if point_coordinates is not None:
493            if verbose: print 'Building interpolation matrix'
494            self.build_interpolation_matrix_A(point_coordinates,
495                                              z,
496                                              verbose = verbose,
497                                              data_origin = data_origin,
498                                              precrop = precrop)
499        self.build_coefficient_matrix_B(point_coordinates, data_origin)
500
501    def build_coefficient_matrix_B(self, point_coordinates=None,
502                                   verbose = False, expand_search = True,
503                                   max_points_per_cell=30,
504                                   data_origin = None,
505                                   precrop = False):
506        """Build final coefficient matrix"""
507
508
509        if self.alpha <> 0:
510            if verbose: print 'Building smoothing matrix'
511            self.build_smoothing_matrix_D()
512
513        if point_coordinates is not None:
514            if self.alpha <> 0:
515                self.B = self.AtA + self.alpha*self.D
516            else:
517                self.B = self.AtA
518
519            #Convert self.B matrix to CSR format for faster matrix vector
520            self.B = Sparse_CSR(self.B)
521
522    def build_interpolation_matrix_A(self, point_coordinates,
523                                     z,
524                                     verbose = False, expand_search = True,
525                                     max_points_per_cell=30,
526                                     data_origin = None,
527                                     precrop = False,
528                                     interp_only = False):
529        """Build n x m interpolation matrix, where
530        n is the number of data points and
531        m is the number of basis functions phi_k (one per vertex)
532
533        This algorithm uses a quad tree data structure for fast binning of data points
534        origin is a 3-tuple consisting of UTM zone, easting and northing.
535        If specified coordinates are assumed to be relative to this origin.
536
537        This one will override any data_origin that may be specified in
538        interpolation instance
539
540        """
541
542        from pyvolution.quad import build_quadtree
543        from utilities.polygon import inside_polygon
544       
545        #Convert input to Numeric arrays
546        z = ensure_numeric(z, Float)
547           
548        #Convert input to Numeric arrays just in case.
549        point_coordinates = ensure_numeric(point_coordinates, Float)
550
551        #Build n x m interpolation matrix
552        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
553        n = point_coordinates.shape[0]     #Nbr of data points
554        if len(z.shape) > 1:
555            att_num = z.shape[1]
556        else:
557            att_num = 0
558           
559        assert z.shape[0] == point_coordinates.shape[0] 
560        if verbose: print 'Number of attributes: %d' %att_num
561        if verbose: print 'Number of datapoints: %d' %n
562        if verbose: print 'Number of basis functions: %d' %m
563
564        self.A = Sparse(n,m)
565        self.AtA = Sparse(m,m)
566        self.Atz = zeros((m,att_num), Float)
567
568        #Build quad tree of vertices
569        root = build_quadtree(self.mesh,
570                              max_points_per_cell = max_points_per_cell)
571        #root.show()
572        self.expanded_quad_searches = []
573        #Compute matrix elements
574        for i in range(n):
575            #For each data_coordinate point
576
577            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
578            x = point_coordinates[i]
579
580            #Find vertices near x
581            candidate_vertices = root.search(x[0], x[1])
582            is_more_elements = True
583
584            element_found, sigma0, sigma1, sigma2, k = \
585                self.search_triangles_of_vertices(candidate_vertices, x)
586            first_expansion = True
587            while not element_found and is_more_elements and expand_search:
588                #if verbose: print 'Expanding search'
589                if first_expansion == True:
590                    self.expanded_quad_searches.append(1)
591                    first_expansion = False
592                else:
593                    end = len(self.expanded_quad_searches) - 1
594                    assert end >= 0
595                    self.expanded_quad_searches[end] += 1
596                candidate_vertices, branch = root.expand_search()
597                if branch == []:
598                    # Searching all the verts from the root cell that haven't
599                    # been searched.  This is the last try
600                    element_found, sigma0, sigma1, sigma2, k = \
601                      self.search_triangles_of_vertices(candidate_vertices, x)
602                    is_more_elements = False
603                else:
604                    element_found, sigma0, sigma1, sigma2, k = \
605                      self.search_triangles_of_vertices(candidate_vertices, x)
606
607               
608            #Update interpolation matrix A if necessary
609            if element_found is True:
610                #Assign values to matrix A
611
612                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
613                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
614                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
615
616                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
617                js     = [j0,j1,j2]
618
619                for j in js:
620                    self.A[i,j] = sigmas[j]
621                    #self.Atz[n-i-1] +=  sigmas[j]*z[n-j-1]
622                    #print "i",i
623                    #print "j",j
624                    #print "self.Atz",self.Atz
625                    #print "z",z
626                    #a=self.Atz[j]
627                    #b=z[i]
628                    self.Atz[j] +=  sigmas[j]*z[i]
629                    #print "m-i-1",m-i-1
630                    #print "n-j-1",n-j-1
631                    #print "self.A[i,j]",self.A[i,j]
632                    #print "z[n-j-1]",z[n-j-1]
633                    #print "z[i]",z[i]
634                    #print "",
635                    #print "self.Atz",self.Atz                   
636                    for k in js:
637                        if interp_only == False:
638                            self.AtA[j,k] += sigmas[j]*sigmas[k]
639            else:
640                pass
641                #Ok if there is no triangle for datapoint
642                #(as in brute force version)
643                #raise 'Could not find triangle for point', x
644
645
646
647    def search_triangles_of_vertices(self, candidate_vertices, x):
648            #Find triangle containing x:
649            element_found = False
650
651            # This will be returned if element_found = False
652            sigma2 = -10.0
653            sigma0 = -10.0
654            sigma1 = -10.0
655            k = -10.0
656            #print "*$* candidate_vertices", candidate_vertices
657            #For all vertices in same cell as point x
658            for v in candidate_vertices:
659                #FIXME (DSG-DSG): this catches verts with no triangle.
660                #Currently pmesh is producing these.
661                #this should be stopped,
662                if self.mesh.vertexlist[v] is None:
663                    continue
664                #for each triangle id (k) which has v as a vertex
665                for k, _ in self.mesh.vertexlist[v]:
666
667                    #Get the three vertex_points of candidate triangle
668                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
669                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
670                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
671
672                    #print "PDSG - k", k
673                    #print "PDSG - xi0", xi0
674                    #print "PDSG - xi1", xi1
675                    #print "PDSG - xi2", xi2
676                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
677                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
678
679                    #Get the three normals
680                    n0 = self.mesh.get_normal(k, 0)
681                    n1 = self.mesh.get_normal(k, 1)
682                    n2 = self.mesh.get_normal(k, 2)
683
684
685                    #Compute interpolation
686                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
687                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
688                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
689
690                    #print "PDSG - sigma0", sigma0
691                    #print "PDSG - sigma1", sigma1
692                    #print "PDSG - sigma2", sigma2
693
694                    #FIXME: Maybe move out to test or something
695                    epsilon = 1.0e-6
696                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
697
698                    #Check that this triangle contains the data point
699
700                    #Sigmas can get negative within
701                    #machine precision on some machines (e.g nautilus)
702                    #Hence the small eps
703                    eps = 1.0e-15
704                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
705                        element_found = True
706                        break
707
708                if element_found is True:
709                    #Don't look for any other triangle
710                    break
711            return element_found, sigma0, sigma1, sigma2, k
712
713
714
715    def build_interpolation_matrix_A_brute(self, point_coordinates):
716        """Build n x m interpolation matrix, where
717        n is the number of data points and
718        m is the number of basis functions phi_k (one per vertex)
719
720        This is the brute force which is too slow for large problems,
721but could be used for testing
722        """
723
724
725        #Convert input to Numeric arrays
726        point_coordinates = ensure_numeric(point_coordinates, Float)
727
728        #Build n x m interpolation matrix
729        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
730        n = point_coordinates.shape[0]     #Nbr of data points
731
732        self.A = Sparse(n,m)
733        self.AtA = Sparse(m,m)
734
735        #Compute matrix elements
736        for i in range(n):
737            #For each data_coordinate point
738
739            x = point_coordinates[i]
740            element_found = False
741            k = 0
742            while not element_found and k < len(self.mesh):
743                #For each triangle (brute force)
744                #FIXME: Real algorithm should only visit relevant triangles
745
746                #Get the three vertex_points
747                xi0 = self.mesh.get_vertex_coordinate(k, 0)
748                xi1 = self.mesh.get_vertex_coordinate(k, 1)
749                xi2 = self.mesh.get_vertex_coordinate(k, 2)
750
751                #Get the three normals
752                n0 = self.mesh.get_normal(k, 0)
753                n1 = self.mesh.get_normal(k, 1)
754                n2 = self.mesh.get_normal(k, 2)
755
756                #Compute interpolation
757                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
758                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
759                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
760
761                #FIXME: Maybe move out to test or something
762                epsilon = 1.0e-6
763                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
764
765                #Check that this triangle contains data point
766                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
767                    element_found = True
768                    #Assign values to matrix A
769
770                    j0 = self.mesh.triangles[k,0] #Global vertex id
771                    #self.A[i, j0] = sigma0
772
773                    j1 = self.mesh.triangles[k,1] #Global vertex id
774                    #self.A[i, j1] = sigma1
775
776                    j2 = self.mesh.triangles[k,2] #Global vertex id
777                    #self.A[i, j2] = sigma2
778
779                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
780                    js     = [j0,j1,j2]
781
782                    for j in js:
783                        self.A[i,j] = sigmas[j]
784                        for k in js:
785                            self.AtA[j,k] += sigmas[j]*sigmas[k]
786                k = k+1
787
788
789
790    def get_A(self):
791        return self.A.todense()
792
793    def get_B(self):
794        return self.B.todense()
795
796    def get_D(self):
797        return self.D.todense()
798
799        #FIXME: Remember to re-introduce the 1/n factor in the
800        #interpolation term
801
802    def build_smoothing_matrix_D(self):
803        """Build m x m smoothing matrix, where
804        m is the number of basis functions phi_k (one per vertex)
805
806        The smoothing matrix is defined as
807
808        D = D1 + D2
809
810        where
811
812        [D1]_{k,l} = \int_\Omega
813           \frac{\partial \phi_k}{\partial x}
814           \frac{\partial \phi_l}{\partial x}\,
815           dx dy
816
817        [D2]_{k,l} = \int_\Omega
818           \frac{\partial \phi_k}{\partial y}
819           \frac{\partial \phi_l}{\partial y}\,
820           dx dy
821
822
823        The derivatives \frac{\partial \phi_k}{\partial x},
824        \frac{\partial \phi_k}{\partial x} for a particular triangle
825        are obtained by computing the gradient a_k, b_k for basis function k
826        """
827
828        #FIXME: algorithm might be optimised by computing local 9x9
829        #"element stiffness matrices:
830
831        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
832
833        self.D = Sparse(m,m)
834
835        #For each triangle compute contributions to D = D1+D2
836        for i in range(len(self.mesh)):
837
838            #Get area
839            area = self.mesh.areas[i]
840
841            #Get global vertex indices
842            v0 = self.mesh.triangles[i,0]
843            v1 = self.mesh.triangles[i,1]
844            v2 = self.mesh.triangles[i,2]
845
846            #Get the three vertex_points
847            xi0 = self.mesh.get_vertex_coordinate(i, 0)
848            xi1 = self.mesh.get_vertex_coordinate(i, 1)
849            xi2 = self.mesh.get_vertex_coordinate(i, 2)
850
851            #Compute gradients for each vertex
852            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
853                              1, 0, 0)
854
855            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
856                              0, 1, 0)
857
858            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
859                              0, 0, 1)
860
861            #Compute diagonal contributions
862            self.D[v0,v0] += (a0*a0 + b0*b0)*area
863            self.D[v1,v1] += (a1*a1 + b1*b1)*area
864            self.D[v2,v2] += (a2*a2 + b2*b2)*area
865
866            #Compute contributions for basis functions sharing edges
867            e01 = (a0*a1 + b0*b1)*area
868            self.D[v0,v1] += e01
869            self.D[v1,v0] += e01
870
871            e12 = (a1*a2 + b1*b2)*area
872            self.D[v1,v2] += e12
873            self.D[v2,v1] += e12
874
875            e20 = (a2*a0 + b2*b0)*area
876            self.D[v2,v0] += e20
877            self.D[v0,v2] += e20
878
879
880    def fit(self, z):
881        """Fit a smooth surface to given 1d array of data points z.
882
883        The smooth surface is computed at each vertex in the underlying
884        mesh using the formula given in the module doc string.
885
886        Pre Condition:
887          self.A, self.AtA and self.B have been initialised
888
889        Inputs:
890          z: Single 1d vector or array of data at the point_coordinates.
891        """
892
893        #Convert input to Numeric arrays
894        z = ensure_numeric(z, Float)
895
896        if len(z.shape) > 1 :
897            raise VectorShapeError, 'Can only deal with 1d data vector'
898
899        if self.point_indices is not None:
900            #Remove values for any points that were outside mesh
901            z = take(z, self.point_indices)
902
903        #Compute right hand side based on data
904        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
905        # after a matrix is built, before calcs.
906        Atz = self.Atz
907        #print "self.get_A()", self.get_A()
908        #print "self.Atz",self.Atz
909        #print 'z', z
910        Atz = self.A.trans_mult(z)
911        #print "self.A.trans_mult(z)",self.A.trans_mult(z)
912
913
914        #Check sanity
915        n, m = self.A.shape
916        if n<m and self.alpha == 0.0:
917            msg = 'ERROR (least_squares): Too few data points\n'
918            msg += 'There are only %d data points and alpha == 0. ' %n
919            msg += 'Need at least %d\n' %m
920            msg += 'Alternatively, set smoothing parameter alpha to a small '
921            msg += 'positive value,\ne.g. 1.0e-3.'
922            raise Exception(msg)
923
924
925
926        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
927        #FIXME: Should we store the result here for later use? (ON)
928
929
930    def fit_points(self, z, verbose=False):
931        """Like fit, but more robust when each point has two or more attributes
932        FIXME (Ole): The name fit_points doesn't carry any meaning
933        for me. How about something like fit_multiple or fit_columns?
934        """
935
936        try:
937            if verbose: print 'Solving penalised least_squares problem'
938            return self.fit(z)
939        except VectorShapeError, e:
940            # broadcasting is not supported.
941
942            #Convert input to Numeric arrays
943            z = ensure_numeric(z, Float)
944
945            #Build n x m interpolation matrix
946            m = self.mesh.coordinates.shape[0] #Number of vertices
947            n = z.shape[1]                     #Number of data points
948
949            f = zeros((m,n), Float) #Resulting columns
950
951            for i in range(z.shape[1]):
952                f[:,i] = self.fit(z[:,i])
953
954            return f
955
956
957    def interpolate(self, f):
958        """Evaluate smooth surface f at data points implied in self.A.
959
960        The mesh values representing a smooth surface are
961        assumed to be specified in f. This argument could,
962        for example have been obtained from the method self.fit()
963
964        Pre Condition:
965          self.A has been initialised
966
967        Inputs:
968          f: Vector or array of data at the mesh vertices.
969          If f is an array, interpolation will be done for each column as
970          per underlying matrix-matrix multiplication
971
972        Output:
973          Interpolated values at data points implied in self.A
974
975        """
976
977        return self.A * f
978
979    def cull_outsiders(self, f):
980        pass
981
982
983
984
985class Interpolation_function:
986    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
987    which is interpolated from time series defined at vertices of
988    triangular mesh (such as those stored in sww files)
989
990    Let m be the number of vertices, n the number of triangles
991    and p the number of timesteps.
992
993    Mandatory input
994        time:               px1 array of monotonously increasing times (Float)
995        quantities:         Dictionary of arrays or 1 array (Float)
996                            The arrays must either have dimensions pxm or mx1.
997                            The resulting function will be time dependent in
998                            the former case while it will be constant with
999                            respect to time in the latter case.
1000       
1001    Optional input:
1002        quantity_names:     List of keys into the quantities dictionary
1003        vertex_coordinates: mx2 array of coordinates (Float)
1004        triangles:          nx3 array of indices into vertex_coordinates (Int)
1005        interpolation_points: Nx2 array of coordinates to be interpolated to
1006        verbose:            Level of reporting
1007   
1008   
1009    The quantities returned by the callable object are specified by
1010    the list quantities which must contain the names of the
1011    quantities to be returned and also reflect the order, e.g. for
1012    the shallow water wave equation, on would have
1013    quantities = ['stage', 'xmomentum', 'ymomentum']
1014
1015    The parameter interpolation_points decides at which points interpolated
1016    quantities are to be computed whenever object is called.
1017    If None, return average value
1018    """
1019
1020   
1021   
1022    def __init__(self,
1023                 time,
1024                 quantities,
1025                 quantity_names = None, 
1026                 vertex_coordinates = None,
1027                 triangles = None,
1028                 interpolation_points = None,
1029                 verbose = False):
1030        """Initialise object and build spatial interpolation if required
1031        """
1032
1033        from Numeric import array, zeros, Float, alltrue, concatenate,\
1034             reshape, ArrayType
1035
1036
1037        from config import time_format
1038        import types
1039
1040
1041
1042        #Check temporal info
1043        time = ensure_numeric(time)       
1044        msg = 'Time must be a monotonuosly '
1045        msg += 'increasing sequence %s' %time
1046        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
1047
1048
1049        #Check if quantities is a single array only
1050        if type(quantities) != types.DictType:
1051            quantities = ensure_numeric(quantities)
1052            quantity_names = ['Attribute']
1053
1054            #Make it a dictionary
1055            quantities = {quantity_names[0]: quantities}
1056
1057
1058        #Use keys if no names are specified
1059        if quantity_names is None:
1060            quantity_names = quantities.keys()
1061
1062
1063        #Check spatial info
1064        if vertex_coordinates is None:
1065            self.spatial = False
1066        else:   
1067            vertex_coordinates = ensure_numeric(vertex_coordinates)
1068
1069            assert triangles is not None, 'Triangles array must be specified'
1070            triangles = ensure_numeric(triangles)
1071            self.spatial = True           
1072           
1073
1074 
1075        #Save for use with statistics
1076        self.quantity_names = quantity_names       
1077        self.quantities = quantities       
1078        self.vertex_coordinates = vertex_coordinates
1079        self.interpolation_points = interpolation_points
1080        self.T = time[:]  # Time assumed to be relative to starttime
1081        self.index = 0    # Initial time index
1082        self.precomputed_values = {}
1083           
1084
1085
1086        #Precomputed spatial interpolation if requested
1087        if interpolation_points is not None:
1088            if self.spatial is False:
1089                raise 'Triangles and vertex_coordinates must be specified'
1090           
1091            try:
1092                self.interpolation_points = ensure_numeric(interpolation_points)
1093            except:
1094                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1095                      'or a list of points\n'
1096                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1097                                      '...')
1098                raise msg
1099
1100
1101            m = len(self.interpolation_points)
1102            p = len(self.T)
1103           
1104            for name in quantity_names:
1105                self.precomputed_values[name] = zeros((p, m), Float)
1106
1107            #Build interpolator
1108            interpol = Interpolation(vertex_coordinates,
1109                                     triangles,
1110                                     point_coordinates = \
1111                                     self.interpolation_points,
1112                                     alpha = 0,
1113                                     precrop = False, 
1114                                     verbose = verbose)
1115
1116            if verbose: print 'Interpolate'
1117            for i, t in enumerate(self.T):
1118                #Interpolate quantities at this timestep
1119                if verbose and i%((p+10)/10)==0:
1120                    print ' time step %d of %d' %(i, p)
1121                   
1122                for name in quantity_names:
1123                    if len(quantities[name].shape) == 2:
1124                        result = interpol.interpolate(quantities[name][i,:])
1125                    else:
1126                       #Assume no time dependency
1127                       result = interpol.interpolate(quantities[name][:])
1128                       
1129                    self.precomputed_values[name][i, :] = result
1130                   
1131                       
1132
1133            #Report
1134            if verbose:
1135                print self.statistics()
1136                #self.print_statistics()
1137           
1138        else:
1139            #Store quantitites as is
1140            for name in quantity_names:
1141                self.precomputed_values[name] = quantities[name]
1142
1143
1144        #else:
1145        #    #Return an average, making this a time series
1146        #    for name in quantity_names:
1147        #        self.values[name] = zeros(len(self.T), Float)
1148        #
1149        #    if verbose: print 'Compute mean values'
1150        #    for i, t in enumerate(self.T):
1151        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1152        #        for name in quantity_names:
1153        #           self.values[name][i] = mean(quantities[name][i,:])
1154
1155
1156
1157
1158    def __repr__(self):
1159        #return 'Interpolation function (spatio-temporal)'
1160        return self.statistics()
1161   
1162
1163    def __call__(self, t, point_id = None, x = None, y = None):
1164        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1165
1166        Inputs:
1167          t: time - Model time. Must lie within existing timesteps
1168          point_id: index of one of the preprocessed points.
1169          x, y:     Overrides location, point_id ignored
1170         
1171          If spatial info is present and all of x,y,point_id
1172          are None an exception is raised
1173                   
1174          If no spatial info is present, point_id and x,y arguments are ignored
1175          making f a function of time only.
1176
1177         
1178          FIXME: point_id could also be a slice
1179          FIXME: What if x and y are vectors?
1180          FIXME: What about f(x,y) without t?
1181        """
1182
1183        from math import pi, cos, sin, sqrt
1184        from Numeric import zeros, Float
1185        from utilities.numerical_tools import mean       
1186
1187        if self.spatial is True:
1188            if point_id is None:
1189                if x is None or y is None:
1190                    msg = 'Either point_id or x and y must be specified'
1191                    raise Exception(msg)
1192            else:
1193                if self.interpolation_points is None:
1194                    msg = 'Interpolation_function must be instantiated ' +\
1195                          'with a list of interpolation points before parameter ' +\
1196                          'point_id can be used'
1197                    raise Exception(msg)
1198
1199
1200        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1201        msg += ' does not match model time: %s\n' %t
1202        if t < self.T[0]: raise Exception(msg)
1203        if t > self.T[-1]: raise Exception(msg)
1204
1205        oldindex = self.index #Time index
1206
1207        #Find current time slot
1208        while t > self.T[self.index]: self.index += 1
1209        while t < self.T[self.index]: self.index -= 1
1210
1211        if t == self.T[self.index]:
1212            #Protect against case where t == T[-1] (last time)
1213            # - also works in general when t == T[i]
1214            ratio = 0
1215        else:
1216            #t is now between index and index+1
1217            ratio = (t - self.T[self.index])/\
1218                    (self.T[self.index+1] - self.T[self.index])
1219
1220        #Compute interpolated values
1221        q = zeros(len(self.quantity_names), Float)
1222
1223        for i, name in enumerate(self.quantity_names):
1224            Q = self.precomputed_values[name]
1225
1226            if self.spatial is False:
1227                #If there is no spatial info               
1228                assert len(Q.shape) == 1
1229
1230                Q0 = Q[self.index]
1231                if ratio > 0: Q1 = Q[self.index+1]
1232
1233            else:
1234                if x is not None and y is not None:
1235                    #Interpolate to x, y
1236                   
1237                    raise 'x,y interpolation not yet implemented'
1238                else:
1239                    #Use precomputed point
1240                    Q0 = Q[self.index, point_id]
1241                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1242
1243            #Linear temporal interpolation   
1244            if ratio > 0:
1245                q[i] = Q0 + ratio*(Q1 - Q0)
1246            else:
1247                q[i] = Q0
1248
1249
1250        #Return vector of interpolated values
1251        #if len(q) == 1:
1252        #    return q[0]
1253        #else:
1254        #    return q
1255
1256
1257        #Return vector of interpolated values
1258        #FIXME:
1259        if self.spatial is True:
1260            return q
1261        else:
1262            #Replicate q according to x and y
1263            #This is e.g used for Wind_stress
1264            if x is None or y is None: 
1265                return q
1266            else:
1267                try:
1268                    N = len(x)
1269                except:
1270                    return q
1271                else:
1272                    from Numeric import ones, Float
1273                    #x is a vector - Create one constant column for each value
1274                    N = len(x)
1275                    assert len(y) == N, 'x and y must have same length'
1276                    res = []
1277                    for col in q:
1278                        res.append(col*ones(N, Float))
1279                       
1280                return res
1281
1282
1283    def statistics(self):
1284        """Output statistics about interpolation_function
1285        """
1286       
1287        vertex_coordinates = self.vertex_coordinates
1288        interpolation_points = self.interpolation_points               
1289        quantity_names = self.quantity_names
1290        quantities = self.quantities
1291        precomputed_values = self.precomputed_values                 
1292               
1293        x = vertex_coordinates[:,0]
1294        y = vertex_coordinates[:,1]               
1295
1296        str =  '------------------------------------------------\n'
1297        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1298        str += '  Extent:\n'
1299        str += '    x in [%f, %f], len(x) == %d\n'\
1300               %(min(x), max(x), len(x))
1301        str += '    y in [%f, %f], len(y) == %d\n'\
1302               %(min(y), max(y), len(y))
1303        str += '    t in [%f, %f], len(t) == %d\n'\
1304               %(min(self.T), max(self.T), len(self.T))
1305        str += '  Quantities:\n'
1306        for name in quantity_names:
1307            q = quantities[name][:].flat
1308            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1309
1310        if interpolation_points is not None:   
1311            str += '  Interpolation points (xi, eta):'\
1312                   ' number of points == %d\n' %interpolation_points.shape[0]
1313            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1314                                            max(interpolation_points[:,0]))
1315            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1316                                             max(interpolation_points[:,1]))
1317            str += '  Interpolated quantities (over all timesteps):\n'
1318       
1319            for name in quantity_names:
1320                q = precomputed_values[name][:].flat
1321                str += '    %s at interpolation points in [%f, %f]\n'\
1322                       %(name, min(q), max(q))
1323        str += '------------------------------------------------\n'
1324
1325        return str
1326
1327        #FIXME: Delete
1328        #print '------------------------------------------------'
1329        #print 'Interpolation_function statistics:'
1330        #print '  Extent:'
1331        #print '    x in [%f, %f], len(x) == %d'\
1332        #      %(min(x), max(x), len(x))
1333        #print '    y in [%f, %f], len(y) == %d'\
1334        #      %(min(y), max(y), len(y))
1335        #print '    t in [%f, %f], len(t) == %d'\
1336        #      %(min(self.T), max(self.T), len(self.T))
1337        #print '  Quantities:'
1338        #for name in quantity_names:
1339        #    q = quantities[name][:].flat
1340        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1341        #print '  Interpolation points (xi, eta):'\
1342        #      ' number of points == %d ' %interpolation_points.shape[0]
1343        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1344        #                             max(interpolation_points[:,0]))
1345        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1346        #                              max(interpolation_points[:,1]))
1347        #print '  Interpolated quantities (over all timesteps):'
1348        #
1349        #for name in quantity_names:
1350        #    q = precomputed_values[name][:].flat
1351        #    print '    %s at interpolation points in [%f, %f]'\
1352        #          %(name, min(q), max(q))
1353        #print '------------------------------------------------'
1354
1355
1356#-------------------------------------------------------------
1357if __name__ == "__main__":
1358    """
1359    Load in a mesh and data points with attributes.
1360    Fit the attributes to the mesh.
1361    Save a new mesh file.
1362    """
1363    import os, sys
1364    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1365            %os.path.basename(sys.argv[0])
1366
1367    if len(sys.argv) < 4:
1368        print usage
1369    else:
1370        mesh_file = sys.argv[1]
1371        point_file = sys.argv[2]
1372        mesh_output_file = sys.argv[3]
1373
1374        expand_search = False
1375        if len(sys.argv) > 4:
1376            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1377                expand_search = True
1378            else:
1379                expand_search = False
1380
1381        verbose = False
1382        if len(sys.argv) > 5:
1383            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1384                verbose = False
1385            else:
1386                verbose = True
1387
1388        if len(sys.argv) > 6:
1389            alpha = sys.argv[6]
1390        else:
1391            alpha = DEFAULT_ALPHA
1392
1393        # This is used more for testing
1394        if len(sys.argv) > 7:
1395            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1396                display_errors = False
1397            else:
1398                display_errors = True
1399           
1400        t0 = time.time()
1401        try:
1402            fit_to_mesh_file(mesh_file,
1403                         point_file,
1404                         mesh_output_file,
1405                         alpha,
1406                         verbose= verbose,
1407                         expand_search = expand_search,
1408                         display_errors = display_errors)
1409        except IOError,e:
1410            import sys; sys.exit(1)
1411
1412        print 'That took %.2f seconds' %(time.time()-t0)
1413
Note: See TracBrowser for help on using the repository browser.