Ignore:
Timestamp:
May 5, 2010, 4:06:02 PM (14 years ago)
Author:
hudson
Message:

New quadtree implementation - unoptimised and no tree balancing. A couple of failing unit tests to fix.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • anuga_core/source/anuga/utilities/quad.py

    r7703 r7707  
    1 """quad.py - quad tree data structure for fast indexing of points in the plane
     1"""quad.py - quad tree data structure for fast indexing of regions in the plane.
     2
     3This is a generic structure that can be used to store any geometry in a quadtree.
    24
    35
     
    810import anuga.utilities.log as log
    911
    10 
    11 #FIXME verts are added one at a time.
    12 #FIXME add max min x y in general_mesh
    13 
     12# Allow children to be slightly bigger than their parents to prevent straddling of a boundary
     13SPLIT_BORDER_RATIO    = 0.55
     14
     15class AABB:
     16    """Axially-aligned bounding box class.
     17    """
     18   
     19    def __init__(self, xmin, xmax, ymin, ymax):
     20        self.xmin = round(xmin,5)   
     21        self.xmax = round(xmax,5)
     22        self.ymin = round(ymin,5)   
     23        self.ymax = round(ymax,5)
     24
     25    def __repr__(self):
     26        return '(xmin:%f, xmax:%f, ymin:%f, ymax:%f)' \
     27               % (round(self.xmin,1), round(self.xmax,1), round(self.ymin,1), round(self.ymax, 1))
     28       
     29    def size(self):
     30        """return size as (w,h)"""
     31        return self.xmax - self.xmin, self.ymax - self.ymin
     32       
     33    def split(self, border=SPLIT_BORDER_RATIO):
     34        """Split along shorter axis.
     35           return 2 subdivided AABBs.
     36        """
     37       
     38        width, height = self.size()
     39        assert width >= 0 and height >= 0
     40       
     41        if (width > height):
     42            # split vertically
     43            return AABB(self.xmin, self.xmin+width*border, self.ymin, self.ymax), \
     44                   AABB(self.xmax-width*border, self.xmax, self.ymin, self.ymax)
     45        else:
     46            # split horizontally       
     47            return AABB(self.xmin, self.xmax, self.ymin, self.ymin+height*border), \
     48                   AABB(self.xmin, self.xmax, self.ymax-height*border, self.ymax)   
     49   
     50    def is_trivial_in(self, test):
     51        if (test.xmin < self.xmin) or (test.xmax > self.xmax):
     52            return False       
     53        if (test.ymin < self.ymin) or (test.ymax > self.ymax):
     54            return False       
     55        return True
     56 
     57    def contains(self, x, y):
     58        return (self.xmin <= x <= self.xmax) and (self.ymin <= y <= self.ymax)
     59           
    1460class Cell(TreeNode):
    1561    """class Cell
     
    2773    """
    2874 
    29     def __init__(self, southern, northern, western, eastern, mesh,
    30                  name = 'cell',
    31                  max_points_per_cell = 4):
     75    def __init__(self, extents,
     76         name = 'cell'):
    3277 
    3378        # Initialise base classes
    3479        TreeNode.__init__(self, string.lower(name))
    35        
    36         # Initialise cell
    37         self.southern = round(southern,5)   
    38         self.northern = round(northern,5)
    39         self.western = round(western,5)   
    40         self.eastern = round(eastern,5)
    41         self.mesh = mesh
    42 
     80   
     81        self.extents = extents
     82       
    4383        # The points in this cell     
    44         self.points = []
    45        
    46         self.max_points_per_cell = max_points_per_cell
    47        
    48        
     84        self.leaves = []
     85        self.children = None
     86       
     87   
    4988    def __repr__(self):
    50         return self.name 
    51 
    52 
    53     def spawn(self):
    54         """Create four child cells unless they already exist
    55         """
    56 
    57         if self.children:
    58             return
    59         else:
    60             self.children = []
    61 
    62         # convenience variables
    63         cs = self.southern   
    64         cn = self.northern
    65         cw = self.western   
    66         ce = self.eastern   
    67         mesh = self.mesh
    68 
    69         # create 4 child cells
    70         self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,mesh,self.name+'_nw'))
    71         self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,mesh,self.name+'_ne'))
    72         self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,mesh,self.name+'_se'))
    73         self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,mesh,self.name+'_sw'))
    74        
    75  
    76     def search(self, x, y, get_vertices=False):
    77         """Find all point indices sharing the same cell as point (x, y)
    78         """
    79         branch = []
    80         points, branch = self.search_branch(x, y, branch, get_vertices=get_vertices)
    81         self.branch = branch 
    82         return points
    83 
    84 
    85     def search_branch(self, x, y, branch, get_vertices=False):
    86         """Find all point indices sharing the same cell as point (x, y)
    87         """
    88         points = []
    89         if self.children:
    90             for child in self:
    91                 if (child.western <= x < child.eastern) and (child.southern <= y < child.northern):
    92                     brothers = list(self.children)
    93                     brothers.remove(child)
    94                     branch.append(brothers)
    95                     points, branch = child.search_branch(x,y, branch,
    96                                                   get_vertices=get_vertices)
    97                    
    98         else:
    99             # Leaf node: Get actual waypoints
    100             points = self.retrieve(get_vertices=get_vertices)     
    101         return points, branch
    102 
    103 
    104     def expand_search(self, get_vertices=False):
    105         """Find all point indices 'up' one cell from the last search
    106         """
    107        
    108         points = []
    109         if self.branch == []:
    110             points = []
    111         else:
    112             three_cells = self.branch.pop()
    113             for cell in three_cells:
    114                 points += cell.retrieve(get_vertices=get_vertices)
    115         return points, self.branch
    116 
    117 
    118     def contains(self, point_id):   
    119         """True only if P's coordinates lie within cell boundaries
    120         This methods has two forms:
    121        
    122         cell.contains(index)
    123         #True if cell contains indexed point
    124         """
    125         x, y = self.mesh.get_node(point_id, absolute=True)     
    126        
    127         return (self.western <= x < self.eastern) and (self.southern <= y < self.northern)
    128    
    129    
    130     def insert(self, points, split = False):
    131         """insert point(s) in existing tree structure below self
    132            and split if requested
    133         """
    134 
    135         # Call insert for each element of a list of points
    136         if type(points) == types.ListType:
    137             for point in points:
    138                 self.insert(point, split)
    139         else:
    140             #Only one point given as argument   
    141             point = points
    142        
    143             # Find appropriate cell
    144             if self.children is not None:
    145                 for child in self:
    146                     if child.contains(point):
    147                         child.insert(point, split)
    148                         break
    149             else:
    150                 # self is a leaf cell: insert point into cell
    151                 if self.contains(point):
    152                     self.store(point)
    153                 else:
    154                     # Have to take into account of georef.
    155                     #x = self.mesh.coordinates[point][0]
    156                     #y = self.mesh.coordinates[point][1]
    157                     node = self.mesh.get_node(point, absolute=True)
    158                     msg = ('point not in region: %s\nnode=%s'
    159                            % (str(point), str(node)))
    160                     raise Exception, msg
    161                
    162                
    163         #Split datastructure if requested       
    164         if split is True:
    165             self.split()
    166                
    167 
    168 
    169     def store(self,objects):
    170        
    171         if type(objects) not in [types.ListType,types.TupleType]:
    172             self.points.append(objects)
    173         else:
    174             self.points.extend(objects)
    175 
    176 
    177     def retrieve_triangles(self):
    178         """return a list of lists. For the inner lists,
    179         The first element is the triangle index,
    180         the second element is a list.for this list
    181            the first element is a list of three (x, y) vertices,
    182            the following elements are the three triangle normals.
    183 
    184         This info is used in searching for a triangle that a point is in.
    185 
    186         Post condition
    187         No more points can be added to the quad tree, since the
    188         points data structure is removed.
    189         """
    190         # FIXME Tidy up the structure that is returned.
    191         # if the triangles att has been made
    192         # return it.
    193         if not hasattr(self,'triangles'):
    194             # use a dictionary to remove duplicates
    195             triangles = {}
    196             verts = self.retrieve_vertices()
    197             for vert in verts:
    198                 triangle_list = self.mesh.get_triangles_and_vertices_per_node(vert)
    199                 for k, _ in triangle_list:
    200                     if not triangles.has_key(k):
    201                         tri = self.mesh.get_vertex_coordinates(k,
    202                                                                absolute=True)
    203                         n0 = self.mesh.get_normal(k, 0)
    204                         n1 = self.mesh.get_normal(k, 1)
    205                         n2 = self.mesh.get_normal(k, 2)
    206                         triangles[k]=(tri, (n0, n1, n2))
    207             self.triangles = triangles.items()
    208             # Delete the old cell data structure to save memory
    209             del self.points
    210         return self.triangles
    211            
    212     def retrieve_vertices(self):
    213          return self.points
    214 
    215 
    216     def retrieve(self, get_vertices=True):
    217          objects = []
    218          if self.children is None:
    219              if get_vertices is True:
    220                  objects = self.retrieve_vertices()
    221              else:
    222                  objects =  self.retrieve_triangles()
    223          else: 
    224              for child in self:
    225                  objects += child.retrieve(get_vertices=get_vertices)
    226          return objects
    227        
    228 
    229     def count(self, keywords=None):
    230         """retrieve number of stored objects beneath this node inclusive
    231         """
    232        
    233         num_waypoint = 0
    234         if self.children:
    235             for child in self:
    236                 num_waypoint = num_waypoint + child.count()
    237         else:
    238             num_waypoint = len(self.points)
    239         return num_waypoint
    240  
     89        str = '%s: leaves: %d' \
     90               % (self.name , len(self.leaves))
     91        if self.children:
     92            str += ', children: %d' % (len(self.children))
     93        return str
     94
     95   
    24196
    24297    def clear(self):
     
    246101    def clear_leaf_node(self):
    247102        """Clears storage in leaf node.
    248         Called from Treenod.
    249         Must exist.     
    250         """
    251         self.points = []
    252        
    253        
     103    Called from Treenode.
     104    Must exist.   
     105    """
     106        self.leaves = []
     107   
     108   
    254109    def clear_internal_node(self):
    255110        """Called from Treenode.   
    256         Must exist.
    257         """
    258         pass
    259 
    260 
    261 
    262     def split(self, threshold=None):
    263         """
    264         Partition cell when number of contained waypoints exceeds
    265         threshold.  All waypoints are then moved into correct
    266         child cell.
    267         """
    268         if threshold == None:
    269            threshold = self.max_points_per_cell
    270            
    271         #FIXME, mincellsize removed.  base it on side length, if needed
    272        
    273         #Protect against silly thresholds such as -1
    274         if threshold < 1:
    275             return
    276        
    277         if not self.children:               # Leaf cell
    278             if self.count() > threshold :   
    279                 #Split is needed
    280                 points = self.retrieve()    # Get points from leaf cell
    281                 self.clear()                # and remove them from storage
    282                    
    283                 self.spawn()                # Spawn child cells and move
    284                 for p in points:            # points to appropriate child
    285                     for child in self:
    286                         if child.contains(p):
    287                             child.insert(p)
    288                             break
    289                        
    290         if self.children:                   # Parent cell
    291             for child in self:              # split (possibly newly created)
    292                 child.split(threshold)      # child cells recursively
    293              
    294 
    295 
    296     def Get_tree(self,depth=0):
    297         """Traverse tree below self
    298            Print for each node the name and
    299            if it is a leaf the number of objects
    300         """
    301         s = ''
    302         if depth == 0:
    303             s = '\n'
     111    Must exist.
     112    """
     113        self.leaves = []
     114
     115
     116    def insert(self, new_leaf):
     117        # process list items sequentially
     118        if type(new_leaf)==type(list()):
     119            ret_val = []
     120            for leaf in new_leaf:
     121                self._insert(leaf)
     122        else:
     123            self._insert(new_leaf)
     124
     125
     126    def _insert(self, new_leaf):   
     127        new_region, data = new_leaf
     128       
     129        # recurse down to any children until we get an intersection
     130        if self.children:
     131            for child in self.children:
     132                if child.extents.is_trivial_in(new_region):
     133                    child._insert(new_leaf)
     134                    return
     135        else:           
     136            # try splitting this cell and see if we get a trivial in
     137            subregion1, subregion2 = self.extents.split()
     138            if subregion1.is_trivial_in(new_region):
     139                self.children = [Cell(subregion1), Cell(subregion2)]   
     140                self.children[0]._insert(new_leaf)
     141                return
     142            elif subregion2.is_trivial_in(new_region):
     143                self.children = [Cell(subregion1), Cell(subregion2)]   
     144                self.children[1]._insert(new_leaf)
     145                return               
     146   
     147        # recursion ended without finding a fit, so attach it as a leaf
     148        self.leaves.append(new_leaf)
     149       
     150     
     151    def retrieve(self):
     152        """Get all leaves from this tree. """
     153       
     154        leaves_found = list(self.leaves)
     155       
     156        if not self.children:
     157            return leaves_found
     158
     159        for child in self.children:
     160            leaves_found.extend(child.retrieve())
    304161           
    305         s += "%s%s:" % ('  '*depth, self.name)
    306         if self.children:
    307             s += '\n'
    308             for child in self.children:
    309                 s += child.Get_tree(depth+1)
    310         else:
    311             s += '(#wp=%d)\n' %(self.count())
    312 
    313         return s
    314 
    315        
     162        return leaves_found
     163
     164    def count(self):
     165        """Count all leaves from this tree. """
     166       
     167        leaves_found = len(self.leaves)
     168       
     169        if not self.children:
     170            return leaves_found
     171
     172        for child in self.children:
     173            leaves_found += child.count()
     174           
     175        return leaves_found       
     176
    316177    def show(self, depth=0):
    317178        """Traverse tree below self
    318            Print for each node the name and
    319            if it is a leaf the number of objects
    320179        """
    321180        if depth == 0:
    322181            log.critical()
    323         log.critical("%s%s" % ('  '*depth, self.name))
     182        print '%s%s' % ('  '*depth, self.name), self.extents,' [', self.leaves, ']'
    324183        if self.children:
    325184            log.critical()
    326185            for child in self.children:
    327186                child.show(depth+1)
    328         else:
    329             log.critical('(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'
    330                          % (self.western, self.eastern, self.southern,
    331                             self.northern, self.count()))
    332 
    333 
    334     def show_all(self,depth=0):
    335         """Traverse tree below self
    336            Print for each node the name and if it is a leaf all its objects
    337         """
    338         if depth == 0:
    339             log.critical()
    340         log.critical("%s%s:" % ('  '*depth, self.name))
    341         if self.children:
    342             print
    343             for child in self.children:
    344                 child.show_all(depth+1)
    345         else:
    346             log.critical('%s' % self.retrieve())
    347 
    348 
    349     def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
    350         """Traverse tree below self and find minimal cell radius,
    351            maximumtree depth and maximum number of waypoints per leaf.
    352         """
    353 
    354         if self.children:
    355             for child in self.children:
    356                 min_rad, max_depth, max_points =\
    357                          child.Stats(depth+1,min_rad,max_depth,max_points)
    358         else:
    359             #FIXME remvoe radius stuff
    360             #min_rad = sys.maxint
    361             #if self.radius < min_rad:   min_rad = self.radius
    362             if depth > max_depth: max_depth = depth
    363             num_points = self.count()
    364             if num_points > max_points: max_points = num_points
    365 
    366         #return min_rad, max_depth, max_points   
    367         return max_depth, max_points   
    368        
    369 
    370     #Class initialisation method
    371     # this is bad.  It adds a huge memory structure to the class.
    372     # When the instance is deleted the mesh hangs round (leaks).
    373     #def initialise(cls, mesh):
    374     #    cls.mesh = mesh
    375 
    376     #initialise = classmethod(initialise)
    377 
     187 
     188
     189    def search(self, x, y, get_vertices = False):
     190        """return a list of possible intersections with geometry"""
     191       
     192        intersecting_regions = []
     193       
     194        # test all leaves to see if they intersect the point
     195        for leaf in self.leaves:
     196            aabb, data = leaf
     197            if aabb.contains(x, y):
     198                if get_vertices:
     199                    intersecting_regions.append(leaf)
     200                else:
     201                    intersecting_regions.append(data)
     202       
     203        # recurse down into nodes that the point passes through
     204        if self.children:
     205            for child in self.children:   
     206                if child.extents.contains(x, y):
     207                    intersecting_regions.extend(child.search(x, y, get_vertices))
     208             
     209        return intersecting_regions
     210       
     211#from anuga.pmesh.mesh import Mesh
     212   
    378213def build_quadtree(mesh, max_points_per_cell = 4):
    379214    """Build quad tree for mesh.
     
    406241    #print "ymax", ymax
    407242   
    408     #FIXME: Use mesh.filename if it exists
    409     # why?
    410     root = Cell(ymin, ymax, xmin, xmax,mesh,
    411                 max_points_per_cell = max_points_per_cell)
    412 
    413     #root.show()
    414    
    415     #Insert indices of all vertices
    416     root.insert( range(mesh.number_of_nodes) )
    417 
    418     #Build quad tree and return
    419     root.split()
     243    root = Cell(AABB(xmin, xmax, ymin, ymax))
     244   
     245    N = len(mesh)
     246
     247    # Get x,y coordinates for all vertices for all triangles
     248    V = mesh.get_vertex_coordinates(absolute=True)
     249       
     250    # Check each triangle
     251    for i in range(N):
     252        x0, y0 = V[3*i, :]
     253        x1, y1 = V[3*i+1, :]
     254        x2, y2 = V[3*i+2, :]
     255
     256        # insert a tuple with an AABB, and the triangle index as data
     257        root._insert((AABB(min([x0, x1, x2]), max([x0, x1, x2]), \
     258                         min([y0, y1, y2]), max([y0, y1, y2])), \
     259                         i))
    420260
    421261    return root
Note: See TracChangeset for help on using the changeset viewer.