source: anuga_core/source/anuga/utilities/quad.py @ 7703

Last change on this file since 7703 was 7703, checked in by hudson, 14 years ago

Refactored quad.py to remove unused methods and duplicated code.

File size: 12.9 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8import anuga.utilities.log as log
9
10
11#FIXME verts are added one at a time.
12#FIXME add max min x y in general_mesh
13
14class Cell(TreeNode):
15    """class Cell
16
17    One cell in the plane delimited by southern, northern,
18    western, eastern boundaries.
19
20    Public Methods:
21        insert(point)
22        search(x, y)
23        split()
24        store()
25        retrieve()
26        count()
27    """
28 
29    def __init__(self, southern, northern, western, eastern, mesh,
30                 name = 'cell',
31                 max_points_per_cell = 4):
32 
33        # Initialise base classes
34        TreeNode.__init__(self, string.lower(name))
35       
36        # Initialise cell
37        self.southern = round(southern,5)   
38        self.northern = round(northern,5)
39        self.western = round(western,5)   
40        self.eastern = round(eastern,5)
41        self.mesh = mesh
42
43        # The points in this cell     
44        self.points = []
45       
46        self.max_points_per_cell = max_points_per_cell
47       
48       
49    def __repr__(self):
50        return self.name 
51
52
53    def spawn(self):
54        """Create four child cells unless they already exist
55        """
56
57        if self.children:
58            return
59        else:
60            self.children = []
61
62        # convenience variables
63        cs = self.southern   
64        cn = self.northern
65        cw = self.western   
66        ce = self.eastern   
67        mesh = self.mesh
68
69        # create 4 child cells
70        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,mesh,self.name+'_nw'))
71        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,mesh,self.name+'_ne'))
72        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,mesh,self.name+'_se'))
73        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,mesh,self.name+'_sw'))
74       
75 
76    def search(self, x, y, get_vertices=False):
77        """Find all point indices sharing the same cell as point (x, y)
78        """
79        branch = []
80        points, branch = self.search_branch(x, y, branch, get_vertices=get_vertices)
81        self.branch = branch 
82        return points
83
84
85    def search_branch(self, x, y, branch, get_vertices=False):
86        """Find all point indices sharing the same cell as point (x, y)
87        """
88        points = []
89        if self.children:
90            for child in self:
91                if (child.western <= x < child.eastern) and (child.southern <= y < child.northern):
92                    brothers = list(self.children)
93                    brothers.remove(child)
94                    branch.append(brothers)
95                    points, branch = child.search_branch(x,y, branch,
96                                                  get_vertices=get_vertices)
97                   
98        else:
99            # Leaf node: Get actual waypoints
100            points = self.retrieve(get_vertices=get_vertices)     
101        return points, branch
102
103
104    def expand_search(self, get_vertices=False):
105        """Find all point indices 'up' one cell from the last search
106        """
107       
108        points = []
109        if self.branch == []:
110            points = []
111        else:
112            three_cells = self.branch.pop()
113            for cell in three_cells:
114                points += cell.retrieve(get_vertices=get_vertices)
115        return points, self.branch
116
117
118    def contains(self, point_id):   
119        """True only if P's coordinates lie within cell boundaries
120        This methods has two forms:
121       
122        cell.contains(index)
123        #True if cell contains indexed point
124        """
125        x, y = self.mesh.get_node(point_id, absolute=True)     
126       
127        return (self.western <= x < self.eastern) and (self.southern <= y < self.northern)
128   
129   
130    def insert(self, points, split = False):
131        """insert point(s) in existing tree structure below self
132           and split if requested
133        """
134
135        # Call insert for each element of a list of points
136        if type(points) == types.ListType:
137            for point in points:
138                self.insert(point, split)
139        else:
140            #Only one point given as argument   
141            point = points
142       
143            # Find appropriate cell
144            if self.children is not None:
145                for child in self:
146                    if child.contains(point):
147                        child.insert(point, split)
148                        break
149            else:
150                # self is a leaf cell: insert point into cell
151                if self.contains(point):
152                    self.store(point)
153                else:
154                    # Have to take into account of georef.
155                    #x = self.mesh.coordinates[point][0]
156                    #y = self.mesh.coordinates[point][1]
157                    node = self.mesh.get_node(point, absolute=True)
158                    msg = ('point not in region: %s\nnode=%s'
159                           % (str(point), str(node)))
160                    raise Exception, msg
161               
162               
163        #Split datastructure if requested       
164        if split is True:
165            self.split()
166               
167
168
169    def store(self,objects):
170       
171        if type(objects) not in [types.ListType,types.TupleType]:
172            self.points.append(objects)
173        else:
174            self.points.extend(objects)
175
176
177    def retrieve_triangles(self):
178        """return a list of lists. For the inner lists,
179        The first element is the triangle index,
180        the second element is a list.for this list
181           the first element is a list of three (x, y) vertices,
182           the following elements are the three triangle normals.
183
184        This info is used in searching for a triangle that a point is in.
185
186        Post condition
187        No more points can be added to the quad tree, since the
188        points data structure is removed.
189        """
190        # FIXME Tidy up the structure that is returned.
191        # if the triangles att has been made
192        # return it.
193        if not hasattr(self,'triangles'):
194            # use a dictionary to remove duplicates
195            triangles = {}
196            verts = self.retrieve_vertices()
197            for vert in verts:
198                triangle_list = self.mesh.get_triangles_and_vertices_per_node(vert)
199                for k, _ in triangle_list:
200                    if not triangles.has_key(k):
201                        tri = self.mesh.get_vertex_coordinates(k,
202                                                               absolute=True)
203                        n0 = self.mesh.get_normal(k, 0)
204                        n1 = self.mesh.get_normal(k, 1)
205                        n2 = self.mesh.get_normal(k, 2) 
206                        triangles[k]=(tri, (n0, n1, n2))
207            self.triangles = triangles.items()
208            # Delete the old cell data structure to save memory
209            del self.points
210        return self.triangles
211           
212    def retrieve_vertices(self):
213         return self.points
214
215
216    def retrieve(self, get_vertices=True):
217         objects = []
218         if self.children is None:
219             if get_vertices is True:
220                 objects = self.retrieve_vertices()
221             else:
222                 objects =  self.retrieve_triangles()
223         else: 
224             for child in self:
225                 objects += child.retrieve(get_vertices=get_vertices)
226         return objects
227       
228
229    def count(self, keywords=None):
230        """retrieve number of stored objects beneath this node inclusive
231        """
232       
233        num_waypoint = 0
234        if self.children:
235            for child in self:
236                num_waypoint = num_waypoint + child.count()
237        else:
238            num_waypoint = len(self.points)
239        return num_waypoint
240 
241
242    def clear(self):
243        self.Prune()   # TreeNode method
244
245
246    def clear_leaf_node(self):
247        """Clears storage in leaf node.
248        Called from Treenod.
249        Must exist.     
250        """
251        self.points = []
252       
253       
254    def clear_internal_node(self):
255        """Called from Treenode.   
256        Must exist.
257        """
258        pass
259
260
261
262    def split(self, threshold=None):
263        """
264        Partition cell when number of contained waypoints exceeds
265        threshold.  All waypoints are then moved into correct
266        child cell.
267        """
268        if threshold == None:
269           threshold = self.max_points_per_cell
270           
271        #FIXME, mincellsize removed.  base it on side length, if needed
272       
273        #Protect against silly thresholds such as -1
274        if threshold < 1:
275            return
276       
277        if not self.children:               # Leaf cell
278            if self.count() > threshold :   
279                #Split is needed
280                points = self.retrieve()    # Get points from leaf cell
281                self.clear()                # and remove them from storage
282                   
283                self.spawn()                # Spawn child cells and move
284                for p in points:            # points to appropriate child
285                    for child in self:
286                        if child.contains(p):
287                            child.insert(p) 
288                            break
289                       
290        if self.children:                   # Parent cell
291            for child in self:              # split (possibly newly created)
292                child.split(threshold)      # child cells recursively
293             
294
295
296    def Get_tree(self,depth=0):
297        """Traverse tree below self
298           Print for each node the name and
299           if it is a leaf the number of objects
300        """
301        s = ''
302        if depth == 0:
303            s = '\n'
304           
305        s += "%s%s:" % ('  '*depth, self.name)
306        if self.children:
307            s += '\n'
308            for child in self.children:
309                s += child.Get_tree(depth+1)
310        else:
311            s += '(#wp=%d)\n' %(self.count())
312
313        return s
314
315       
316    def show(self, depth=0):
317        """Traverse tree below self
318           Print for each node the name and
319           if it is a leaf the number of objects
320        """
321        if depth == 0:
322            log.critical() 
323        log.critical("%s%s" % ('  '*depth, self.name))
324        if self.children:
325            log.critical()
326            for child in self.children:
327                child.show(depth+1)
328        else:
329            log.critical('(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'
330                         % (self.western, self.eastern, self.southern,
331                            self.northern, self.count()))
332
333
334    def show_all(self,depth=0):
335        """Traverse tree below self
336           Print for each node the name and if it is a leaf all its objects
337        """
338        if depth == 0:
339            log.critical() 
340        log.critical("%s%s:" % ('  '*depth, self.name))
341        if self.children:
342            print
343            for child in self.children:
344                child.show_all(depth+1)
345        else:
346            log.critical('%s' % self.retrieve())
347
348
349    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
350        """Traverse tree below self and find minimal cell radius,
351           maximumtree depth and maximum number of waypoints per leaf.
352        """
353
354        if self.children:
355            for child in self.children:
356                min_rad, max_depth, max_points =\
357                         child.Stats(depth+1,min_rad,max_depth,max_points)
358        else:
359            #FIXME remvoe radius stuff
360            #min_rad = sys.maxint
361            #if self.radius < min_rad:   min_rad = self.radius
362            if depth > max_depth: max_depth = depth
363            num_points = self.count()
364            if num_points > max_points: max_points = num_points
365
366        #return min_rad, max_depth, max_points   
367        return max_depth, max_points   
368       
369
370    #Class initialisation method
371    # this is bad.  It adds a huge memory structure to the class.
372    # When the instance is deleted the mesh hangs round (leaks).
373    #def initialise(cls, mesh):
374    #    cls.mesh = mesh
375
376    #initialise = classmethod(initialise)
377
378def build_quadtree(mesh, max_points_per_cell = 4):
379    """Build quad tree for mesh.
380
381    All vertices in mesh are stored in quadtree and a reference
382    to the root is returned.
383    """
384
385
386    #Make root cell
387    #print mesh.coordinates
388
389    xmin, xmax, ymin, ymax = mesh.get_extent(absolute=True)
390   
391    # Ensure boundary points are fully contained in region
392    # It is a property of the cell structure that
393    # points on xmax or ymax of any given cell
394    # belong to the neighbouring cell.
395    # Hence, the root cell needs to be expanded slightly
396    ymax += (ymax-ymin)/10
397    xmax += (xmax-xmin)/10
398
399    # To avoid round off error
400    ymin -= (ymax-ymin)/10
401    xmin -= (xmax-xmin)/10   
402
403    #print "xmin", xmin
404    #print "xmax", xmax
405    #print "ymin", ymin
406    #print "ymax", ymax
407   
408    #FIXME: Use mesh.filename if it exists
409    # why?
410    root = Cell(ymin, ymax, xmin, xmax,mesh,
411                max_points_per_cell = max_points_per_cell)
412
413    #root.show()
414   
415    #Insert indices of all vertices
416    root.insert( range(mesh.number_of_nodes) )
417
418    #Build quad tree and return
419    root.split()
420
421    return root
Note: See TracBrowser for help on using the repository browser.