Changeset 7707


Ignore:
Timestamp:
May 5, 2010, 4:06:02 PM (15 years ago)
Author:
James Hudson
Message:

New quadtree implementation - unoptimised and no tree balancing. A couple of failing unit tests to fix.

Location:
anuga_core/source/anuga
Files:
6 edited

Legend:

Unmodified
Added
Removed
  • anuga_core/source/anuga/fit_interpolate/fit.py

    r7317 r7707  
    281281           
    282282            element_found, sigma0, sigma1, sigma2, k = \
    283                            search_tree_of_vertices(self.root, x)
     283                           search_tree_of_vertices(self.root, self.mesh, x)
    284284           
    285285            if element_found is True:
  • anuga_core/source/anuga/fit_interpolate/interpolate.py

    r7685 r7707  
    149149            I = cache(wrap_Interpolate, (args, kwargs), {}, verbose=verbose)
    150150    else:
    151         I = apply(Interpolate, args, kwargs)
    152 
     151        I = apply(Interpolate, args, kwargs)           
     152               
    153153    # Call interpolate method with interpolation points
    154154    result = I.interpolate_block(vertex_values, interpolation_points,
     
    334334
    335335        See interpolate for doc info.
    336         """
    337 
     336        """     
     337               
    338338        # FIXME (Ole): I reckon we should change the interface so that
    339339        # the user can specify the interpolation matrix instead of the
     
    489489
    490490        centroids = []
    491        
    492491        inside_poly_indices = []
    493492       
     
    503502            x = point_coordinates[i]
    504503            element_found, sigma0, sigma1, sigma2, k = \
    505                            search_tree_of_vertices(self.root, x)
     504                           search_tree_of_vertices(self.root, self.mesh, x)
    506505                       
    507506        # Update interpolation matrix A if necessary
  • anuga_core/source/anuga/fit_interpolate/search_functions.py

    r7701 r7707  
    3030
    3131# FIXME(Ole): Could we come up with a less confusing structure?
     32# FIXME(James): remove this global var
    3233LAST_TRIANGLE = [[-10,
    3334                   (num.array([[max_float, max_float],
     
    3839                     num.array([-1.1,-1.1])))]]
    3940
    40 def search_tree_of_vertices(root, x):
     41last_triangle = LAST_TRIANGLE                                   
     42                                         
     43def search_tree_of_vertices(root, mesh, x):
    4144    """
    4245    Find the triangle (element) that the point x is in.
     
    4447    Inputs:
    4548        root: A quad tree of the vertices
     49        mesh: The mesh which the quad tree indexes into
    4650        x:    The point being placed
    4751   
     
    6771   
    6872    # Get triangles in the cell that the point is in.
    69     # Triangle is a list, first element triangle_id,
    70     # second element the triangle
    71     triangles = root.search(x[0], x[1])
     73    tri_indices = root.search(x[0], x[1])
     74    triangles = _trilist_from_indices(mesh, tri_indices)
     75   
    7276    element_found, sigma0, sigma1, sigma2, k = \
    7377                   _search_triangles_of_vertices(triangles, x)
     
    7579    is_more_elements = True
    7680   
    77     while not element_found and is_more_elements:
    78         triangles, branch = root.expand_search()
    79         if branch == []:
    80             # Searching all the verts from the root cell that haven't
    81             # been searched.  This is the last try
    82             element_found, sigma0, sigma1, sigma2, k = \
    83                            _search_triangles_of_vertices(triangles, x)
    84             is_more_elements = False
    85         else:
    86             element_found, sigma0, sigma1, sigma2, k = \
    87                        _search_triangles_of_vertices(triangles, x)
     81    # while not element_found and is_more_elements:
     82        # triangles, branch = root.expand_search()
     83        # if branch == []:
     84            # # Searching all the verts from the root cell that haven't
     85            # # been searched.  This is the last try
     86            # element_found, sigma0, sigma1, sigma2, k = \
     87                           # _search_triangles_of_vertices(triangles, x)
     88            # is_more_elements = False
     89        # else:
     90            # element_found, sigma0, sigma1, sigma2, k = \
     91                       # _search_triangles_of_vertices(triangles, x)
    8892                       
    8993       
     
    103107    global last_triangle
    104108
    105     x = ensure_numeric(x, num.float)    
     109    x = ensure_numeric(x, num.float)    
    106110   
    107111    # These statments are needed if triangles is empty
     
    114118    for k, tri_verts_norms in triangles:
    115119        tri = tri_verts_norms[0]
    116         tri = ensure_numeric(tri)               
     120        tri = ensure_numeric(tri)       
    117121        # k is the triangle index
    118         # tri is a list of verts (x, y), representing a tringle
     122        # tri is a list of verts (x, y), representing a triangle
    119123        # Find triangle that contains x (if any) and interpolate
    120124       
    121         # Input check disabled to speed things up.     
     125        # Input check disabled to speed things up.   
    122126        if bool(_is_inside_triangle(x, tri, int(True), 1.0e-12, 1.0e-12)):
    123127           
     
    135139
    136140
     141def _trilist_from_indices(mesh, indices):
     142    """return a list of lists. For the inner lists,
     143    The first element is the triangle index,
     144    the second element is a list.for this list
     145       the first element is a list of three (x, y) vertices,
     146       the following elements are the three triangle normals.
     147
     148    """
     149
     150    ret_list = []
     151    for i in indices:
     152        vertices = mesh.get_vertex_coordinates(triangle_id=i, absolute=True)
     153        n0 = mesh.get_normal(i, 0)
     154        n1 = mesh.get_normal(i, 1)
     155        n2 = mesh.get_normal(i, 2)
     156        ret_list.append([i, [vertices, (n0, n1, n2)]])
     157    return ret_list
     158               
    137159           
    138160def compute_interpolation_values(triangle, n0, n1, n2, x):
  • anuga_core/source/anuga/fit_interpolate/test_search_functions.py

    r7276 r7707  
    55from search_functions import search_tree_of_vertices, set_last_triangle
    66from search_functions import _search_triangles_of_vertices
     7from search_functions import _trilist_from_indices
    78from search_functions import compute_interpolation_values
    89
     
    3940
    4041
     42    def test_off_and_boundary(self):
     43        """test_off: Test a point off the mesh
     44        """
     45
     46        points, vertices, boundary = rectangular(1, 1, 1, 1)
     47        mesh = Mesh(points, vertices, boundary)
     48
     49        #Test that points are arranged in a counter clock wise order
     50        mesh.check_integrity()
     51
     52        root = build_quadtree(mesh, max_points_per_cell = 1)
     53        set_last_triangle()
     54
     55        found, s0, s1, s2, k = search_tree_of_vertices(root, mesh, [-0.2, 10.7])
     56        assert found is False
     57
     58        found, s0, s1, s2, k = search_tree_of_vertices(root, mesh, [0, 0])
     59        assert found is True
     60               
     61               
    4162    def test_small(self):
    4263        """test_small: Two triangles
     
    5374
    5475        x = [0.2, 0.7]
    55         found, s0, s1, s2, k = search_tree_of_vertices(root, x)
     76        found, s0, s1, s2, k = search_tree_of_vertices(root, mesh, x)
    5677        assert k == 1 # Triangle one
    57         assert found is True
    58 
     78        assert found is True           
     79               
    5980    def test_bigger(self):
    6081        """test_bigger
     
    7697                  [10, 3]]:
    7798           
    78             found, s0, s1, s2, k = search_tree_of_vertices(root,
    79                                                            ensure_numeric(x))
    80 
     99            found, s0, s1, s2, k = search_tree_of_vertices(root, mesh,
     100                                                           ensure_numeric(x))                                                             
     101                                                                                                                   
    81102            if k >= 0:
    82103                V = mesh.get_vertex_coordinates(k) # nodes for triangle k
     
    109130                      [10, 3]]:
    110131               
    111                 found, s0, s1, s2, k = search_tree_of_vertices(root, x)
     132                found, s0, s1, s2, k = search_tree_of_vertices(root, mesh, x)
    112133
    113134                if k >= 0:
     
    134155        # One point
    135156        x = ensure_numeric([0.5, 0.5])
    136         candidate_vertices = root.search(x[0], x[1])
    137 
    138         #print x, candidate_vertices
     157
     158        triangles = _trilist_from_indices(mesh, root.search(x[0], x[1]))
     159       
    139160        found, sigma0, sigma1, sigma2, k = \
    140                _search_triangles_of_vertices(candidate_vertices,
    141                                              x)
     161               _search_triangles_of_vertices(triangles, x)
    142162
    143163        if k >= 0:
     
    155175                  [10, 3]]:
    156176               
    157             candidate_vertices = root.search(x[0], x[1])
     177            triangles = _trilist_from_indices(mesh, root.search(x[0], x[1]))
    158178
    159179            #print x, candidate_vertices
    160180            found, sigma0, sigma1, sigma2, k = \
    161                    _search_triangles_of_vertices(candidate_vertices,
     181                   _search_triangles_of_vertices(triangles,
    162182                                                 ensure_numeric(x))
    163183            if k >= 0:
     
    194214   
    195215
    196         root = Cell(-3, 9, -3, 9, mesh,
     216        root = Cell(-3, 9, -3, 9,
    197217                    max_points_per_cell = 4)
    198218        #Insert indices of all vertices
     
    206226        x = [2.5, 1.5]
    207227        element_found, sigma0, sigma1, sigma2, k = \
    208                        search_tree_of_vertices(root, x)
     228                       search_tree_of_vertices(root, mesh, x)
    209229        # One point
    210230        x = [3.00005, 2.999994]
    211231        element_found, sigma0, sigma1, sigma2, k = \
    212                        search_tree_of_vertices(root, x)
     232                       search_tree_of_vertices(root, mesh, x)
    213233        assert element_found is True
    214234        assert k == 1
  • anuga_core/source/anuga/utilities/quad.py

    r7703 r7707  
    1 """quad.py - quad tree data structure for fast indexing of points in the plane
     1"""quad.py - quad tree data structure for fast indexing of regions in the plane.
     2
     3This is a generic structure that can be used to store any geometry in a quadtree.
    24
    35
     
    810import anuga.utilities.log as log
    911
    10 
    11 #FIXME verts are added one at a time.
    12 #FIXME add max min x y in general_mesh
    13 
     12# Allow children to be slightly bigger than their parents to prevent straddling of a boundary
     13SPLIT_BORDER_RATIO    = 0.55
     14
     15class AABB:
     16    """Axially-aligned bounding box class.
     17    """
     18   
     19    def __init__(self, xmin, xmax, ymin, ymax):
     20        self.xmin = round(xmin,5)   
     21        self.xmax = round(xmax,5)
     22        self.ymin = round(ymin,5)   
     23        self.ymax = round(ymax,5)
     24
     25    def __repr__(self):
     26        return '(xmin:%f, xmax:%f, ymin:%f, ymax:%f)' \
     27               % (round(self.xmin,1), round(self.xmax,1), round(self.ymin,1), round(self.ymax, 1))
     28       
     29    def size(self):
     30        """return size as (w,h)"""
     31        return self.xmax - self.xmin, self.ymax - self.ymin
     32       
     33    def split(self, border=SPLIT_BORDER_RATIO):
     34        """Split along shorter axis.
     35           return 2 subdivided AABBs.
     36        """
     37       
     38        width, height = self.size()
     39        assert width >= 0 and height >= 0
     40       
     41        if (width > height):
     42            # split vertically
     43            return AABB(self.xmin, self.xmin+width*border, self.ymin, self.ymax), \
     44                   AABB(self.xmax-width*border, self.xmax, self.ymin, self.ymax)
     45        else:
     46            # split horizontally       
     47            return AABB(self.xmin, self.xmax, self.ymin, self.ymin+height*border), \
     48                   AABB(self.xmin, self.xmax, self.ymax-height*border, self.ymax)   
     49   
     50    def is_trivial_in(self, test):
     51        if (test.xmin < self.xmin) or (test.xmax > self.xmax):
     52            return False       
     53        if (test.ymin < self.ymin) or (test.ymax > self.ymax):
     54            return False       
     55        return True
     56 
     57    def contains(self, x, y):
     58        return (self.xmin <= x <= self.xmax) and (self.ymin <= y <= self.ymax)
     59           
    1460class Cell(TreeNode):
    1561    """class Cell
     
    2773    """
    2874 
    29     def __init__(self, southern, northern, western, eastern, mesh,
    30                  name = 'cell',
    31                  max_points_per_cell = 4):
     75    def __init__(self, extents,
     76         name = 'cell'):
    3277 
    3378        # Initialise base classes
    3479        TreeNode.__init__(self, string.lower(name))
    35        
    36         # Initialise cell
    37         self.southern = round(southern,5)   
    38         self.northern = round(northern,5)
    39         self.western = round(western,5)   
    40         self.eastern = round(eastern,5)
    41         self.mesh = mesh
    42 
     80   
     81        self.extents = extents
     82       
    4383        # The points in this cell     
    44         self.points = []
    45        
    46         self.max_points_per_cell = max_points_per_cell
    47        
    48        
     84        self.leaves = []
     85        self.children = None
     86       
     87   
    4988    def __repr__(self):
    50         return self.name 
    51 
    52 
    53     def spawn(self):
    54         """Create four child cells unless they already exist
    55         """
    56 
    57         if self.children:
    58             return
    59         else:
    60             self.children = []
    61 
    62         # convenience variables
    63         cs = self.southern   
    64         cn = self.northern
    65         cw = self.western   
    66         ce = self.eastern   
    67         mesh = self.mesh
    68 
    69         # create 4 child cells
    70         self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,mesh,self.name+'_nw'))
    71         self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,mesh,self.name+'_ne'))
    72         self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,mesh,self.name+'_se'))
    73         self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,mesh,self.name+'_sw'))
    74        
    75  
    76     def search(self, x, y, get_vertices=False):
    77         """Find all point indices sharing the same cell as point (x, y)
    78         """
    79         branch = []
    80         points, branch = self.search_branch(x, y, branch, get_vertices=get_vertices)
    81         self.branch = branch 
    82         return points
    83 
    84 
    85     def search_branch(self, x, y, branch, get_vertices=False):
    86         """Find all point indices sharing the same cell as point (x, y)
    87         """
    88         points = []
    89         if self.children:
    90             for child in self:
    91                 if (child.western <= x < child.eastern) and (child.southern <= y < child.northern):
    92                     brothers = list(self.children)
    93                     brothers.remove(child)
    94                     branch.append(brothers)
    95                     points, branch = child.search_branch(x,y, branch,
    96                                                   get_vertices=get_vertices)
    97                    
    98         else:
    99             # Leaf node: Get actual waypoints
    100             points = self.retrieve(get_vertices=get_vertices)     
    101         return points, branch
    102 
    103 
    104     def expand_search(self, get_vertices=False):
    105         """Find all point indices 'up' one cell from the last search
    106         """
    107        
    108         points = []
    109         if self.branch == []:
    110             points = []
    111         else:
    112             three_cells = self.branch.pop()
    113             for cell in three_cells:
    114                 points += cell.retrieve(get_vertices=get_vertices)
    115         return points, self.branch
    116 
    117 
    118     def contains(self, point_id):   
    119         """True only if P's coordinates lie within cell boundaries
    120         This methods has two forms:
    121        
    122         cell.contains(index)
    123         #True if cell contains indexed point
    124         """
    125         x, y = self.mesh.get_node(point_id, absolute=True)     
    126        
    127         return (self.western <= x < self.eastern) and (self.southern <= y < self.northern)
    128    
    129    
    130     def insert(self, points, split = False):
    131         """insert point(s) in existing tree structure below self
    132            and split if requested
    133         """
    134 
    135         # Call insert for each element of a list of points
    136         if type(points) == types.ListType:
    137             for point in points:
    138                 self.insert(point, split)
    139         else:
    140             #Only one point given as argument   
    141             point = points
    142        
    143             # Find appropriate cell
    144             if self.children is not None:
    145                 for child in self:
    146                     if child.contains(point):
    147                         child.insert(point, split)
    148                         break
    149             else:
    150                 # self is a leaf cell: insert point into cell
    151                 if self.contains(point):
    152                     self.store(point)
    153                 else:
    154                     # Have to take into account of georef.
    155                     #x = self.mesh.coordinates[point][0]
    156                     #y = self.mesh.coordinates[point][1]
    157                     node = self.mesh.get_node(point, absolute=True)
    158                     msg = ('point not in region: %s\nnode=%s'
    159                            % (str(point), str(node)))
    160                     raise Exception, msg
    161                
    162                
    163         #Split datastructure if requested       
    164         if split is True:
    165             self.split()
    166                
    167 
    168 
    169     def store(self,objects):
    170        
    171         if type(objects) not in [types.ListType,types.TupleType]:
    172             self.points.append(objects)
    173         else:
    174             self.points.extend(objects)
    175 
    176 
    177     def retrieve_triangles(self):
    178         """return a list of lists. For the inner lists,
    179         The first element is the triangle index,
    180         the second element is a list.for this list
    181            the first element is a list of three (x, y) vertices,
    182            the following elements are the three triangle normals.
    183 
    184         This info is used in searching for a triangle that a point is in.
    185 
    186         Post condition
    187         No more points can be added to the quad tree, since the
    188         points data structure is removed.
    189         """
    190         # FIXME Tidy up the structure that is returned.
    191         # if the triangles att has been made
    192         # return it.
    193         if not hasattr(self,'triangles'):
    194             # use a dictionary to remove duplicates
    195             triangles = {}
    196             verts = self.retrieve_vertices()
    197             for vert in verts:
    198                 triangle_list = self.mesh.get_triangles_and_vertices_per_node(vert)
    199                 for k, _ in triangle_list:
    200                     if not triangles.has_key(k):
    201                         tri = self.mesh.get_vertex_coordinates(k,
    202                                                                absolute=True)
    203                         n0 = self.mesh.get_normal(k, 0)
    204                         n1 = self.mesh.get_normal(k, 1)
    205                         n2 = self.mesh.get_normal(k, 2)
    206                         triangles[k]=(tri, (n0, n1, n2))
    207             self.triangles = triangles.items()
    208             # Delete the old cell data structure to save memory
    209             del self.points
    210         return self.triangles
    211            
    212     def retrieve_vertices(self):
    213          return self.points
    214 
    215 
    216     def retrieve(self, get_vertices=True):
    217          objects = []
    218          if self.children is None:
    219              if get_vertices is True:
    220                  objects = self.retrieve_vertices()
    221              else:
    222                  objects =  self.retrieve_triangles()
    223          else: 
    224              for child in self:
    225                  objects += child.retrieve(get_vertices=get_vertices)
    226          return objects
    227        
    228 
    229     def count(self, keywords=None):
    230         """retrieve number of stored objects beneath this node inclusive
    231         """
    232        
    233         num_waypoint = 0
    234         if self.children:
    235             for child in self:
    236                 num_waypoint = num_waypoint + child.count()
    237         else:
    238             num_waypoint = len(self.points)
    239         return num_waypoint
    240  
     89        str = '%s: leaves: %d' \
     90               % (self.name , len(self.leaves))
     91        if self.children:
     92            str += ', children: %d' % (len(self.children))
     93        return str
     94
     95   
    24196
    24297    def clear(self):
     
    246101    def clear_leaf_node(self):
    247102        """Clears storage in leaf node.
    248         Called from Treenod.
    249         Must exist.     
    250         """
    251         self.points = []
    252        
    253        
     103    Called from Treenode.
     104    Must exist.   
     105    """
     106        self.leaves = []
     107   
     108   
    254109    def clear_internal_node(self):
    255110        """Called from Treenode.   
    256         Must exist.
    257         """
    258         pass
    259 
    260 
    261 
    262     def split(self, threshold=None):
    263         """
    264         Partition cell when number of contained waypoints exceeds
    265         threshold.  All waypoints are then moved into correct
    266         child cell.
    267         """
    268         if threshold == None:
    269            threshold = self.max_points_per_cell
    270            
    271         #FIXME, mincellsize removed.  base it on side length, if needed
    272        
    273         #Protect against silly thresholds such as -1
    274         if threshold < 1:
    275             return
    276        
    277         if not self.children:               # Leaf cell
    278             if self.count() > threshold :   
    279                 #Split is needed
    280                 points = self.retrieve()    # Get points from leaf cell
    281                 self.clear()                # and remove them from storage
    282                    
    283                 self.spawn()                # Spawn child cells and move
    284                 for p in points:            # points to appropriate child
    285                     for child in self:
    286                         if child.contains(p):
    287                             child.insert(p)
    288                             break
    289                        
    290         if self.children:                   # Parent cell
    291             for child in self:              # split (possibly newly created)
    292                 child.split(threshold)      # child cells recursively
    293              
    294 
    295 
    296     def Get_tree(self,depth=0):
    297         """Traverse tree below self
    298            Print for each node the name and
    299            if it is a leaf the number of objects
    300         """
    301         s = ''
    302         if depth == 0:
    303             s = '\n'
     111    Must exist.
     112    """
     113        self.leaves = []
     114
     115
     116    def insert(self, new_leaf):
     117        # process list items sequentially
     118        if type(new_leaf)==type(list()):
     119            ret_val = []
     120            for leaf in new_leaf:
     121                self._insert(leaf)
     122        else:
     123            self._insert(new_leaf)
     124
     125
     126    def _insert(self, new_leaf):   
     127        new_region, data = new_leaf
     128       
     129        # recurse down to any children until we get an intersection
     130        if self.children:
     131            for child in self.children:
     132                if child.extents.is_trivial_in(new_region):
     133                    child._insert(new_leaf)
     134                    return
     135        else:           
     136            # try splitting this cell and see if we get a trivial in
     137            subregion1, subregion2 = self.extents.split()
     138            if subregion1.is_trivial_in(new_region):
     139                self.children = [Cell(subregion1), Cell(subregion2)]   
     140                self.children[0]._insert(new_leaf)
     141                return
     142            elif subregion2.is_trivial_in(new_region):
     143                self.children = [Cell(subregion1), Cell(subregion2)]   
     144                self.children[1]._insert(new_leaf)
     145                return               
     146   
     147        # recursion ended without finding a fit, so attach it as a leaf
     148        self.leaves.append(new_leaf)
     149       
     150     
     151    def retrieve(self):
     152        """Get all leaves from this tree. """
     153       
     154        leaves_found = list(self.leaves)
     155       
     156        if not self.children:
     157            return leaves_found
     158
     159        for child in self.children:
     160            leaves_found.extend(child.retrieve())
    304161           
    305         s += "%s%s:" % ('  '*depth, self.name)
    306         if self.children:
    307             s += '\n'
    308             for child in self.children:
    309                 s += child.Get_tree(depth+1)
    310         else:
    311             s += '(#wp=%d)\n' %(self.count())
    312 
    313         return s
    314 
    315        
     162        return leaves_found
     163
     164    def count(self):
     165        """Count all leaves from this tree. """
     166       
     167        leaves_found = len(self.leaves)
     168       
     169        if not self.children:
     170            return leaves_found
     171
     172        for child in self.children:
     173            leaves_found += child.count()
     174           
     175        return leaves_found       
     176
    316177    def show(self, depth=0):
    317178        """Traverse tree below self
    318            Print for each node the name and
    319            if it is a leaf the number of objects
    320179        """
    321180        if depth == 0:
    322181            log.critical()
    323         log.critical("%s%s" % ('  '*depth, self.name))
     182        print '%s%s' % ('  '*depth, self.name), self.extents,' [', self.leaves, ']'
    324183        if self.children:
    325184            log.critical()
    326185            for child in self.children:
    327186                child.show(depth+1)
    328         else:
    329             log.critical('(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'
    330                          % (self.western, self.eastern, self.southern,
    331                             self.northern, self.count()))
    332 
    333 
    334     def show_all(self,depth=0):
    335         """Traverse tree below self
    336            Print for each node the name and if it is a leaf all its objects
    337         """
    338         if depth == 0:
    339             log.critical()
    340         log.critical("%s%s:" % ('  '*depth, self.name))
    341         if self.children:
    342             print
    343             for child in self.children:
    344                 child.show_all(depth+1)
    345         else:
    346             log.critical('%s' % self.retrieve())
    347 
    348 
    349     def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
    350         """Traverse tree below self and find minimal cell radius,
    351            maximumtree depth and maximum number of waypoints per leaf.
    352         """
    353 
    354         if self.children:
    355             for child in self.children:
    356                 min_rad, max_depth, max_points =\
    357                          child.Stats(depth+1,min_rad,max_depth,max_points)
    358         else:
    359             #FIXME remvoe radius stuff
    360             #min_rad = sys.maxint
    361             #if self.radius < min_rad:   min_rad = self.radius
    362             if depth > max_depth: max_depth = depth
    363             num_points = self.count()
    364             if num_points > max_points: max_points = num_points
    365 
    366         #return min_rad, max_depth, max_points   
    367         return max_depth, max_points   
    368        
    369 
    370     #Class initialisation method
    371     # this is bad.  It adds a huge memory structure to the class.
    372     # When the instance is deleted the mesh hangs round (leaks).
    373     #def initialise(cls, mesh):
    374     #    cls.mesh = mesh
    375 
    376     #initialise = classmethod(initialise)
    377 
     187 
     188
     189    def search(self, x, y, get_vertices = False):
     190        """return a list of possible intersections with geometry"""
     191       
     192        intersecting_regions = []
     193       
     194        # test all leaves to see if they intersect the point
     195        for leaf in self.leaves:
     196            aabb, data = leaf
     197            if aabb.contains(x, y):
     198                if get_vertices:
     199                    intersecting_regions.append(leaf)
     200                else:
     201                    intersecting_regions.append(data)
     202       
     203        # recurse down into nodes that the point passes through
     204        if self.children:
     205            for child in self.children:   
     206                if child.extents.contains(x, y):
     207                    intersecting_regions.extend(child.search(x, y, get_vertices))
     208             
     209        return intersecting_regions
     210       
     211#from anuga.pmesh.mesh import Mesh
     212   
    378213def build_quadtree(mesh, max_points_per_cell = 4):
    379214    """Build quad tree for mesh.
     
    406241    #print "ymax", ymax
    407242   
    408     #FIXME: Use mesh.filename if it exists
    409     # why?
    410     root = Cell(ymin, ymax, xmin, xmax,mesh,
    411                 max_points_per_cell = max_points_per_cell)
    412 
    413     #root.show()
    414    
    415     #Insert indices of all vertices
    416     root.insert( range(mesh.number_of_nodes) )
    417 
    418     #Build quad tree and return
    419     root.split()
     243    root = Cell(AABB(xmin, xmax, ymin, ymax))
     244   
     245    N = len(mesh)
     246
     247    # Get x,y coordinates for all vertices for all triangles
     248    V = mesh.get_vertex_coordinates(absolute=True)
     249       
     250    # Check each triangle
     251    for i in range(N):
     252        x0, y0 = V[3*i, :]
     253        x1, y1 = V[3*i+1, :]
     254        x2, y2 = V[3*i+2, :]
     255
     256        # insert a tuple with an AABB, and the triangle index as data
     257        root._insert((AABB(min([x0, x1, x2]), max([x0, x1, x2]), \
     258                         min([y0, y1, y2]), max([y0, y1, y2])), \
     259                         i))
    420260
    421261    return root
  • anuga_core/source/anuga/utilities/test_quad.py

    r7703 r7707  
    22import numpy as num
    33
    4 from quad import Cell, build_quadtree
     4from quad import AABB, Cell, build_quadtree
    55from anuga.abstract_2d_finite_volumes.general_mesh import General_mesh as Mesh
    66
     
    1212
    1313    def setUp(self):
    14 
    15         a = [3, 107]
    16         b = [5, 107]
    17         c = [5, 105]
    18         d = [7, 107]
    19         e = [15, 115]
    20         f = [15, 130]
    21         g = [30, 110]
    22         h = [30, 130]
     14        pass
     15
     16    def tearDown(self):
     17        pass
     18
     19    def test_AABB_contains(self):
     20        box = AABB(1, 21, 1, 11)
     21        assert box.contains(10, 5)
     22        assert box.contains(1, 1)
     23        assert box.contains(20, 6)
     24        assert not box.contains(-1, -1)
     25        assert not box.contains(5, 70)
     26        assert not box.contains(6, -70)
     27        assert not box.contains(-1, 6)
     28        assert not box.contains(50, 6)       
     29       
     30    def test_AABB_split_vert(self):
     31        parent = AABB(1, 21, 1, 11)
     32       
     33        child1, child2 = parent.split(0.6)
     34
     35        self.assertEqual(child1.xmin, 1)
     36        self.assertEqual(child1.xmax, 13)
     37        self.assertEqual(child1.ymin, 1)
     38        self.assertEqual(child1.ymax, 11)
     39       
     40        self.assertEqual(child2.xmin, 9)
     41        self.assertEqual(child2.xmax, 21)
     42        self.assertEqual(child2.ymin, 1)
     43        self.assertEqual(child2.ymax, 11)   
     44
     45    def test_AABB_split_horiz(self):
     46        parent = AABB(1, 11, 1, 41)
     47       
     48        child1, child2 = parent.split(0.6)
     49
     50        self.assertEqual(child1.xmin, 1)
     51        self.assertEqual(child1.xmax, 11)
     52        self.assertEqual(child1.ymin, 1)
     53        self.assertEqual(child1.ymax, 25)
     54       
     55        self.assertEqual(child2.xmin, 1)
     56        self.assertEqual(child2.xmax, 11)
     57        self.assertEqual(child2.ymin, 17)
     58        self.assertEqual(child2.ymax, 41)         
     59       
     60    def test_add_data(self):
     61        cell = Cell(AABB(0,10, 0,5))
     62        cell.insert([(AABB(1,3, 1, 3), 111), (AABB(8,9, 1, 2), 222),  \
     63                     (AABB(7, 8, 3, 4), 333), (AABB(1, 10, 0, 1), 444)])
     64
     65        result = cell.retrieve()
     66        assert type(result) in [types.ListType,types.TupleType],\
     67                            'should be a list'
     68
     69        self.assertEqual(len(result),4)
     70       
     71    def test_search(self):
     72        test_region = (AABB(8,9, 1, 2), 222)
     73        cell = Cell(AABB(0,10, 0,5))
     74        cell.insert([(AABB(1,3, 1, 3), 111), test_region,  \
     75                     (AABB(7, 8, 3, 4), 333), (AABB(1, 10, 0, 1), 444)])
     76
     77        result =  cell.search(x = 8.5, y = 1.5, get_vertices=True)
     78        assert type(result) in [types.ListType,types.TupleType],\
     79                            'should be a list'
     80        self.assertEqual(result, [test_region], 'only 1 point should intersect')
     81
     82
     83    def test_clear_1(self):
     84        cell = Cell(AABB(0,10, 0,5))   
     85        cell.insert([(AABB(1,3, 1, 3), 111), (AABB(8,9, 1, 2), 222),  \
     86                     (AABB(7, 8, 3, 4), 333), (AABB(1, 10, 0, 1), 444)])
     87                     
     88        assert len(cell.retrieve()) == 4
     89        cell.clear()
     90
     91        assert len(cell.retrieve()) == 0
     92
     93    def test_build_quadtree(self):
     94
     95        a = [3, 7]
     96        b = [5, 7]
     97        c = [5, 5]
     98        d = [7, 7]
     99        e = [15, 15]
     100        f = [15, 30]
     101        g = [30, 10]
     102        h = [30, 30]
    23103
    24104        points = [a, b, c, d, e, f, g, h]
     
    27107        vertices = [[1,0,2], [1,3,4], [1,2,3], [5,4,7], [4,6,7]]
    28108
    29         mesh = Mesh(points, vertices)
    30         self.mesh = mesh
    31         self.cell = Cell(100, 140, 0, 40, mesh, 'cell')
    32 
    33     def tearDown(self):
    34         pass
    35 
    36     def test_add_points_2_cell(self):
    37         self.cell.insert(0)
    38         self.cell.insert(1)
    39 
    40         result = self.cell.retrieve()
    41         assert type(result) in [types.ListType,types.TupleType],\
    42                                 'should be a list'
    43         self.assertEqual(len(result),2)
    44 
    45     def test_add_points_2_cellII(self):
    46         self.cell.insert([0,1,2,3,4,5,6,7])
    47 
    48         result = self.cell.retrieve()
    49         assert type(result) in [types.ListType,types.TupleType],\
    50                                 'should be a list'
    51         self.assertEqual(len(result),8)
    52 
    53 
    54     def test_search(self):
    55         self.cell.insert([0,1,2,3,4,5,6,7])
    56         self.cell.split(4)
    57 
    58         result =  self.cell.search(x = 1, y = 101, get_vertices=True)
    59         assert type(result) in [types.ListType,types.TupleType],\
    60                                 'should be a list'
    61         self.assertEqual(result, [0,1,2,3])
    62 
    63 
    64     def test_clear_1(self):
    65         self.cell.insert([0,1,2,3,4,5,6,7])
    66         assert self.cell.count() == 8
    67         self.cell.clear()
    68 
    69         #This one actually revealed a bug :-)
    70         assert self.cell.count() == 0
    71 
    72     def test_clear_2(self):
    73         self.cell.insert([0,1,2,3,4,5,6,7])
    74         assert self.cell.count() == 8
    75         self.cell.split(2)
    76         assert self.cell.count() == 8
    77 
    78         self.cell.clear()
    79         assert self.cell.count() == 0
    80 
    81 
    82 
    83     def test_split(self):
    84         self.cell.insert([0,1,2,3,4,5,6,7], split = False)
    85 
    86         #No children yet
    87         assert self.cell.children is None
    88         assert self.cell.count() == 8
    89 
    90         #Split
    91         self.cell.split(4)
    92         #self.cell.show()
    93         #self.cell.show_all()
    94 
    95 
    96         #Now there are children
    97         assert self.cell.children is not None
    98         assert self.cell.count() == 8
    99 
    100 
    101     def test_build_quadtree(self):
    102 
    103         Q = build_quadtree(self.mesh)
     109        mesh = Mesh(points, vertices)
     110   
     111        Q = build_quadtree(mesh)
    104112        #Q.show()
    105113        #print Q.count()
    106         assert Q.count() == 8
    107 
    108 
    109 
    110         result = Q.search(3, 105, get_vertices=True)
     114        self.assertEqual(Q.count(), len(vertices))
     115
     116        # test a point that falls within a triangle
     117        result = Q.search(10, 10, get_vertices=True)
    111118        assert type(result) in [types.ListType,types.TupleType],\
    112                                 'should be a list'
    113         #print "result",result
    114         self.assertEqual(result, [0,1,2,3])
     119                            'should be a list'
     120        pos, index = result[0]
     121        self.assertEqual(index, 1)
    115122
    116123
    117124    def test_build_quadtreeII(self):
    118125
    119         self.cell = Cell(100, 140, 0, 40, 'cell')
     126        self.cell = Cell(AABB(100, 140, 0, 40), 'cell')
    120127
    121128        p0 = [34.6292076111,-7999.92529297]
     
    128135        vertices = [[0,1,2],[0,2,3]]
    129136
    130         mesh = Mesh(points, vertices)
     137        mesh = Mesh(points, vertices)
    131138
    132139        #This was causing round off error
    133140        Q = build_quadtree(mesh)
    134141       
    135     def test_interpolate_one_point_many_triangles(self):
     142    def NOtest_interpolate_one_point_many_triangles(self):
    136143        # this test has 10 triangles that share the same vert.
    137144        # If the number of points per cell in  a quad tree is less
     
    175182                      ]
    176183       
    177         mesh = Mesh(vertices, triangles)
     184        mesh = Mesh(vertices, triangles)
    178185        try:
    179186            Q = build_quadtree(mesh, max_points_per_cell = 9)
     
    186193    def test_retrieve_triangles(self):
    187194
    188         cell = Cell(0, 6, 0, 6, 'cell', max_points_per_cell=4)
     195        cell = Cell(AABB(0, 6, 0, 6), 'cell')
    189196
    190197        p0 = [2,1]
     
    198205        vertices = [[0,1,2],[0,2,3],[1,4,2]]
    199206
    200         mesh = Mesh(points, vertices)
     207        mesh = Mesh(points, vertices)
    201208
    202209        Q = build_quadtree(mesh)
    203         results = Q.search(5,1)
    204         assert len(results),2
    205         #print "results", results
    206         #print "results[0][0]", results[0][0]
    207         assert results[0],0
    208         assert results[1],2
    209         assert results[0][1],[[ 2.,  1.],
    210                      [ 4.,  1.],
    211                      [ 4.,  4.]]
    212         assert results[1][1],[[ 4.,  1.],
    213                      [ 5.,  4.],
    214                      [ 4.,  4.]]
    215         # this is the normals
    216         assert results[0][1][1],[[1.,  0.],
    217                      [-0.83205029,  0.5547002],
    218                      [ 0.,  -1.]]
    219                      
    220         # assert num.allclose(num.array(results),[[[ 2.,  1.],
    221         #[ 4.,  1.], [ 4.,  4.]], [[ 4.,  1.],[ 5.,  4.],[ 4.,  4.]]] )
     210        results = Q.search(4.5, 3)
     211        assert len(results) == 1
     212        self.assertEqual(results[0], 2)
    222213        results = Q.search(5,4.)
    223         ### print "results",results
    224         # results_dic={}
    225         # results_dic.update(results)
    226         assert len(results),3
    227         #print "results_dic[0]", results_dic[0]
    228         assert results[0][1],[[ 2.,  1.],
    229                      [ 4.,  1.],
    230                      [ 4.,  4.]]
    231         assert results[1][1],[[ 2.,  1.],
    232                      [ 4.,  4.],
    233                      [ 2.,  4.]]
    234         assert results[2][1],[[ 4.,  1.],
    235                      [ 5.,  4.],
    236                      [ 4.,  4.]]
    237         #assert allclose(array(results),[[[ 2.,  1.],[ 4.,  1.], [ 4.,  4.]]
    238          #                               ,[[ 2.,  1.],[ 4.,  4.], [ 2.,  4.]],
    239         #[[ 4.,  1.],  [ 5.,  4.], [ 4.,  4.]],
    240          #                               [[ 4.,  1.], [ 5.,  4.], [ 4.,  4.]]])
    241        
     214        self.assertEqual(len(results),1)
     215        self.assertEqual(results[0], 2)
    242216################################################################################
    243217
Note: See TracChangeset for help on using the changeset viewer.