source: inundation/fit_interpolate/spike_least_squares.py @ 3330

Last change on this file since 3330 was 2897, checked in by duncan, 18 years ago

close to finishing fit, which replaces least squares

File size: 50.3 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20import exceptions
21class ShapeError(exceptions.Exception): pass
22class FittingError(exceptions.Exception): pass
23
24
25#from general_mesh import General_mesh
26from Numeric import zeros, array, Float, Int, dot, transpose, concatenate, ArrayType, NewAxis
27from pyvolution.mesh import Mesh
28
29from Numeric import zeros, take, compress, array, Float, Int, dot, transpose, concatenate, ArrayType
30from utilities.sparse import Sparse, Sparse_CSR
31from utilities.cg_solve import conjugate_gradient, VectorShapeError
32from utilities.numerical_tools import ensure_numeric, mean, gradient
33
34
35from coordinate_transforms.geo_reference import Geo_reference
36
37import time
38
39
40
41DEFAULT_ALPHA = 0.001
42
43def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
44                     alpha=DEFAULT_ALPHA, verbose= False,
45                     expand_search = False,
46                     data_origin = None,
47                     mesh_origin = None,
48                     precrop = False,
49                     display_errors = True):
50    """
51    Given a mesh file (tsh) and a point attribute file (xya), fit
52    point attributes to the mesh and write a mesh file with the
53    results.
54
55
56    If data_origin is not None it is assumed to be
57    a 3-tuple with geo referenced
58    UTM coordinates (zone, easting, northing)
59
60    NOTE: Throws IOErrors, for a variety of file problems.
61   
62    mesh_origin is the same but refers to the input tsh file.
63    FIXME: When the tsh format contains it own origin, these parameters can go.
64    FIXME: And both origins should be obtained from the specified files.
65    """
66
67    from load_mesh.loadASCII import import_mesh_file, \
68                 import_points_file, export_mesh_file, \
69                 concatinate_attributelist
70
71
72    try:
73        mesh_dict = import_mesh_file(mesh_file)
74    except IOError,e:
75        if display_errors:
76            print "Could not load bad file. ", e
77        raise IOError  #Re-raise exception
78       
79    vertex_coordinates = mesh_dict['vertices']
80    triangles = mesh_dict['triangles']
81    if type(mesh_dict['vertex_attributes']) == ArrayType:
82        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
83    else:
84        old_point_attributes = mesh_dict['vertex_attributes']
85
86    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
87        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
88    else:
89        old_title_list = mesh_dict['vertex_attribute_titles']
90
91    if verbose: print 'tsh file %s loaded' %mesh_file
92
93    # load in the .pts file
94    try:
95        point_dict = import_points_file(point_file, verbose=verbose)
96    except IOError,e:
97        if display_errors:
98            print "Could not load bad file. ", e
99        raise IOError  #Re-raise exception 
100
101    point_coordinates = point_dict['pointlist']
102    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
103
104    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
105        data_origin = point_dict['geo_reference'].get_origin()
106    else:
107        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
108
109    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
110        mesh_origin = mesh_dict['geo_reference'].get_origin()
111    else:
112        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
113
114    if verbose: print "points file loaded"
115    if verbose: print "fitting to mesh"
116    f = fit_to_mesh(vertex_coordinates,
117                    triangles,
118                    point_coordinates,
119                    point_attributes,
120                    alpha = alpha,
121                    verbose = verbose,
122                    expand_search = expand_search,
123                    data_origin = data_origin,
124                    mesh_origin = mesh_origin,
125                    precrop = precrop)
126    if verbose: print "finished fitting to mesh"
127
128    # convert array to list of lists
129    new_point_attributes = f.tolist()
130    #FIXME have this overwrite attributes with the same title - DSG
131    #Put the newer attributes last
132    if old_title_list <> []:
133        old_title_list.extend(title_list)
134        #FIXME can this be done a faster way? - DSG
135        for i in range(len(old_point_attributes)):
136            old_point_attributes[i].extend(new_point_attributes[i])
137        mesh_dict['vertex_attributes'] = old_point_attributes
138        mesh_dict['vertex_attribute_titles'] = old_title_list
139    else:
140        mesh_dict['vertex_attributes'] = new_point_attributes
141        mesh_dict['vertex_attribute_titles'] = title_list
142
143    #FIXME (Ole): Remember to output mesh_origin as well
144    if verbose: print "exporting to file ", mesh_output_file
145
146    try:
147        export_mesh_file(mesh_output_file, mesh_dict)
148    except IOError,e:
149        if display_errors:
150            print "Could not write file. ", e
151        raise IOError
152
153def fit_to_mesh(vertex_coordinates,
154                triangles,
155                point_coordinates,
156                point_attributes,
157                alpha = DEFAULT_ALPHA,
158                verbose = False,
159                acceptable_overshoot = 1.01,
160                expand_search = False,
161                data_origin = None,
162                mesh_origin = None,
163                precrop = False,
164                use_cache = False):
165    """
166    Fit a smooth surface to a triangulation,
167    given data points with attributes.
168
169
170        Inputs:
171
172          vertex_coordinates: List of coordinate pairs [xi, eta] of points
173          constituting mesh (or a an m x 2 Numeric array)
174
175          triangles: List of 3-tuples (or a Numeric array) of
176          integers representing indices of all vertices in the mesh.
177
178          point_coordinates: List of coordinate pairs [x, y] of data points
179          (or an nx2 Numeric array)
180
181          alpha: Smoothing parameter.
182
183          acceptable overshoot: controls the allowed factor by which fitted values
184          may exceed the value of input data. The lower limit is defined
185          as min(z) - acceptable_overshoot*delta z and upper limit
186          as max(z) + acceptable_overshoot*delta z
187         
188
189          point_attributes: Vector or array of data at the point_coordinates.
190
191          data_origin and mesh_origin are 3-tuples consisting of
192          UTM zone, easting and northing. If specified
193          point coordinates and vertex coordinates are assumed to be
194          relative to their respective origins.
195
196    """
197
198    if use_cache is True:
199        from caching.caching import cache
200        interp = cache(_interpolation,
201                       (vertex_coordinates,
202                        triangles,
203                        point_coordinates),
204                       {'alpha': alpha,
205                        'verbose': verbose,
206                        'expand_search': expand_search,
207                        'data_origin': data_origin,
208                        'mesh_origin': mesh_origin,
209                        'precrop': precrop},
210                       verbose = verbose)       
211       
212    else:
213        interp = Interpolation(vertex_coordinates,
214                               triangles,
215                               point_coordinates,
216                               alpha = alpha,
217                               verbose = verbose,
218                               expand_search = expand_search,
219                               data_origin = data_origin,
220                               mesh_origin = mesh_origin,
221                               precrop = precrop)
222
223    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
224
225
226    #Sanity check
227    point_coordinates = ensure_numeric(point_coordinates)
228    vertex_coordinates = ensure_numeric(vertex_coordinates)
229
230    #Data points
231    X = point_coordinates[:,0]
232    Y = point_coordinates[:,1] 
233    Z = ensure_numeric(point_attributes)
234    if len(Z.shape) == 1:
235        Z = Z[:, NewAxis]
236       
237
238    #Data points inside mesh boundary
239    indices = interp.point_indices
240    if indices is not None:   
241        Xc = take(X, indices)
242        Yc = take(Y, indices)   
243        Zc = take(Z, indices)
244    else:
245        Xc = X
246        Yc = Y 
247        Zc = Z       
248   
249    #Vertex coordinates
250    Xi = vertex_coordinates[:,0]
251    Eta = vertex_coordinates[:,1]       
252    Zeta = ensure_numeric(vertex_attributes)
253    if len(Zeta.shape) == 1:
254        Zeta = Zeta[:, NewAxis]   
255
256    for i in range(Zeta.shape[1]): #For each attribute
257        zeta = Zeta[:,i]
258        z = Z[:,i]               
259        zc = Zc[:,i]
260
261        max_zc = max(zc)
262        min_zc = min(zc)
263        delta_zc = max_zc-min_zc
264        upper_limit = max_zc + delta_zc*acceptable_overshoot
265        lower_limit = min_zc - delta_zc*acceptable_overshoot       
266       
267
268        if max(zeta) > upper_limit or min(zeta) < lower_limit:
269            msg = 'Least sqares produced values outside the allowed '
270            msg += 'range [%f, %f].\n' %(lower_limit, upper_limit)
271            msg += 'z in [%f, %f], zeta in [%f, %f].\n' %(min_zc, max_zc,
272                                                          min(zeta), max(zeta))
273            msg += 'If greater range is needed, increase the value of '
274            msg += 'acceptable_fit_overshoot (currently %.2f).\n' %(acceptable_overshoot)
275
276
277            offending_vertices = (zeta > upper_limit or zeta < lower_limit)
278            Xi_c = compress(offending_vertices, Xi)
279            Eta_c = compress(offending_vertices, Eta)
280            offending_coordinates = concatenate((Xi_c[:, NewAxis],
281                                                 Eta_c[:, NewAxis]),
282                                                axis=1)
283
284            msg += 'Offending locations:\n %s' %(offending_coordinates)
285           
286            raise FittingError, msg
287
288
289   
290        if verbose:
291            print '+------------------------------------------------'
292            print 'Least squares statistics'
293            print '+------------------------------------------------'   
294            print 'points: %d points' %(len(z))
295            print '    x in [%f, %f]'%(min(X), max(X))
296            print '    y in [%f, %f]'%(min(Y), max(Y))
297            print '    z in [%f, %f]'%(min(z), max(z))
298            print
299
300            if indices is not None:
301                print 'Cropped points: %d points' %(len(zc))
302                print '    x in [%f, %f]'%(min(Xc), max(Xc))
303                print '    y in [%f, %f]'%(min(Yc), max(Yc))
304                print '    z in [%f, %f]'%(min(zc), max(zc))
305                print
306           
307
308            print 'Mesh: %d vertices' %(len(zeta))
309            print '    xi in [%f, %f]'%(min(Xi), max(Xi))
310            print '    eta in [%f, %f]'%(min(Eta), max(Eta))
311            print '    zeta in [%f, %f]'%(min(zeta), max(zeta))
312            print '+------------------------------------------------'
313
314    return vertex_attributes
315
316
317
318def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
319                    verbose = False, reduction = 1):
320    """Fits attributes from pts file to MxN rectangular mesh
321
322    Read pts file and create rectangular mesh of resolution MxN such that
323    it covers all points specified in pts file.
324
325    FIXME: This may be a temporary function until we decide on
326    netcdf formats etc
327
328    FIXME: Uses elevation hardwired
329    """
330
331    import  mesh_factory
332    from load_mesh.loadASCII import import_points_file
333   
334    if verbose: print 'Read pts'
335    points_dict = import_points_file(pts_name)
336    #points, attributes = util.read_xya(pts_name)
337
338    #Reduce number of points a bit
339    points = points_dict['pointlist'][::reduction]
340    elevation = points_dict['attributelist']['elevation']  #Must be elevation
341    elevation = elevation[::reduction]
342
343    if verbose: print 'Got %d data points' %len(points)
344
345    if verbose: print 'Create mesh'
346    #Find extent
347    max_x = min_x = points[0][0]
348    max_y = min_y = points[0][1]
349    for point in points[1:]:
350        x = point[0]
351        if x > max_x: max_x = x
352        if x < min_x: min_x = x
353        y = point[1]
354        if y > max_y: max_y = y
355        if y < min_y: min_y = y
356
357    #Create appropriate mesh
358    vertex_coordinates, triangles, boundary =\
359         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
360                                (min_x, min_y))
361
362    #Fit attributes to mesh
363    vertex_attributes = fit_to_mesh(vertex_coordinates,
364                        triangles,
365                        points,
366                        elevation, alpha=alpha, verbose=verbose)
367
368
369
370    return vertex_coordinates, triangles, boundary, vertex_attributes
371
372
373def _interpolation(*args, **kwargs):
374    """Private function for use with caching. Reason is that classes
375    may change their byte code between runs which is annoying.
376    """
377   
378    return Interpolation(*args, **kwargs)
379
380
381class Interpolation:
382
383    def __init__(self,
384                 vertex_coordinates,
385                 triangles,
386                 z,
387                 point_coordinates = None,
388                 alpha = None,
389                 verbose = False,
390                 expand_search = True,
391                 interp_only = False,
392                 max_points_per_cell = 30,
393                 mesh_origin = None,
394                 data_origin = None,
395                 precrop = False):
396
397
398        """ Build interpolation matrix mapping from
399        function values at vertices to function values at data points
400
401        Inputs:
402
403          vertex_coordinates: List of coordinate pairs [xi, eta] of
404          points constituting mesh (or a an m x 2 Numeric array)
405          Points may appear multiple times
406          (e.g. if vertices have discontinuities)
407
408          triangles: List of 3-tuples (or a Numeric array) of
409          integers representing indices of all vertices in the mesh.
410
411          point_coordinates: List of coordinate pairs [x, y] of
412          data points (or an nx2 Numeric array)
413          If point_coordinates is absent, only smoothing matrix will
414          be built
415
416          alpha: Smoothing parameter
417
418          data_origin and mesh_origin are 3-tuples consisting of
419          UTM zone, easting and northing. If specified
420          point coordinates and vertex coordinates are assumed to be
421          relative to their respective origins.
422         
423          z: Single 1d vector or array of data at the point_coordinates.
424
425        """
426
427        #Convert input to Numeric arrays
428        triangles = ensure_numeric(triangles, Int)
429        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
430
431        #Build underlying mesh
432        if verbose: print 'Building mesh'
433        #self.mesh = General_mesh(vertex_coordinates, triangles,
434        #FIXME: Trying the normal mesh while testing precrop,
435        #       The functionality of boundary_polygon is needed for that
436
437        #FIXME - geo ref does not have to go into mesh.
438        # Change the point co-ords to conform to the
439        # mesh co-ords early in the code
440        if mesh_origin is None:
441            geo = None
442        else:
443            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
444        self.mesh = Mesh(vertex_coordinates, triangles,
445                         geo_reference = geo)
446
447        if verbose: print 'Checking mesh integrity'
448        self.mesh.check_integrity()
449
450        if verbose: print 'Mesh integrity checked'
451       
452        self.data_origin = data_origin
453
454        self.point_indices = None
455
456        #Smoothing parameter
457        if alpha is None:
458            self.alpha = DEFAULT_ALPHA
459        else:   
460            self.alpha = alpha
461
462
463        if point_coordinates is not None:
464            if verbose: print 'Building interpolation matrix'
465            self.build_interpolation_matrix_A(point_coordinates,
466                                              z,
467                                              verbose = verbose,
468                                              expand_search = expand_search,
469                                              interp_only = interp_only, 
470                                              max_points_per_cell =\
471                                              max_points_per_cell,
472                                              data_origin = data_origin,
473                                              precrop = precrop)
474        #Build coefficient matrices
475        if interp_only == False:
476            self.build_coefficient_matrix_B(point_coordinates,
477                                        verbose = verbose,
478                                        expand_search = expand_search,
479                                        max_points_per_cell =\
480                                        max_points_per_cell,
481                                        data_origin = data_origin,
482                                        precrop = precrop)
483        if verbose: print 'Finished interpolation'
484
485    def set_point_coordinates(self, point_coordinates,
486                              data_origin = None,
487                              verbose = False,
488                              precrop = True):
489        """
490        A public interface to setting the point co-ordinates.
491        """
492        if point_coordinates is not None:
493            if verbose: print 'Building interpolation matrix'
494            self.build_interpolation_matrix_A(point_coordinates,
495                                              z,
496                                              verbose = verbose,
497                                              data_origin = data_origin,
498                                              precrop = precrop)
499        self.build_coefficient_matrix_B(point_coordinates, data_origin)
500
501    def build_coefficient_matrix_B(self, point_coordinates=None,
502                                   verbose = False, expand_search = True,
503                                   max_points_per_cell=30,
504                                   data_origin = None,
505                                   precrop = False):
506        """Build final coefficient matrix"""
507
508
509        if self.alpha <> 0:
510            if verbose: print 'Building smoothing matrix'
511            self.build_smoothing_matrix_D()
512
513        if point_coordinates is not None:
514            if self.alpha <> 0:
515                self.B = self.AtA + self.alpha*self.D
516            else:
517                self.B = self.AtA
518
519            #Convert self.B matrix to CSR format for faster matrix vector
520            self.B = Sparse_CSR(self.B)
521
522    def build_interpolation_matrix_A(self, point_coordinates,
523                                     z,
524                                     verbose = False, expand_search = True,
525                                     max_points_per_cell=30,
526                                     data_origin = None,
527                                     precrop = False,
528                                     interp_only = False):
529        """Build n x m interpolation matrix, where
530        n is the number of data points and
531        m is the number of basis functions phi_k (one per vertex)
532
533        This algorithm uses a quad tree data structure for fast binning of data points
534        origin is a 3-tuple consisting of UTM zone, easting and northing.
535        If specified coordinates are assumed to be relative to this origin.
536
537        This one will override any data_origin that may be specified in
538        interpolation instance
539
540        """
541
542        from pyvolution.quad import build_quadtree
543        from utilities.polygon import inside_polygon
544       
545        #Convert input to Numeric arrays
546        z = ensure_numeric(z, Float)
547           
548        #Convert input to Numeric arrays just in case.
549        point_coordinates = ensure_numeric(point_coordinates, Float)
550
551        #Build n x m interpolation matrix
552        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
553        n = point_coordinates.shape[0]     #Nbr of data points
554        if len(z.shape) > 1:
555            att_num = z.shape[1]
556        else:
557            att_num = 0
558           
559        assert z.shape[0] == point_coordinates.shape[0] 
560        if verbose: print 'Number of attributes: %d' %att_num
561        if verbose: print 'Number of datapoints: %d' %n
562        if verbose: print 'Number of basis functions: %d' %m
563
564        self.A = Sparse(n,m)
565        self.AtA = Sparse(m,m)
566        self.Atz = zeros((m,att_num), Float)
567        print "self.Atz",self.Atz
568        self.Atz[0] = 100
569        print "self.Atz",self.Atz
570
571        #Build quad tree of vertices
572        root = build_quadtree(self.mesh,
573                              max_points_per_cell = max_points_per_cell)
574        #root.show()
575        self.expanded_quad_searches = []
576        #Compute matrix elements
577        for i in range(n):
578            #For each data_coordinate point
579
580            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
581            x = point_coordinates[i]
582
583            #Find vertices near x
584            candidate_vertices = root.search(x[0], x[1])
585            is_more_elements = True
586
587            element_found, sigma0, sigma1, sigma2, k = \
588                self.search_triangles_of_vertices(candidate_vertices, x)
589            first_expansion = True
590            while not element_found and is_more_elements and expand_search:
591                #if verbose: print 'Expanding search'
592                if first_expansion == True:
593                    self.expanded_quad_searches.append(1)
594                    first_expansion = False
595                else:
596                    end = len(self.expanded_quad_searches) - 1
597                    assert end >= 0
598                    self.expanded_quad_searches[end] += 1
599                candidate_vertices, branch = root.expand_search()
600                if branch == []:
601                    # Searching all the verts from the root cell that haven't
602                    # been searched.  This is the last try
603                    element_found, sigma0, sigma1, sigma2, k = \
604                      self.search_triangles_of_vertices(candidate_vertices, x)
605                    is_more_elements = False
606                else:
607                    element_found, sigma0, sigma1, sigma2, k = \
608                      self.search_triangles_of_vertices(candidate_vertices, x)
609
610               
611            #Update interpolation matrix A if necessary
612            if element_found is True:
613                #Assign values to matrix A
614
615                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
616                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
617                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
618
619                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
620                js     = [j0,j1,j2]
621
622                for j in js:
623                    self.A[i,j] = sigmas[j]
624                    #self.Atz[n-i-1] +=  sigmas[j]*z[n-j-1]
625                    #print "i",i
626                    #print "j",j
627                    #print "self.Atz",self.Atz
628                    #print "z",z
629                    #a=self.Atz[j]
630                    #b=z[i]
631                    self.Atz[j] +=  sigmas[j]*z[i]
632                    #print "m-i-1",m-i-1
633                    #print "n-j-1",n-j-1
634                    #print "self.A[i,j]",self.A[i,j]
635                    #print "z[n-j-1]",z[n-j-1]
636                    #print "z[i]",z[i]
637                    #print "",
638                    #print "self.Atz",self.Atz                   
639                    for k in js:
640                        if interp_only == False:
641                            self.AtA[j,k] += sigmas[j]*sigmas[k]
642            else:
643                pass
644                #Ok if there is no triangle for datapoint
645                #(as in brute force version)
646                #raise 'Could not find triangle for point', x
647
648
649
650    def search_triangles_of_vertices(self, candidate_vertices, x):
651            #Find triangle containing x:
652            element_found = False
653
654            # This will be returned if element_found = False
655            sigma2 = -10.0
656            sigma0 = -10.0
657            sigma1 = -10.0
658            k = -10.0
659            #print "*$* candidate_vertices", candidate_vertices
660            #For all vertices in same cell as point x
661            for v in candidate_vertices:
662                #FIXME (DSG-DSG): this catches verts with no triangle.
663                #Currently pmesh is producing these.
664                #this should be stopped,
665                if self.mesh.vertexlist[v] is None:
666                    continue
667                #for each triangle id (k) which has v as a vertex
668                for k, _ in self.mesh.vertexlist[v]:
669
670                    #Get the three vertex_points of candidate triangle
671                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
672                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
673                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
674
675                    #print "PDSG - k", k
676                    #print "PDSG - xi0", xi0
677                    #print "PDSG - xi1", xi1
678                    #print "PDSG - xi2", xi2
679                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
680                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
681
682                    #Get the three normals
683                    n0 = self.mesh.get_normal(k, 0)
684                    n1 = self.mesh.get_normal(k, 1)
685                    n2 = self.mesh.get_normal(k, 2)
686
687
688                    #Compute interpolation
689                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
690                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
691                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
692
693                    #print "PDSG - sigma0", sigma0
694                    #print "PDSG - sigma1", sigma1
695                    #print "PDSG - sigma2", sigma2
696
697                    #FIXME: Maybe move out to test or something
698                    epsilon = 1.0e-6
699                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
700
701                    #Check that this triangle contains the data point
702
703                    #Sigmas can get negative within
704                    #machine precision on some machines (e.g nautilus)
705                    #Hence the small eps
706                    eps = 1.0e-15
707                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
708                        element_found = True
709                        break
710
711                if element_found is True:
712                    #Don't look for any other triangle
713                    break
714            return element_found, sigma0, sigma1, sigma2, k
715
716
717
718    def build_interpolation_matrix_A_brute(self, point_coordinates):
719        """Build n x m interpolation matrix, where
720        n is the number of data points and
721        m is the number of basis functions phi_k (one per vertex)
722
723        This is the brute force which is too slow for large problems,
724but could be used for testing
725        """
726
727
728        #Convert input to Numeric arrays
729        point_coordinates = ensure_numeric(point_coordinates, Float)
730
731        #Build n x m interpolation matrix
732        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
733        n = point_coordinates.shape[0]     #Nbr of data points
734
735        self.A = Sparse(n,m)
736        self.AtA = Sparse(m,m)
737
738        #Compute matrix elements
739        for i in range(n):
740            #For each data_coordinate point
741
742            x = point_coordinates[i]
743            element_found = False
744            k = 0
745            while not element_found and k < len(self.mesh):
746                #For each triangle (brute force)
747                #FIXME: Real algorithm should only visit relevant triangles
748
749                #Get the three vertex_points
750                xi0 = self.mesh.get_vertex_coordinate(k, 0)
751                xi1 = self.mesh.get_vertex_coordinate(k, 1)
752                xi2 = self.mesh.get_vertex_coordinate(k, 2)
753
754                #Get the three normals
755                n0 = self.mesh.get_normal(k, 0)
756                n1 = self.mesh.get_normal(k, 1)
757                n2 = self.mesh.get_normal(k, 2)
758
759                #Compute interpolation
760                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
761                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
762                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
763
764                #FIXME: Maybe move out to test or something
765                epsilon = 1.0e-6
766                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
767
768                #Check that this triangle contains data point
769                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
770                    element_found = True
771                    #Assign values to matrix A
772
773                    j0 = self.mesh.triangles[k,0] #Global vertex id
774                    #self.A[i, j0] = sigma0
775
776                    j1 = self.mesh.triangles[k,1] #Global vertex id
777                    #self.A[i, j1] = sigma1
778
779                    j2 = self.mesh.triangles[k,2] #Global vertex id
780                    #self.A[i, j2] = sigma2
781
782                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
783                    js     = [j0,j1,j2]
784
785                    for j in js:
786                        self.A[i,j] = sigmas[j]
787                        for k in js:
788                            self.AtA[j,k] += sigmas[j]*sigmas[k]
789                k = k+1
790
791
792
793    def get_A(self):
794        return self.A.todense()
795
796    def get_B(self):
797        return self.B.todense()
798
799    def get_D(self):
800        return self.D.todense()
801
802        #FIXME: Remember to re-introduce the 1/n factor in the
803        #interpolation term
804
805    def build_smoothing_matrix_D(self):
806        """Build m x m smoothing matrix, where
807        m is the number of basis functions phi_k (one per vertex)
808
809        The smoothing matrix is defined as
810
811        D = D1 + D2
812
813        where
814
815        [D1]_{k,l} = \int_\Omega
816           \frac{\partial \phi_k}{\partial x}
817           \frac{\partial \phi_l}{\partial x}\,
818           dx dy
819
820        [D2]_{k,l} = \int_\Omega
821           \frac{\partial \phi_k}{\partial y}
822           \frac{\partial \phi_l}{\partial y}\,
823           dx dy
824
825
826        The derivatives \frac{\partial \phi_k}{\partial x},
827        \frac{\partial \phi_k}{\partial x} for a particular triangle
828        are obtained by computing the gradient a_k, b_k for basis function k
829        """
830
831        #FIXME: algorithm might be optimised by computing local 9x9
832        #"element stiffness matrices:
833
834        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
835
836        self.D = Sparse(m,m)
837
838        #For each triangle compute contributions to D = D1+D2
839        for i in range(len(self.mesh)):
840
841            #Get area
842            area = self.mesh.areas[i]
843
844            #Get global vertex indices
845            v0 = self.mesh.triangles[i,0]
846            v1 = self.mesh.triangles[i,1]
847            v2 = self.mesh.triangles[i,2]
848
849            #Get the three vertex_points
850            xi0 = self.mesh.get_vertex_coordinate(i, 0)
851            xi1 = self.mesh.get_vertex_coordinate(i, 1)
852            xi2 = self.mesh.get_vertex_coordinate(i, 2)
853
854            #Compute gradients for each vertex
855            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
856                              1, 0, 0)
857
858            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
859                              0, 1, 0)
860
861            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
862                              0, 0, 1)
863
864            #Compute diagonal contributions
865            self.D[v0,v0] += (a0*a0 + b0*b0)*area
866            self.D[v1,v1] += (a1*a1 + b1*b1)*area
867            self.D[v2,v2] += (a2*a2 + b2*b2)*area
868
869            #Compute contributions for basis functions sharing edges
870            e01 = (a0*a1 + b0*b1)*area
871            self.D[v0,v1] += e01
872            self.D[v1,v0] += e01
873
874            e12 = (a1*a2 + b1*b2)*area
875            self.D[v1,v2] += e12
876            self.D[v2,v1] += e12
877
878            e20 = (a2*a0 + b2*b0)*area
879            self.D[v2,v0] += e20
880            self.D[v0,v2] += e20
881
882
883    def fit(self, z):
884        """Fit a smooth surface to given 1d array of data points z.
885
886        The smooth surface is computed at each vertex in the underlying
887        mesh using the formula given in the module doc string.
888
889        Pre Condition:
890          self.A, self.AtA and self.B have been initialised
891
892        Inputs:
893          z: Single 1d vector or array of data at the point_coordinates.
894        """
895
896        #Convert input to Numeric arrays
897        z = ensure_numeric(z, Float)
898
899        if len(z.shape) > 1 :
900            raise VectorShapeError, 'Can only deal with 1d data vector'
901
902        if self.point_indices is not None:
903            #Remove values for any points that were outside mesh
904            z = take(z, self.point_indices)
905
906        #Compute right hand side based on data
907        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
908        # after a matrix is built, before calcs.
909        Atz = self.Atz
910        #print "self.get_A()", self.get_A()
911        #print "self.Atz",self.Atz
912        #print 'z', z
913        Atz = self.A.trans_mult(z)
914        #print "self.A.trans_mult(z)",self.A.trans_mult(z)
915
916
917        #Check sanity
918        n, m = self.A.shape
919        if n<m and self.alpha == 0.0:
920            msg = 'ERROR (least_squares): Too few data points\n'
921            msg += 'There are only %d data points and alpha == 0. ' %n
922            msg += 'Need at least %d\n' %m
923            msg += 'Alternatively, set smoothing parameter alpha to a small '
924            msg += 'positive value,\ne.g. 1.0e-3.'
925            raise Exception(msg)
926
927
928
929        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
930        #FIXME: Should we store the result here for later use? (ON)
931
932
933    def fit_points(self, z, verbose=False):
934        """Like fit, but more robust when each point has two or more attributes
935        FIXME (Ole): The name fit_points doesn't carry any meaning
936        for me. How about something like fit_multiple or fit_columns?
937        """
938
939        try:
940            if verbose: print 'Solving penalised least_squares problem'
941            return self.fit(z)
942        except VectorShapeError, e:
943            # broadcasting is not supported.
944
945            #Convert input to Numeric arrays
946            z = ensure_numeric(z, Float)
947
948            #Build n x m interpolation matrix
949            m = self.mesh.coordinates.shape[0] #Number of vertices
950            n = z.shape[1]                     #Number of data points
951
952            f = zeros((m,n), Float) #Resulting columns
953
954            for i in range(z.shape[1]):
955                f[:,i] = self.fit(z[:,i])
956
957            return f
958
959
960    def interpolate(self, f):
961        """Evaluate smooth surface f at data points implied in self.A.
962
963        The mesh values representing a smooth surface are
964        assumed to be specified in f. This argument could,
965        for example have been obtained from the method self.fit()
966
967        Pre Condition:
968          self.A has been initialised
969
970        Inputs:
971          f: Vector or array of data at the mesh vertices.
972          If f is an array, interpolation will be done for each column as
973          per underlying matrix-matrix multiplication
974
975        Output:
976          Interpolated values at data points implied in self.A
977
978        """
979
980        return self.A * f
981
982    def cull_outsiders(self, f):
983        pass
984
985
986
987
988class Interpolation_function:
989    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
990    which is interpolated from time series defined at vertices of
991    triangular mesh (such as those stored in sww files)
992
993    Let m be the number of vertices, n the number of triangles
994    and p the number of timesteps.
995
996    Mandatory input
997        time:               px1 array of monotonously increasing times (Float)
998        quantities:         Dictionary of arrays or 1 array (Float)
999                            The arrays must either have dimensions pxm or mx1.
1000                            The resulting function will be time dependent in
1001                            the former case while it will be constant with
1002                            respect to time in the latter case.
1003       
1004    Optional input:
1005        quantity_names:     List of keys into the quantities dictionary
1006        vertex_coordinates: mx2 array of coordinates (Float)
1007        triangles:          nx3 array of indices into vertex_coordinates (Int)
1008        interpolation_points: Nx2 array of coordinates to be interpolated to
1009        verbose:            Level of reporting
1010   
1011   
1012    The quantities returned by the callable object are specified by
1013    the list quantities which must contain the names of the
1014    quantities to be returned and also reflect the order, e.g. for
1015    the shallow water wave equation, on would have
1016    quantities = ['stage', 'xmomentum', 'ymomentum']
1017
1018    The parameter interpolation_points decides at which points interpolated
1019    quantities are to be computed whenever object is called.
1020    If None, return average value
1021    """
1022
1023   
1024   
1025    def __init__(self,
1026                 time,
1027                 quantities,
1028                 quantity_names = None, 
1029                 vertex_coordinates = None,
1030                 triangles = None,
1031                 interpolation_points = None,
1032                 verbose = False):
1033        """Initialise object and build spatial interpolation if required
1034        """
1035
1036        from Numeric import array, zeros, Float, alltrue, concatenate,\
1037             reshape, ArrayType
1038
1039
1040        from config import time_format
1041        import types
1042
1043
1044
1045        #Check temporal info
1046        time = ensure_numeric(time)       
1047        msg = 'Time must be a monotonuosly '
1048        msg += 'increasing sequence %s' %time
1049        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
1050
1051
1052        #Check if quantities is a single array only
1053        if type(quantities) != types.DictType:
1054            quantities = ensure_numeric(quantities)
1055            quantity_names = ['Attribute']
1056
1057            #Make it a dictionary
1058            quantities = {quantity_names[0]: quantities}
1059
1060
1061        #Use keys if no names are specified
1062        if quantity_names is None:
1063            quantity_names = quantities.keys()
1064
1065
1066        #Check spatial info
1067        if vertex_coordinates is None:
1068            self.spatial = False
1069        else:   
1070            vertex_coordinates = ensure_numeric(vertex_coordinates)
1071
1072            assert triangles is not None, 'Triangles array must be specified'
1073            triangles = ensure_numeric(triangles)
1074            self.spatial = True           
1075           
1076
1077 
1078        #Save for use with statistics
1079        self.quantity_names = quantity_names       
1080        self.quantities = quantities       
1081        self.vertex_coordinates = vertex_coordinates
1082        self.interpolation_points = interpolation_points
1083        self.time = time[:]  # Time assumed to be relative to starttime
1084        self.index = 0    # Initial time index
1085        self.precomputed_values = {}
1086           
1087
1088
1089        #Precomputed spatial interpolation if requested
1090        if interpolation_points is not None:
1091            if self.spatial is False:
1092                raise 'Triangles and vertex_coordinates must be specified'
1093           
1094            try:
1095                self.interpolation_points = ensure_numeric(interpolation_points)
1096            except:
1097                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1098                      'or a list of points\n'
1099                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1100                                      '...')
1101                raise msg
1102
1103
1104            m = len(self.interpolation_points)
1105            p = len(self.time)
1106           
1107            for name in quantity_names:
1108                self.precomputed_values[name] = zeros((p, m), Float)
1109
1110            #Build interpolator
1111            interpol = Interpolation(vertex_coordinates,
1112                                     triangles,
1113                                     point_coordinates = \
1114                                     self.interpolation_points,
1115                                     alpha = 0,
1116                                     precrop = False, 
1117                                     verbose = verbose)
1118
1119            if verbose: print 'Interpolate'
1120            for i, t in enumerate(self.time):
1121                #Interpolate quantities at this timestep
1122                if verbose and i%((p+10)/10)==0:
1123                    print ' time step %d of %d' %(i, p)
1124                   
1125                for name in quantity_names:
1126                    if len(quantities[name].shape) == 2:
1127                        result = interpol.interpolate(quantities[name][i,:])
1128                    else:
1129                       #Assume no time dependency
1130                       result = interpol.interpolate(quantities[name][:])
1131                       
1132                    self.precomputed_values[name][i, :] = result
1133                   
1134                       
1135
1136            #Report
1137            if verbose:
1138                print self.statistics()
1139                #self.print_statistics()
1140           
1141        else:
1142            #Store quantitites as is
1143            for name in quantity_names:
1144                self.precomputed_values[name] = quantities[name]
1145
1146
1147        #else:
1148        #    #Return an average, making this a time series
1149        #    for name in quantity_names:
1150        #        self.values[name] = zeros(len(self.time), Float)
1151        #
1152        #    if verbose: print 'Compute mean values'
1153        #    for i, t in enumerate(self.time):
1154        #        if verbose: print ' time step %d of %d' %(i, len(self.time))
1155        #        for name in quantity_names:
1156        #           self.values[name][i] = mean(quantities[name][i,:])
1157
1158
1159
1160
1161    def __repr__(self):
1162        #return 'Interpolation function (spatio-temporal)'
1163        return self.statistics()
1164   
1165
1166    def __call__(self, t, point_id = None, x = None, y = None):
1167        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1168
1169        Inputs:
1170          t: time - Model time. Must lie within existing timesteps
1171          point_id: index of one of the preprocessed points.
1172          x, y:     Overrides location, point_id ignored
1173         
1174          If spatial info is present and all of x,y,point_id
1175          are None an exception is raised
1176                   
1177          If no spatial info is present, point_id and x,y arguments are ignored
1178          making f a function of time only.
1179
1180         
1181          FIXME: point_id could also be a slice
1182          FIXME: What if x and y are vectors?
1183          FIXME: What about f(x,y) without t?
1184        """
1185
1186        from math import pi, cos, sin, sqrt
1187        from Numeric import zeros, Float
1188        from utilities.numerical_tools import mean       
1189
1190        if self.spatial is True:
1191            if point_id is None:
1192                if x is None or y is None:
1193                    msg = 'Either point_id or x and y must be specified'
1194                    raise Exception(msg)
1195            else:
1196                if self.interpolation_points is None:
1197                    msg = 'Interpolation_function must be instantiated ' +\
1198                          'with a list of interpolation points before parameter ' +\
1199                          'point_id can be used'
1200                    raise Exception(msg)
1201
1202
1203        msg = 'Time interval [%s:%s]' %(self.time[0], self.time[-1])
1204        msg += ' does not match model time: %s\n' %t
1205        if t < self.time[0]: raise Exception(msg)
1206        if t > self.time[-1]: raise Exception(msg)
1207
1208        oldindex = self.index #Time index
1209
1210        #Find current time slot
1211        while t > self.time[self.index]: self.index += 1
1212        while t < self.time[self.index]: self.index -= 1
1213
1214        if t == self.time[self.index]:
1215            #Protect against case where t == T[-1] (last time)
1216            # - also works in general when t == T[i]
1217            ratio = 0
1218        else:
1219            #t is now between index and index+1
1220            ratio = (t - self.time[self.index])/\
1221                    (self.time[self.index+1] - self.time[self.index])
1222
1223        #Compute interpolated values
1224        q = zeros(len(self.quantity_names), Float)
1225
1226        for i, name in enumerate(self.quantity_names):
1227            Q = self.precomputed_values[name]
1228
1229            if self.spatial is False:
1230                #If there is no spatial info               
1231                assert len(Q.shape) == 1
1232
1233                Q0 = Q[self.index]
1234                if ratio > 0: Q1 = Q[self.index+1]
1235
1236            else:
1237                if x is not None and y is not None:
1238                    #Interpolate to x, y
1239                   
1240                    raise 'x,y interpolation not yet implemented'
1241                else:
1242                    #Use precomputed point
1243                    Q0 = Q[self.index, point_id]
1244                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1245
1246            #Linear temporal interpolation   
1247            if ratio > 0:
1248                q[i] = Q0 + ratio*(Q1 - Q0)
1249            else:
1250                q[i] = Q0
1251
1252
1253        #Return vector of interpolated values
1254        #if len(q) == 1:
1255        #    return q[0]
1256        #else:
1257        #    return q
1258
1259
1260        #Return vector of interpolated values
1261        #FIXME:
1262        if self.spatial is True:
1263            return q
1264        else:
1265            #Replicate q according to x and y
1266            #This is e.g used for Wind_stress
1267            if x is None or y is None: 
1268                return q
1269            else:
1270                try:
1271                    N = len(x)
1272                except:
1273                    return q
1274                else:
1275                    from Numeric import ones, Float
1276                    #x is a vector - Create one constant column for each value
1277                    N = len(x)
1278                    assert len(y) == N, 'x and y must have same length'
1279                    res = []
1280                    for col in q:
1281                        res.append(col*ones(N, Float))
1282                       
1283                return res
1284
1285
1286    def statistics(self):
1287        """Output statistics about interpolation_function
1288        """
1289       
1290        vertex_coordinates = self.vertex_coordinates
1291        interpolation_points = self.interpolation_points               
1292        quantity_names = self.quantity_names
1293        quantities = self.quantities
1294        precomputed_values = self.precomputed_values                 
1295               
1296        x = vertex_coordinates[:,0]
1297        y = vertex_coordinates[:,1]               
1298
1299        str =  '------------------------------------------------\n'
1300        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1301        str += '  Extent:\n'
1302        str += '    x in [%f, %f], len(x) == %d\n'\
1303               %(min(x), max(x), len(x))
1304        str += '    y in [%f, %f], len(y) == %d\n'\
1305               %(min(y), max(y), len(y))
1306        str += '    t in [%f, %f], len(t) == %d\n'\
1307               %(min(self.time), max(self.time), len(self.time))
1308        str += '  Quantities:\n'
1309        for name in quantity_names:
1310            q = quantities[name][:].flat
1311            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1312
1313        if interpolation_points is not None:   
1314            str += '  Interpolation points (xi, eta):'\
1315                   ' number of points == %d\n' %interpolation_points.shape[0]
1316            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1317                                            max(interpolation_points[:,0]))
1318            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1319                                             max(interpolation_points[:,1]))
1320            str += '  Interpolated quantities (over all timesteps):\n'
1321       
1322            for name in quantity_names:
1323                q = precomputed_values[name][:].flat
1324                str += '    %s at interpolation points in [%f, %f]\n'\
1325                       %(name, min(q), max(q))
1326        str += '------------------------------------------------\n'
1327
1328        return str
1329
1330        #FIXME: Delete
1331        #print '------------------------------------------------'
1332        #print 'Interpolation_function statistics:'
1333        #print '  Extent:'
1334        #print '    x in [%f, %f], len(x) == %d'\
1335        #      %(min(x), max(x), len(x))
1336        #print '    y in [%f, %f], len(y) == %d'\
1337        #      %(min(y), max(y), len(y))
1338        #print '    t in [%f, %f], len(t) == %d'\
1339        #      %(min(self.time), max(self.time), len(self.time))
1340        #print '  Quantities:'
1341        #for name in quantity_names:
1342        #    q = quantities[name][:].flat
1343        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1344        #print '  Interpolation points (xi, eta):'\
1345        #      ' number of points == %d ' %interpolation_points.shape[0]
1346        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1347        #                             max(interpolation_points[:,0]))
1348        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1349        #                              max(interpolation_points[:,1]))
1350        #print '  Interpolated quantities (over all timesteps):'
1351        #
1352        #for name in quantity_names:
1353        #    q = precomputed_values[name][:].flat
1354        #    print '    %s at interpolation points in [%f, %f]'\
1355        #          %(name, min(q), max(q))
1356        #print '------------------------------------------------'
1357
1358
1359#-------------------------------------------------------------
1360if __name__ == "__main__":
1361    """
1362    Load in a mesh and data points with attributes.
1363    Fit the attributes to the mesh.
1364    Save a new mesh file.
1365    """
1366    import os, sys
1367    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1368            %os.path.basename(sys.argv[0])
1369
1370    if len(sys.argv) < 4:
1371        print usage
1372    else:
1373        mesh_file = sys.argv[1]
1374        point_file = sys.argv[2]
1375        mesh_output_file = sys.argv[3]
1376
1377        expand_search = False
1378        if len(sys.argv) > 4:
1379            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1380                expand_search = True
1381            else:
1382                expand_search = False
1383
1384        verbose = False
1385        if len(sys.argv) > 5:
1386            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1387                verbose = False
1388            else:
1389                verbose = True
1390
1391        if len(sys.argv) > 6:
1392            alpha = sys.argv[6]
1393        else:
1394            alpha = DEFAULT_ALPHA
1395
1396        # This is used more for testing
1397        if len(sys.argv) > 7:
1398            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1399                display_errors = False
1400            else:
1401                display_errors = True
1402           
1403        t0 = time.time()
1404        try:
1405            fit_to_mesh_file(mesh_file,
1406                         point_file,
1407                         mesh_output_file,
1408                         alpha,
1409                         verbose= verbose,
1410                         expand_search = expand_search,
1411                         display_errors = display_errors)
1412        except IOError,e:
1413            import sys; sys.exit(1)
1414
1415        print 'That took %.2f seconds' %(time.time()-t0)
1416
Note: See TracBrowser for help on using the repository browser.