source: branches/inundation-numpy-branch/pyvolution/least_squares.py @ 7077

Last change on this file since 7077 was 3514, checked in by duncan, 19 years ago

Hi all,
I'm doing a change in the anuga structure, moving the code to

\anuga_core\source\anuga

After you have done an svn update, the PYTHONPATH has to be changed to;
PYTHONPATH = anuga_core/source/

This is part of changes required to make installation of anuga quicker and reducing the size of our sandpits.

If any imports are broken, try fixing them. With adding anuga. to them for example. If this seems to have really broken things, email/phone me.

Cheers
Duncan

File size: 48.1 KB
Line 
1"""Least squares smooting and interpolation.
2
3   Implements a penalised least-squares fit and associated interpolations.
4
5   The penalty term (or smoothing term) is controlled by the smoothing
6   parameter alpha.
7   With a value of alpha=0, the fit function will attempt
8   to interpolate as closely as possible in the least-squares sense.
9   With values alpha > 0, a certain amount of smoothing will be applied.
10   A positive alpha is essential in cases where there are too few
11   data points.
12   A negative alpha is not allowed.
13   A typical value of alpha is 1.0e-6
14
15
16   Ole Nielsen, Stephen Roberts, Duncan Gray, Christopher Zoppou
17   Geoscience Australia, 2004.
18"""
19
20#import exceptions
21#class ShapeError(exceptions.Exception): pass
22
23#from general_mesh import General_mesh
24from numpy import zeros, take, array, Float, Int, dot, transpose, concatenate, ArrayType
25
26#FIXME (Ole): Meshes to move somewhere else
27from anuga.pyvolution.mesh import Mesh
28from anuga.utilities.sparse import Sparse, Sparse_CSR
29from anuga.utilities.cg_solve import conjugate_gradient, VectorShapeError
30from anuga.utilities.numerical_tools import ensure_numeric, mean, gradient
31
32
33from coordinate_transforms.geo_reference import Geo_reference
34
35import time
36
37
38
39
40DEFAULT_ALPHA = 0.001
41
42def fit_to_mesh_file(mesh_file, point_file, mesh_output_file,
43                     alpha=DEFAULT_ALPHA, verbose= False,
44                     expand_search = False,
45                     data_origin = None,
46                     mesh_origin = None,
47                     precrop = False,
48                     display_errors = True):
49    """
50    Given a mesh file (tsh) and a point attribute file (xya), fit
51    point attributes to the mesh and write a mesh file with the
52    results.
53
54
55    If data_origin is not None it is assumed to be
56    a 3-tuple with geo referenced
57    UTM coordinates (zone, easting, northing)
58
59    NOTE: Throws IOErrors, for a variety of file problems.
60   
61    mesh_origin is the same but refers to the input tsh file.
62    FIXME: When the tsh format contains it own origin, these parameters can go.
63    FIXME: And both origins should be obtained from the specified files.
64    """
65
66    from load_mesh.loadASCII import import_mesh_file, \
67                 import_points_file, export_mesh_file, \
68                 concatinate_attributelist
69
70
71    try:
72        mesh_dict = import_mesh_file(mesh_file)
73    except IOError,e:
74        if display_errors:
75            print "Could not load bad file. ", e
76        raise IOError  #Re-raise exception
77       
78    vertex_coordinates = mesh_dict['vertices']
79    triangles = mesh_dict['triangles']
80    if type(mesh_dict['vertex_attributes']) == ArrayType:
81        old_point_attributes = mesh_dict['vertex_attributes'].tolist()
82    else:
83        old_point_attributes = mesh_dict['vertex_attributes']
84
85    if type(mesh_dict['vertex_attribute_titles']) == ArrayType:
86        old_title_list = mesh_dict['vertex_attribute_titles'].tolist()
87    else:
88        old_title_list = mesh_dict['vertex_attribute_titles']
89
90    if verbose: print 'tsh file %s loaded' %mesh_file
91
92    # load in the .pts file
93    try:
94        point_dict = import_points_file(point_file, verbose=verbose)
95    except IOError,e:
96        if display_errors:
97            print "Could not load bad file. ", e
98        raise IOError  #Re-raise exception 
99
100    point_coordinates = point_dict['pointlist']
101    title_list,point_attributes = concatinate_attributelist(point_dict['attributelist'])
102
103    if point_dict.has_key('geo_reference') and not point_dict['geo_reference'] is None:
104        data_origin = point_dict['geo_reference'].get_origin()
105    else:
106        data_origin = (56, 0, 0) #FIXME(DSG-DSG)
107
108    if mesh_dict.has_key('geo_reference') and not mesh_dict['geo_reference'] is None:
109        mesh_origin = mesh_dict['geo_reference'].get_origin()
110    else:
111        mesh_origin = (56, 0, 0) #FIXME(DSG-DSG)
112
113    if verbose: print "points file loaded"
114    if verbose: print "fitting to mesh"
115    f = fit_to_mesh(vertex_coordinates,
116                    triangles,
117                    point_coordinates,
118                    point_attributes,
119                    alpha = alpha,
120                    verbose = verbose,
121                    expand_search = expand_search,
122                    data_origin = data_origin,
123                    mesh_origin = mesh_origin,
124                    precrop = precrop)
125    if verbose: print "finished fitting to mesh"
126
127    # convert array to list of lists
128    new_point_attributes = f.tolist()
129    #FIXME have this overwrite attributes with the same title - DSG
130    #Put the newer attributes last
131    if old_title_list <> []:
132        old_title_list.extend(title_list)
133        #FIXME can this be done a faster way? - DSG
134        for i in range(len(old_point_attributes)):
135            old_point_attributes[i].extend(new_point_attributes[i])
136        mesh_dict['vertex_attributes'] = old_point_attributes
137        mesh_dict['vertex_attribute_titles'] = old_title_list
138    else:
139        mesh_dict['vertex_attributes'] = new_point_attributes
140        mesh_dict['vertex_attribute_titles'] = title_list
141
142    #FIXME (Ole): Remember to output mesh_origin as well
143    if verbose: print "exporting to file ", mesh_output_file
144
145    try:
146        export_mesh_file(mesh_output_file, mesh_dict)
147    except IOError,e:
148        if display_errors:
149            print "Could not write file. ", e
150        raise IOError
151
152def fit_to_mesh(vertex_coordinates,
153                triangles,
154                point_coordinates,
155                point_attributes,
156                alpha = DEFAULT_ALPHA,
157                verbose = False,
158                expand_search = False,
159                data_origin = None,
160                mesh_origin = None,
161                precrop = False,
162                use_cache = False):
163    """
164    Fit a smooth surface to a triangulation,
165    given data points with attributes.
166
167
168        Inputs:
169
170          vertex_coordinates: List of coordinate pairs [xi, eta] of points
171          constituting mesh (or a an m x 2 Numeric array)
172
173          triangles: List of 3-tuples (or a Numeric array) of
174          integers representing indices of all vertices in the mesh.
175
176          point_coordinates: List of coordinate pairs [x, y] of data points
177          (or an nx2 Numeric array)
178
179          alpha: Smoothing parameter.
180
181          point_attributes: Vector or array of data at the point_coordinates.
182
183          data_origin and mesh_origin are 3-tuples consisting of
184          UTM zone, easting and northing. If specified
185          point coordinates and vertex coordinates are assumed to be
186          relative to their respective origins.
187
188    """
189
190    if use_cache is True:
191        from caching.caching import cache
192        interp = cache(_interpolation,
193                       (vertex_coordinates,
194                        triangles,
195                        point_coordinates),
196                       {'alpha': alpha,
197                        'verbose': verbose,
198                        'expand_search': expand_search,
199                        'data_origin': data_origin,
200                        'mesh_origin': mesh_origin,
201                        'precrop': precrop},
202                       verbose = verbose)       
203       
204    else:
205        interp = Interpolation(vertex_coordinates,
206                               triangles,
207                               point_coordinates,
208                               alpha = alpha,
209                               verbose = verbose,
210                               expand_search = expand_search,
211                               data_origin = data_origin,
212                               mesh_origin = mesh_origin,
213                               precrop = precrop)
214
215    vertex_attributes = interp.fit_points(point_attributes, verbose = verbose)
216    return vertex_attributes
217
218
219
220def pts2rectangular(pts_name, M, N, alpha = DEFAULT_ALPHA,
221                    verbose = False, reduction = 1):
222    """Fits attributes from pts file to MxN rectangular mesh
223
224    Read pts file and create rectangular mesh of resolution MxN such that
225    it covers all points specified in pts file.
226
227    FIXME: This may be a temporary function until we decide on
228    netcdf formats etc
229
230    FIXME: Uses elevation hardwired
231    """
232
233    import  mesh_factory
234    from load_mesh.loadASCII import import_points_file
235   
236    if verbose: print 'Read pts'
237    points_dict = import_points_file(pts_name)
238    #points, attributes = util.read_xya(pts_name)
239
240    #Reduce number of points a bit
241    points = points_dict['pointlist'][::reduction]
242    elevation = points_dict['attributelist']['elevation']  #Must be elevation
243    elevation = elevation[::reduction]
244
245    if verbose: print 'Got %d data points' %len(points)
246
247    if verbose: print 'Create mesh'
248    #Find extent
249    max_x = min_x = points[0][0]
250    max_y = min_y = points[0][1]
251    for point in points[1:]:
252        x = point[0]
253        if x > max_x: max_x = x
254        if x < min_x: min_x = x
255        y = point[1]
256        if y > max_y: max_y = y
257        if y < min_y: min_y = y
258
259    #Create appropriate mesh
260    vertex_coordinates, triangles, boundary =\
261         mesh_factory.rectangular(M, N, max_x-min_x, max_y-min_y,
262                                (min_x, min_y))
263
264    #Fit attributes to mesh
265    vertex_attributes = fit_to_mesh(vertex_coordinates,
266                        triangles,
267                        points,
268                        elevation, alpha=alpha, verbose=verbose)
269
270
271
272    return vertex_coordinates, triangles, boundary, vertex_attributes
273
274
275def _interpolation(*args, **kwargs):
276    """Private function for use with caching. Reason is that classes
277    may change their byte code between runs which is annoying.
278    """
279   
280    return Interpolation(*args, **kwargs)
281
282
283class Interpolation:
284
285    def __init__(self,
286                 vertex_coordinates,
287                 triangles,
288                 point_coordinates = None,
289                 alpha = None,
290                 verbose = False,
291                 expand_search = True,
292                 interp_only = False,
293                 max_points_per_cell = 30,
294                 mesh_origin = None,
295                 data_origin = None,
296                 precrop = False):
297
298
299        """ Build interpolation matrix mapping from
300        function values at vertices to function values at data points
301
302        Inputs:
303
304          vertex_coordinates: List of coordinate pairs [xi, eta] of
305          points constituting mesh (or a an m x 2 Numeric array)
306          Points may appear multiple times
307          (e.g. if vertices have discontinuities)
308
309          triangles: List of 3-tuples (or a Numeric array) of
310          integers representing indices of all vertices in the mesh.
311
312          point_coordinates: List of coordinate pairs [x, y] of
313          data points (or an nx2 Numeric array)
314          If point_coordinates is absent, only smoothing matrix will
315          be built
316
317          alpha: Smoothing parameter
318
319          data_origin and mesh_origin are 3-tuples consisting of
320          UTM zone, easting and northing. If specified
321          point coordinates and vertex coordinates are assumed to be
322          relative to their respective origins.
323
324        """
325
326        #Convert input to Numeric arrays
327        triangles = ensure_numeric(triangles, Int)
328        vertex_coordinates = ensure_numeric(vertex_coordinates, Float)
329
330        #Build underlying mesh
331        if verbose: print 'Building mesh'
332        #self.mesh = General_mesh(vertex_coordinates, triangles,
333        #FIXME: Trying the normal mesh while testing precrop,
334        #       The functionality of boundary_polygon is needed for that
335
336        #FIXME - geo ref does not have to go into mesh.
337        # Change the point co-ords to conform to the
338        # mesh co-ords early in the code
339        if mesh_origin is None:
340            geo = None
341        else:
342            geo = Geo_reference(mesh_origin[0],mesh_origin[1],mesh_origin[2])
343        self.mesh = Mesh(vertex_coordinates, triangles,
344                         geo_reference = geo)
345       
346        self.mesh.check_integrity()
347
348        self.data_origin = data_origin
349
350        self.point_indices = None
351
352        #Smoothing parameter
353        if alpha is None:
354            self.alpha = DEFAULT_ALPHA
355        else:   
356            self.alpha = alpha
357
358
359        if point_coordinates is not None:
360            if verbose: print 'Building interpolation matrix'
361            self.build_interpolation_matrix_A(point_coordinates,
362                                              verbose = verbose,
363                                              expand_search = expand_search,
364                                              interp_only = interp_only, 
365                                              max_points_per_cell =\
366                                              max_points_per_cell,
367                                              data_origin = data_origin,
368                                              precrop = precrop)
369        #Build coefficient matrices
370        if interp_only == False:
371            self.build_coefficient_matrix_B(point_coordinates,
372                                        verbose = verbose,
373                                        expand_search = expand_search,
374                                        max_points_per_cell =\
375                                        max_points_per_cell,
376                                        data_origin = data_origin,
377                                        precrop = precrop)
378
379    def set_point_coordinates(self, point_coordinates,
380                              data_origin = None,
381                              verbose = False,
382                              precrop = True):
383        """
384        A public interface to setting the point co-ordinates.
385        """
386        if point_coordinates is not None:
387            if verbose: print 'Building interpolation matrix'
388            self.build_interpolation_matrix_A(point_coordinates,
389                                              verbose = verbose,
390                                              data_origin = data_origin,
391                                              precrop = precrop)
392        self.build_coefficient_matrix_B(point_coordinates, data_origin)
393
394    def build_coefficient_matrix_B(self, point_coordinates=None,
395                                   verbose = False, expand_search = True,
396                                   max_points_per_cell=30,
397                                   data_origin = None,
398                                   precrop = False):
399        """Build final coefficient matrix"""
400
401
402        if self.alpha <> 0:
403            if verbose: print 'Building smoothing matrix'
404            self.build_smoothing_matrix_D()
405
406        if point_coordinates is not None:
407            if self.alpha <> 0:
408                self.B = self.AtA + self.alpha*self.D
409            else:
410                self.B = self.AtA
411
412            #Convert self.B matrix to CSR format for faster matrix vector
413            self.B = Sparse_CSR(self.B)
414
415    def build_interpolation_matrix_A(self, point_coordinates,
416                                     verbose = False, expand_search = True,
417                                     max_points_per_cell=30,
418                                     data_origin = None,
419                                     precrop = False,
420                                     interp_only = False):
421        """Build n x m interpolation matrix, where
422        n is the number of data points and
423        m is the number of basis functions phi_k (one per vertex)
424
425        This algorithm uses a quad tree data structure for fast binning of data points
426        origin is a 3-tuple consisting of UTM zone, easting and northing.
427        If specified coordinates are assumed to be relative to this origin.
428
429        This one will override any data_origin that may be specified in
430        interpolation instance
431
432        """
433
434
435
436        #FIXME (Ole): Check that this function is memeory efficient.
437        #6 million datapoints and 300000 basis functions
438        #causes out-of-memory situation
439        #First thing to check is whether there is room for self.A and self.AtA
440        #
441        #Maybe we need some sort of blocking
442
443        from anuga.pyvolution.quad import build_quadtree
444        from anuga.utilities.polygon import inside_polygon
445       
446
447        if data_origin is None:
448            data_origin = self.data_origin #Use the one from
449                                           #interpolation instance
450
451        #Convert input to Numeric arrays just in case.
452        point_coordinates = ensure_numeric(point_coordinates, Float)
453
454        #Keep track of discarded points (if any).
455        #This is only registered if precrop is True
456        self.cropped_points = False
457
458        #Shift data points to same origin as mesh (if specified)
459
460        #FIXME this will shift if there was no geo_ref.
461        #But all this should be removed anyhow.
462        #change coords before this point
463        mesh_origin = self.mesh.geo_reference.get_origin()
464        if point_coordinates is not None:
465            if data_origin is not None:
466                if mesh_origin is not None:
467
468                    #Transformation:
469                    #
470                    #Let x_0 be the reference point of the point coordinates
471                    #and xi_0 the reference point of the mesh.
472                    #
473                    #A point coordinate (x + x_0) is then made relative
474                    #to xi_0 by
475                    #
476                    # x_new = x + x_0 - xi_0
477                    #
478                    #and similarly for eta
479
480                    x_offset = data_origin[1] - mesh_origin[1]
481                    y_offset = data_origin[2] - mesh_origin[2]
482                else: #Shift back to a zero origin
483                    x_offset = data_origin[1]
484                    y_offset = data_origin[2]
485
486                point_coordinates[:,0] += x_offset
487                point_coordinates[:,1] += y_offset
488            else:
489                if mesh_origin is not None:
490                    #Use mesh origin for data points
491                    point_coordinates[:,0] -= mesh_origin[1]
492                    point_coordinates[:,1] -= mesh_origin[2]
493
494
495
496        #Remove points falling outside mesh boundary
497        #This reduced one example from 1356 seconds to 825 seconds
498
499       
500        if precrop is True:
501            #from Numeric import take
502
503            if verbose: print 'Getting boundary polygon'
504            P = self.mesh.get_boundary_polygon()
505
506            if verbose: print 'Getting indices inside mesh boundary'
507            indices = inside_polygon(point_coordinates, P, verbose = verbose)
508
509
510            if len(indices) != point_coordinates.shape[0]:
511                self.cropped_points = True
512                if verbose:
513                    print 'Done - %d points outside mesh have been cropped.'\
514                          %(point_coordinates.shape[0] - len(indices))
515
516            point_coordinates = take(point_coordinates, indices)
517            self.point_indices = indices
518
519
520
521
522        #Build n x m interpolation matrix
523        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
524        n = point_coordinates.shape[0]     #Nbr of data points
525
526        if verbose: print 'Number of datapoints: %d' %n
527        if verbose: print 'Number of basis functions: %d' %m
528
529        #FIXME (Ole): We should use CSR here since mat-mat mult is now OK.
530        #However, Sparse_CSR does not have the same methods as Sparse yet
531        #The tests will reveal what needs to be done
532
533        #
534        #self.A = Sparse_CSR(Sparse(n,m))
535        #self.AtA = Sparse_CSR(Sparse(m,m))
536        self.A = Sparse(n,m)
537        self.AtA = Sparse(m,m)
538
539        #Build quad tree of vertices (FIXME: Is this the right spot for that?)
540        root = build_quadtree(self.mesh,
541                              max_points_per_cell = max_points_per_cell)
542        #root.show()
543        self.expanded_quad_searches = []
544        #Compute matrix elements
545        for i in range(n):
546            #For each data_coordinate point
547
548            if verbose and i%((n+10)/10)==0: print 'Doing %d of %d' %(i, n)
549            x = point_coordinates[i]
550
551            #Find vertices near x
552            candidate_vertices = root.search(x[0], x[1])
553            is_more_elements = True
554
555            element_found, sigma0, sigma1, sigma2, k = \
556                self.search_triangles_of_vertices(candidate_vertices, x)
557            first_expansion = True
558            while not element_found and is_more_elements and expand_search:
559                #if verbose: print 'Expanding search'
560                if first_expansion == True:
561                    self.expanded_quad_searches.append(1)
562                    first_expansion = False
563                else:
564                    end = len(self.expanded_quad_searches) - 1
565                    assert end >= 0
566                    self.expanded_quad_searches[end] += 1
567                candidate_vertices, branch = root.expand_search()
568                if branch == []:
569                    # Searching all the verts from the root cell that haven't
570                    # been searched.  This is the last try
571                    element_found, sigma0, sigma1, sigma2, k = \
572                      self.search_triangles_of_vertices(candidate_vertices, x)
573                    is_more_elements = False
574                else:
575                    element_found, sigma0, sigma1, sigma2, k = \
576                      self.search_triangles_of_vertices(candidate_vertices, x)
577
578               
579            #Update interpolation matrix A if necessary
580            if element_found is True:
581                #Assign values to matrix A
582
583                j0 = self.mesh.triangles[k,0] #Global vertex id for sigma0
584                j1 = self.mesh.triangles[k,1] #Global vertex id for sigma1
585                j2 = self.mesh.triangles[k,2] #Global vertex id for sigma2
586
587                sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
588                js     = [j0,j1,j2]
589
590                for j in js:
591                    self.A[i,j] = sigmas[j]
592                    for k in js:
593                        if interp_only == False:
594                            self.AtA[j,k] += sigmas[j]*sigmas[k]
595            else:
596                pass
597                #Ok if there is no triangle for datapoint
598                #(as in brute force version)
599                #raise 'Could not find triangle for point', x
600
601
602
603    def search_triangles_of_vertices(self, candidate_vertices, x):
604            #Find triangle containing x:
605            element_found = False
606
607            # This will be returned if element_found = False
608            sigma2 = -10.0
609            sigma0 = -10.0
610            sigma1 = -10.0
611            k = -10.0
612            #print "*$* candidate_vertices", candidate_vertices
613            #For all vertices in same cell as point x
614            for v in candidate_vertices:
615                #FIXME (DSG-DSG): this catches verts with no triangle.
616                #Currently pmesh is producing these.
617                #this should be stopped,
618                if self.mesh.vertexlist[v] is None:
619                    continue
620                #for each triangle id (k) which has v as a vertex
621                for k, _ in self.mesh.vertexlist[v]:
622
623                    #Get the three vertex_points of candidate triangle
624                    xi0 = self.mesh.get_vertex_coordinate(k, 0)
625                    xi1 = self.mesh.get_vertex_coordinate(k, 1)
626                    xi2 = self.mesh.get_vertex_coordinate(k, 2)
627
628                    #print "PDSG - k", k
629                    #print "PDSG - xi0", xi0
630                    #print "PDSG - xi1", xi1
631                    #print "PDSG - xi2", xi2
632                    #print "PDSG element %i verts((%f, %f),(%f, %f),(%f, %f))"\
633                    #   % (k, xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1])
634
635                    #Get the three normals
636                    n0 = self.mesh.get_normal(k, 0)
637                    n1 = self.mesh.get_normal(k, 1)
638                    n2 = self.mesh.get_normal(k, 2)
639
640
641                    #Compute interpolation
642                    sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
643                    sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
644                    sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
645
646                    #print "PDSG - sigma0", sigma0
647                    #print "PDSG - sigma1", sigma1
648                    #print "PDSG - sigma2", sigma2
649
650                    #FIXME: Maybe move out to test or something
651                    epsilon = 1.0e-6
652                    assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
653
654                    #Check that this triangle contains the data point
655
656                    #Sigmas can get negative within
657                    #machine precision on some machines (e.g nautilus)
658                    #Hence the small eps
659                    eps = 1.0e-15
660                    if sigma0 >= -eps and sigma1 >= -eps and sigma2 >= -eps:
661                        element_found = True
662                        break
663
664                if element_found is True:
665                    #Don't look for any other triangle
666                    break
667            return element_found, sigma0, sigma1, sigma2, k
668
669
670
671    def build_interpolation_matrix_A_brute(self, point_coordinates):
672        """Build n x m interpolation matrix, where
673        n is the number of data points and
674        m is the number of basis functions phi_k (one per vertex)
675
676        This is the brute force which is too slow for large problems,
677        but could be used for testing
678        """
679
680
681        #Convert input to Numeric arrays
682        point_coordinates = ensure_numeric(point_coordinates, Float)
683
684        #Build n x m interpolation matrix
685        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
686        n = point_coordinates.shape[0]     #Nbr of data points
687
688        self.A = Sparse(n,m)
689        self.AtA = Sparse(m,m)
690
691        #Compute matrix elements
692        for i in range(n):
693            #For each data_coordinate point
694
695            x = point_coordinates[i]
696            element_found = False
697            k = 0
698            while not element_found and k < len(self.mesh):
699                #For each triangle (brute force)
700                #FIXME: Real algorithm should only visit relevant triangles
701
702                #Get the three vertex_points
703                xi0 = self.mesh.get_vertex_coordinate(k, 0)
704                xi1 = self.mesh.get_vertex_coordinate(k, 1)
705                xi2 = self.mesh.get_vertex_coordinate(k, 2)
706
707                #Get the three normals
708                n0 = self.mesh.get_normal(k, 0)
709                n1 = self.mesh.get_normal(k, 1)
710                n2 = self.mesh.get_normal(k, 2)
711
712                #Compute interpolation
713                sigma2 = dot((x-xi0), n2)/dot((xi2-xi0), n2)
714                sigma0 = dot((x-xi1), n0)/dot((xi0-xi1), n0)
715                sigma1 = dot((x-xi2), n1)/dot((xi1-xi2), n1)
716
717                #FIXME: Maybe move out to test or something
718                epsilon = 1.0e-6
719                assert abs(sigma0 + sigma1 + sigma2 - 1.0) < epsilon
720
721                #Check that this triangle contains data point
722                if sigma0 >= 0 and sigma1 >= 0 and sigma2 >= 0:
723                    element_found = True
724                    #Assign values to matrix A
725
726                    j0 = self.mesh.triangles[k,0] #Global vertex id
727                    #self.A[i, j0] = sigma0
728
729                    j1 = self.mesh.triangles[k,1] #Global vertex id
730                    #self.A[i, j1] = sigma1
731
732                    j2 = self.mesh.triangles[k,2] #Global vertex id
733                    #self.A[i, j2] = sigma2
734
735                    sigmas = {j0:sigma0, j1:sigma1, j2:sigma2}
736                    js     = [j0,j1,j2]
737
738                    for j in js:
739                        self.A[i,j] = sigmas[j]
740                        for k in js:
741                            self.AtA[j,k] += sigmas[j]*sigmas[k]
742                k = k+1
743
744
745
746    def get_A(self):
747        return self.A.todense()
748
749    def get_B(self):
750        return self.B.todense()
751
752    def get_D(self):
753        return self.D.todense()
754
755        #FIXME: Remember to re-introduce the 1/n factor in the
756        #interpolation term
757
758    def build_smoothing_matrix_D(self):
759        """Build m x m smoothing matrix, where
760        m is the number of basis functions phi_k (one per vertex)
761
762        The smoothing matrix is defined as
763
764        D = D1 + D2
765
766        where
767
768        [D1]_{k,l} = \int_\Omega
769           \frac{\partial \phi_k}{\partial x}
770           \frac{\partial \phi_l}{\partial x}\,
771           dx dy
772
773        [D2]_{k,l} = \int_\Omega
774           \frac{\partial \phi_k}{\partial y}
775           \frac{\partial \phi_l}{\partial y}\,
776           dx dy
777
778
779        The derivatives \frac{\partial \phi_k}{\partial x},
780        \frac{\partial \phi_k}{\partial x} for a particular triangle
781        are obtained by computing the gradient a_k, b_k for basis function k
782        """
783
784        #FIXME: algorithm might be optimised by computing local 9x9
785        #"element stiffness matrices:
786
787        m = self.mesh.coordinates.shape[0] #Nbr of basis functions (1/vertex)
788
789        self.D = Sparse(m,m)
790
791        #For each triangle compute contributions to D = D1+D2
792        for i in range(len(self.mesh)):
793
794            #Get area
795            area = self.mesh.areas[i]
796
797            #Get global vertex indices
798            v0 = self.mesh.triangles[i,0]
799            v1 = self.mesh.triangles[i,1]
800            v2 = self.mesh.triangles[i,2]
801
802            #Get the three vertex_points
803            xi0 = self.mesh.get_vertex_coordinate(i, 0)
804            xi1 = self.mesh.get_vertex_coordinate(i, 1)
805            xi2 = self.mesh.get_vertex_coordinate(i, 2)
806
807            #Compute gradients for each vertex
808            a0, b0 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
809                              1, 0, 0)
810
811            a1, b1 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
812                              0, 1, 0)
813
814            a2, b2 = gradient(xi0[0], xi0[1], xi1[0], xi1[1], xi2[0], xi2[1],
815                              0, 0, 1)
816
817            #Compute diagonal contributions
818            self.D[v0,v0] += (a0*a0 + b0*b0)*area
819            self.D[v1,v1] += (a1*a1 + b1*b1)*area
820            self.D[v2,v2] += (a2*a2 + b2*b2)*area
821
822            #Compute contributions for basis functions sharing edges
823            e01 = (a0*a1 + b0*b1)*area
824            self.D[v0,v1] += e01
825            self.D[v1,v0] += e01
826
827            e12 = (a1*a2 + b1*b2)*area
828            self.D[v1,v2] += e12
829            self.D[v2,v1] += e12
830
831            e20 = (a2*a0 + b2*b0)*area
832            self.D[v2,v0] += e20
833            self.D[v0,v2] += e20
834
835
836    def fit(self, z):
837        """Fit a smooth surface to given 1d array of data points z.
838
839        The smooth surface is computed at each vertex in the underlying
840        mesh using the formula given in the module doc string.
841
842        Pre Condition:
843          self.A, self.AtA and self.B have been initialised
844
845        Inputs:
846          z: Single 1d vector or array of data at the point_coordinates.
847        """
848
849        #Convert input to Numeric arrays
850        z = ensure_numeric(z, Float)
851
852        if len(z.shape) > 1 :
853            raise VectorShapeError, 'Can only deal with 1d data vector'
854
855        if self.point_indices is not None:
856            #Remove values for any points that were outside mesh
857            z = take(z, self.point_indices)
858
859        #Compute right hand side based on data
860        #FIXME (DSG-DsG): could Sparse_CSR be used here?  Use this format
861        # after a matrix is built, before calcs.
862        Atz = self.A.trans_mult(z)
863
864
865        #Check sanity
866        n, m = self.A.shape
867        if n<m and self.alpha == 0.0:
868            msg = 'ERROR (least_squares): Too few data points\n'
869            msg += 'There are only %d data points and alpha == 0. ' %n
870            msg += 'Need at least %d\n' %m
871            msg += 'Alternatively, set smoothing parameter alpha to a small '
872            msg += 'positive value,\ne.g. 1.0e-3.'
873            raise msg
874
875
876
877        return conjugate_gradient(self.B, Atz, Atz, imax=2*len(Atz) )
878        #FIXME: Should we store the result here for later use? (ON)
879
880
881    def fit_points(self, z, verbose=False):
882        """Like fit, but more robust when each point has two or more attributes
883        FIXME (Ole): The name fit_points doesn't carry any meaning
884        for me. How about something like fit_multiple or fit_columns?
885        """
886
887        try:
888            if verbose: print 'Solving penalised least_squares problem'
889            return self.fit(z)
890        except VectorShapeError, e:
891            # broadcasting is not supported.
892
893            #Convert input to Numeric arrays
894            z = ensure_numeric(z, Float)
895
896            #Build n x m interpolation matrix
897            m = self.mesh.coordinates.shape[0] #Number of vertices
898            n = z.shape[1]                     #Number of data points
899
900            f = zeros((m,n), Float) #Resulting columns
901
902            for i in range(z.shape[1]):
903                f[:,i] = self.fit(z[:,i])
904
905            return f
906
907
908    def interpolate(self, f):
909        """Evaluate smooth surface f at data points implied in self.A.
910
911        The mesh values representing a smooth surface are
912        assumed to be specified in f. This argument could,
913        for example have been obtained from the method self.fit()
914
915        Pre Condition:
916          self.A has been initialised
917
918        Inputs:
919          f: Vector or array of data at the mesh vertices.
920          If f is an array, interpolation will be done for each column as
921          per underlying matrix-matrix multiplication
922
923        Output:
924          Interpolated values at data points implied in self.A
925
926        """
927
928        return self.A * f
929
930    def cull_outsiders(self, f):
931        pass
932
933
934
935
936class Interpolation_function:
937    """Interpolation_function - creates callable object f(t, id) or f(t,x,y)
938    which is interpolated from time series defined at vertices of
939    triangular mesh (such as those stored in sww files)
940
941    Let m be the number of vertices, n the number of triangles
942    and p the number of timesteps.
943
944    Mandatory input
945        time:               px1 array of monotonously increasing times (Float)
946        quantities:         Dictionary of arrays or 1 array (Float)
947                            The arrays must either have dimensions pxm or mx1.
948                            The resulting function will be time dependent in
949                            the former case while it will be constan with
950                            respect to time in the latter case.
951       
952    Optional input:
953        quantity_names:     List of keys into the quantities dictionary
954        vertex_coordinates: mx2 array of coordinates (Float)
955        triangles:          nx3 array of indices into vertex_coordinates (Int)
956        interpolation_points: Nx2 array of coordinates to be interpolated to
957        verbose:            Level of reporting
958   
959   
960    The quantities returned by the callable object are specified by
961    the list quantities which must contain the names of the
962    quantities to be returned and also reflect the order, e.g. for
963    the shallow water wave equation, on would have
964    quantities = ['stage', 'xmomentum', 'ymomentum']
965
966    The parameter interpolation_points decides at which points interpolated
967    quantities are to be computed whenever object is called.
968    If None, return average value
969    """
970
971   
972   
973    def __init__(self,
974                 time,
975                 quantities,
976                 quantity_names = None, 
977                 vertex_coordinates = None,
978                 triangles = None,
979                 interpolation_points = None,
980                 verbose = False):
981        """Initialise object and build spatial interpolation if required
982        """
983
984        from Numeric import array, zeros, Float, alltrue, concatenate,\
985             reshape, ArrayType
986
987
988        from config import time_format
989        import types
990
991
992
993        #Check temporal info
994        time = ensure_numeric(time)       
995        msg = 'Time must be a monotonuosly '
996        msg += 'increasing sequence %s' %time
997        assert alltrue(time[1:] - time[:-1] >= 0 ), msg
998
999
1000        #Check if quantities is a single array only
1001        if type(quantities) != types.DictType:
1002            quantities = ensure_numeric(quantities)
1003            quantity_names = ['Attribute']
1004
1005            #Make it a dictionary
1006            quantities = {quantity_names[0]: quantities}
1007
1008
1009        #Use keys if no names are specified
1010        if quantity_names is None:
1011            quantity_names = quantities.keys()
1012
1013
1014        #Check spatial info
1015        if vertex_coordinates is None:
1016            self.spatial = False
1017        else:   
1018            vertex_coordinates = ensure_numeric(vertex_coordinates)
1019
1020            assert triangles is not None, 'Triangles array must be specified'
1021            triangles = ensure_numeric(triangles)
1022            self.spatial = True           
1023           
1024
1025 
1026        #Save for use with statistics
1027        self.quantity_names = quantity_names       
1028        self.quantities = quantities       
1029        self.vertex_coordinates = vertex_coordinates
1030        self.interpolation_points = interpolation_points
1031        self.T = time[:]  # Time assumed to be relative to starttime
1032        self.index = 0    # Initial time index
1033        self.precomputed_values = {}
1034           
1035
1036
1037        #Precomputed spatial interpolation if requested
1038        if interpolation_points is not None:
1039            if self.spatial is False:
1040                raise 'Triangles and vertex_coordinates must be specified'
1041           
1042            try:
1043                self.interpolation_points = ensure_numeric(interpolation_points)
1044            except:
1045                msg = 'Interpolation points must be an N x 2 Numeric array '+\
1046                      'or a list of points\n'
1047                msg += 'I got: %s.' %(str(self.interpolation_points)[:60] +\
1048                                      '...')
1049                raise msg
1050
1051
1052            m = len(self.interpolation_points)
1053            p = len(self.T)
1054           
1055            for name in quantity_names:
1056                self.precomputed_values[name] = zeros((p, m), Float)
1057
1058            #Build interpolator
1059            interpol = Interpolation(vertex_coordinates,
1060                                     triangles,
1061                                     point_coordinates = \
1062                                     self.interpolation_points,
1063                                     alpha = 0,
1064                                     precrop = False, 
1065                                     verbose = verbose)
1066
1067            if verbose: print 'Interpolate'
1068            for i, t in enumerate(self.T):
1069                #Interpolate quantities at this timestep
1070                if verbose and i%((p+10)/10)==0:
1071                    print ' time step %d of %d' %(i, p)
1072                   
1073                for name in quantity_names:
1074                    if len(quantities[name].shape) == 2:
1075                        result = interpol.interpolate(quantities[name][i,:])
1076                    else:
1077                       #Assume no time dependency
1078                       result = interpol.interpolate(quantities[name][:])
1079                       
1080                    self.precomputed_values[name][i, :] = result
1081                   
1082                       
1083
1084            #Report
1085            if verbose:
1086                print self.statistics()
1087                #self.print_statistics()
1088           
1089        else:
1090            #Store quantitites as is
1091            for name in quantity_names:
1092                self.precomputed_values[name] = quantities[name]
1093
1094
1095        #else:
1096        #    #Return an average, making this a time series
1097        #    for name in quantity_names:
1098        #        self.values[name] = zeros(len(self.T), Float)
1099        #
1100        #    if verbose: print 'Compute mean values'
1101        #    for i, t in enumerate(self.T):
1102        #        if verbose: print ' time step %d of %d' %(i, len(self.T))
1103        #        for name in quantity_names:
1104        #           self.values[name][i] = mean(quantities[name][i,:])
1105
1106
1107
1108
1109    def __repr__(self):
1110        #return 'Interpolation function (spatio-temporal)'
1111        return self.statistics()
1112   
1113
1114    def __call__(self, t, point_id = None, x = None, y = None):
1115        """Evaluate f(t), f(t, point_id) or f(t, x, y)
1116
1117        Inputs:
1118          t: time - Model time. Must lie within existing timesteps
1119          point_id: index of one of the preprocessed points.
1120          x, y:     Overrides location, point_id ignored
1121         
1122          If spatial info is present and all of x,y,point_id
1123          are None an exception is raised
1124                   
1125          If no spatial info is present, point_id and x,y arguments are ignored
1126          making f a function of time only.
1127
1128         
1129          FIXME: point_id could also be a slice
1130          FIXME: What if x and y are vectors?
1131          FIXME: What about f(x,y) without t?
1132        """
1133
1134        from math import pi, cos, sin, sqrt
1135        from Numeric import zeros, Float
1136        from anuga.utilities.numerical_tools import mean       
1137
1138        if self.spatial is True:
1139            if point_id is None:
1140                if x is None or y is None:
1141                    msg = 'Either point_id or x and y must be specified'
1142                    raise msg
1143            else:
1144                if self.interpolation_points is None:
1145                    msg = 'Interpolation_function must be instantiated ' +\
1146                          'with a list of interpolation points before parameter ' +\
1147                          'point_id can be used'
1148                    raise msg
1149
1150
1151        msg = 'Time interval [%s:%s]' %(self.T[0], self.T[1])
1152        msg += ' does not match model time: %s\n' %t
1153        if t < self.T[0]: raise msg
1154        if t > self.T[-1]: raise msg
1155
1156        oldindex = self.index #Time index
1157
1158        #Find current time slot
1159        while t > self.T[self.index]: self.index += 1
1160        while t < self.T[self.index]: self.index -= 1
1161
1162        if t == self.T[self.index]:
1163            #Protect against case where t == T[-1] (last time)
1164            # - also works in general when t == T[i]
1165            ratio = 0
1166        else:
1167            #t is now between index and index+1
1168            ratio = (t - self.T[self.index])/\
1169                    (self.T[self.index+1] - self.T[self.index])
1170
1171        #Compute interpolated values
1172        q = zeros(len(self.quantity_names), Float)
1173
1174        for i, name in enumerate(self.quantity_names):
1175            Q = self.precomputed_values[name]
1176
1177            if self.spatial is False:
1178                #If there is no spatial info               
1179                assert len(Q.shape) == 1
1180
1181                Q0 = Q[self.index]
1182                if ratio > 0: Q1 = Q[self.index+1]
1183
1184            else:
1185                if x is not None and y is not None:
1186                    #Interpolate to x, y
1187                   
1188                    raise 'x,y interpolation not yet implemented'
1189                else:
1190                    #Use precomputed point
1191                    Q0 = Q[self.index, point_id]
1192                    if ratio > 0: Q1 = Q[self.index+1, point_id]
1193
1194            #Linear temporal interpolation   
1195            if ratio > 0:
1196                q[i] = Q0 + ratio*(Q1 - Q0)
1197            else:
1198                q[i] = Q0
1199
1200
1201        #Return vector of interpolated values
1202        #if len(q) == 1:
1203        #    return q[0]
1204        #else:
1205        #    return q
1206
1207
1208        #Return vector of interpolated values
1209        #FIXME:
1210        if self.spatial is True:
1211            return q
1212        else:
1213            #Replicate q according to x and y
1214            #This is e.g used for Wind_stress
1215            if x is None or y is None: 
1216                return q
1217            else:
1218                try:
1219                    N = len(x)
1220                except:
1221                    return q
1222                else:
1223                    from Numeric import ones, Float
1224                    #x is a vector - Create one constant column for each value
1225                    N = len(x)
1226                    assert len(y) == N, 'x and y must have same length'
1227                    res = []
1228                    for col in q:
1229                        res.append(col*ones(N, Float))
1230                       
1231                return res
1232
1233
1234    def statistics(self):
1235        """Output statistics about interpolation_function
1236        """
1237       
1238        vertex_coordinates = self.vertex_coordinates
1239        interpolation_points = self.interpolation_points               
1240        quantity_names = self.quantity_names
1241        quantities = self.quantities
1242        precomputed_values = self.precomputed_values                 
1243               
1244        x = vertex_coordinates[:,0]
1245        y = vertex_coordinates[:,1]               
1246
1247        str =  '------------------------------------------------\n'
1248        str += 'Interpolation_function (spatio-temporal) statistics:\n'
1249        str += '  Extent:\n'
1250        str += '    x in [%f, %f], len(x) == %d\n'\
1251               %(min(x), max(x), len(x))
1252        str += '    y in [%f, %f], len(y) == %d\n'\
1253               %(min(y), max(y), len(y))
1254        str += '    t in [%f, %f], len(t) == %d\n'\
1255               %(min(self.T), max(self.T), len(self.T))
1256        str += '  Quantities:\n'
1257        for name in quantity_names:
1258            q = quantities[name][:].flat
1259            str += '    %s in [%f, %f]\n' %(name, min(q), max(q))
1260
1261        if interpolation_points is not None:   
1262            str += '  Interpolation points (xi, eta):'\
1263                   ' number of points == %d\n' %interpolation_points.shape[0]
1264            str += '    xi in [%f, %f]\n' %(min(interpolation_points[:,0]),
1265                                            max(interpolation_points[:,0]))
1266            str += '    eta in [%f, %f]\n' %(min(interpolation_points[:,1]),
1267                                             max(interpolation_points[:,1]))
1268            str += '  Interpolated quantities (over all timesteps):\n'
1269       
1270            for name in quantity_names:
1271                q = precomputed_values[name][:].flat
1272                str += '    %s at interpolation points in [%f, %f]\n'\
1273                       %(name, min(q), max(q))
1274        str += '------------------------------------------------\n'
1275
1276        return str
1277
1278        #FIXME: Delete
1279        #print '------------------------------------------------'
1280        #print 'Interpolation_function statistics:'
1281        #print '  Extent:'
1282        #print '    x in [%f, %f], len(x) == %d'\
1283        #      %(min(x), max(x), len(x))
1284        #print '    y in [%f, %f], len(y) == %d'\
1285        #      %(min(y), max(y), len(y))
1286        #print '    t in [%f, %f], len(t) == %d'\
1287        #      %(min(self.T), max(self.T), len(self.T))
1288        #print '  Quantities:'
1289        #for name in quantity_names:
1290        #    q = quantities[name][:].flat
1291        #    print '    %s in [%f, %f]' %(name, min(q), max(q))
1292        #print '  Interpolation points (xi, eta):'\
1293        #      ' number of points == %d ' %interpolation_points.shape[0]
1294        #print '    xi in [%f, %f]' %(min(interpolation_points[:,0]),
1295        #                             max(interpolation_points[:,0]))
1296        #print '    eta in [%f, %f]' %(min(interpolation_points[:,1]),
1297        #                              max(interpolation_points[:,1]))
1298        #print '  Interpolated quantities (over all timesteps):'
1299        #
1300        #for name in quantity_names:
1301        #    q = precomputed_values[name][:].flat
1302        #    print '    %s at interpolation points in [%f, %f]'\
1303        #          %(name, min(q), max(q))
1304        #print '------------------------------------------------'
1305
1306
1307#-------------------------------------------------------------
1308if __name__ == "__main__":
1309    """
1310    Load in a mesh and data points with attributes.
1311    Fit the attributes to the mesh.
1312    Save a new mesh file.
1313    """
1314    import os, sys
1315    usage = "usage: %s mesh_input.tsh point.xya mesh_output.tsh [expand|no_expand][vervose|non_verbose] [alpha] [display_errors|no_display_errors]"\
1316            %os.path.basename(sys.argv[0])
1317
1318    if len(sys.argv) < 4:
1319        print usage
1320    else:
1321        mesh_file = sys.argv[1]
1322        point_file = sys.argv[2]
1323        mesh_output_file = sys.argv[3]
1324
1325        expand_search = False
1326        if len(sys.argv) > 4:
1327            if sys.argv[4][0] == "e" or sys.argv[4][0] == "E":
1328                expand_search = True
1329            else:
1330                expand_search = False
1331
1332        verbose = False
1333        if len(sys.argv) > 5:
1334            if sys.argv[5][0] == "n" or sys.argv[5][0] == "N":
1335                verbose = False
1336            else:
1337                verbose = True
1338
1339        if len(sys.argv) > 6:
1340            alpha = sys.argv[6]
1341        else:
1342            alpha = DEFAULT_ALPHA
1343
1344        # This is used more for testing
1345        if len(sys.argv) > 7:
1346            if sys.argv[7][0] == "n" or sys.argv[5][0] == "N":
1347                display_errors = False
1348            else:
1349                display_errors = True
1350           
1351        t0 = time.time()
1352        try:
1353            fit_to_mesh_file(mesh_file,
1354                         point_file,
1355                         mesh_output_file,
1356                         alpha,
1357                         verbose= verbose,
1358                         expand_search = expand_search,
1359                         display_errors = display_errors)
1360        except IOError,e:
1361            import sys; sys.exit(1)
1362
1363        print 'That took %.2f seconds' %(time.time()-t0)
1364
Note: See TracBrowser for help on using the repository browser.