source: anuga_core/source/anuga/utilities/quad.py @ 4808

Last change on this file since 4808 was 4808, checked in by duncan, 17 years ago

bug fix

File size: 14.6 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8
9#FIXME verts are added one at a time.
10#FIXME add max min x y in general_mesh
11
12class Cell(TreeNode):
13    """class Cell
14
15    One cell in the plane delimited by southern, northern,
16    western, eastern boundaries.
17
18    Public Methods:
19        prune()
20        insert(point)
21        search(x, y)
22        collapse()
23        split()
24        store()
25        retrieve()
26        count()
27    """
28 
29    def __init__(self, southern, northern, western, eastern, mesh,
30                 name = 'cell',
31                 max_points_per_cell = 4):
32 
33        # Initialise base classes
34        TreeNode.__init__(self, string.lower(name))
35       
36        # Initialise cell
37        self.southern = round(southern,5)   
38        self.northern = round(northern,5)
39        self.western = round(western,5)   
40        self.eastern = round(eastern,5)
41        self.mesh = mesh
42
43        # The points in this cell     
44        self.points = []
45       
46        self.max_points_per_cell = max_points_per_cell
47       
48       
49    def __repr__(self):
50        return self.name 
51
52
53    def spawn(self):
54        """Create four child cells unless they already exist
55        """
56
57        if self.children:
58            return
59        else:
60            self.children = []
61
62        # convenience variables
63        cs = self.southern   
64        cn = self.northern
65        cw = self.western   
66        ce = self.eastern   
67        mesh = self.mesh
68
69        # create 4 child cells
70        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,mesh,self.name+'_nw'))
71        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,mesh,self.name+'_ne'))
72        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,mesh,self.name+'_se'))
73        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,mesh,self.name+'_sw'))
74       
75 
76    def search(self, x, y, get_vertices=False):
77        """Find all point indices sharing the same cell as point (x, y)
78        """
79        branch = []
80        points = []
81        if self.children:
82            for child in self:
83                if child.contains(x,y):
84                    brothers = list(self.children)
85                    brothers.remove(child)
86                    branch.append(brothers)
87                    points, branch = child.search_branch(x,y, branch,
88                                                  get_vertices=get_vertices)
89        else:
90            # Leaf node: Get actual waypoints
91            points = self.retrieve(get_vertices=get_vertices)
92        self.branch = branch   
93        return points
94
95
96    def search_branch(self, x, y, branch, get_vertices=False):
97        """Find all point indices sharing the same cell as point (x, y)
98        """
99        points = []
100        if self.children:
101            for child in self:
102                if child.contains(x,y):
103                    brothers = list(self.children)
104                    brothers.remove(child)
105                    branch.append(brothers)
106                    points, branch = child.search_branch(x,y, branch,
107                                                  get_vertices=get_vertices)
108                   
109        else:
110            # Leaf node: Get actual waypoints
111            points = self.retrieve(get_vertices=get_vertices)     
112        return points, branch
113
114
115    def expand_search(self, get_vertices=False):
116        """Find all point indices 'up' one cell from the last search
117        """
118       
119        points = []
120        if self.branch == []:
121            points = []
122        else:
123            three_cells = self.branch.pop()
124            for cell in three_cells:
125                #print "cell ", cell.show()
126                points += cell.retrieve(get_vertices=get_vertices)
127        return points, self.branch
128
129
130    def contains(*args):   
131        """True only if P's coordinates lie within cell boundaries
132        This methods has two forms:
133       
134        cell.contains(index)
135          #True if cell contains indexed point
136        cell.contains(x, y)
137          #True if cell contains point (x,y)   
138
139        """
140       
141        self = args[0]
142        if len(args) == 2:
143            point_id = int(args[1])
144            x, y = self.mesh.get_node(point_id, absolute=True)
145        elif len(args) == 3:
146            x = float(args[1])
147            y = float(args[2])
148        else:
149            msg = 'Number of arguments to method must be two or three'
150            raise msg                         
151       
152        if y <  self.southern: return False
153        if y >= self.northern: return False
154        if x <  self.western:  return False
155        if x >= self.eastern:  return False
156        return True
157   
158   
159    def insert(self, points, split = False):
160        """insert point(s) in existing tree structure below self
161           and split if requested
162        """
163
164        # Call insert for each element of a list of points
165        if type(points) == types.ListType:
166            for point in points:
167                self.insert(point, split)
168        else:
169            #Only one point given as argument   
170            point = points
171       
172            # Find appropriate cell
173            if self.children is not None:
174                for child in self:
175                    if child.contains(point):
176                        child.insert(point, split)
177                        break
178            else:
179                # self is a leaf cell: insert point into cell
180                if self.contains(point):
181                    self.store(point)
182                else:
183                    # Have to take into account of georef.
184                    #x = self.mesh.coordinates[point][0]
185                    #y = self.mesh.coordinates[point][1]
186                    node = self.mesh.get_node(point, absolute=True)
187                    print "node", node
188                    print "(" + str(node[0]) + "," + str(node[1]) + ")"
189                    raise 'point not in region: %s' %str(point)
190               
191               
192        #Split datastructure if requested       
193        if split is True:
194            self.split()
195               
196
197
198    def store(self,objects):
199       
200        if type(objects) not in [types.ListType,types.TupleType]:
201            self.points.append(objects)
202        else:
203            self.points.extend(objects)
204
205
206    def retrieve_triangles(self):
207        """return a list of lists. For the inner lists,
208        The first element is the triangle index,
209        the second element is a list.for this list
210           the first element is a list of three (x, y) vertices,
211           the following elements are the three triangle normals.
212
213        This info is used in searching for a triangle that a point is in.
214
215        Post condition
216        No more points can be added to the quad tree, since the
217        points data structure is removed.
218        """
219        # FIXME Tidy up the structure that is returned.
220        # if the triangles att has been made
221        # return it.
222        if not hasattr(self,'triangles'):
223            # use a dictionary to remove duplicates
224            triangles = {}
225            verts = self.retrieve_vertices()
226            # print "verts", verts
227            for vert in verts:
228                triangle_list = self.mesh.get_triangles_and_vertices_per_node(vert)
229                for k, _ in triangle_list:
230                    if not triangles.has_key(k):
231                        # print 'k',k
232                        tri = self.mesh.get_vertex_coordinates(k,
233                                                               absolute=True)
234                        n0 = self.mesh.get_normal(k, 0)
235                        n1 = self.mesh.get_normal(k, 1)
236                        n2 = self.mesh.get_normal(k, 2) 
237                        triangles[k]=(tri, (n0, n1, n2))
238            self.triangles = triangles.items()
239            # Delete the old cell data structure to save memory
240            del self.points
241        return self.triangles
242           
243    def retrieve_vertices(self):
244         return self.points
245
246
247    def retrieve(self, get_vertices=True):
248         objects = []
249         if self.children is None:
250             if get_vertices is True:
251                 objects = self.retrieve_vertices()
252             else:
253                 objects =  self.retrieve_triangles()
254         else: 
255             for child in self:
256                 objects += child.retrieve(get_vertices=get_vertices)
257         return objects
258       
259
260    def count(self, keywords=None):
261        """retrieve number of stored objects beneath this node inclusive
262        """
263       
264        num_waypoint = 0
265        if self.children:
266            for child in self:
267                num_waypoint = num_waypoint + child.count()
268        else:
269            num_waypoint = len(self.points)
270        return num_waypoint
271 
272
273    def clear(self):
274        self.Prune()   # TreeNode method
275
276
277    def clear_leaf_node(self):
278        """Clears storage in leaf node.
279        Called from Treenod.
280        Must exist.     
281        """
282        self.points = []
283       
284       
285    def clear_internal_node(self):
286        """Called from Treenode.   
287        Must exist.
288        """
289        pass
290
291
292
293    def split(self, threshold=None):
294        """
295        Partition cell when number of contained waypoints exceeds
296        threshold.  All waypoints are then moved into correct
297        child cell.
298        """
299        if threshold == None:
300           threshold = self.max_points_per_cell
301           
302        #FIXME, mincellsize removed.  base it on side length, if needed
303       
304        #Protect against silly thresholds such as -1
305        if threshold < 1:
306            return
307       
308        if not self.children:               # Leaf cell
309            if self.count() > threshold :   
310                #Split is needed
311                points = self.retrieve()    # Get points from leaf cell
312                self.clear()                # and remove them from storage
313                   
314                self.spawn()                # Spawn child cells and move
315                for p in points:            # points to appropriate child
316                    for child in self:
317                        if child.contains(p):
318                            child.insert(p) 
319                            break
320                       
321        if self.children:                   # Parent cell
322            for child in self:              # split (possibly newly created)
323                child.split(threshold)      # child cells recursively
324               
325
326
327    def collapse(self,threshold=None):
328        """
329        collapse child cells into immediate parent if total number of contained waypoints
330        in subtree below is less than or equal to threshold.
331        All waypoints are then moved into parent cell and
332        children are removed. If self is a leaf node initially, do nothing.
333        """
334       
335        if threshold is None:
336            threshold = self.max_points_per_cell       
337
338
339        if self.children:                   # Parent cell   
340            if self.count() <= threshold:   # collapse
341                points = self.retrieve()    # Get all points from child cells
342                self.clear()                # Remove children, self is now a leaf node
343                self.insert(points)         # Insert all points in local storage
344            else:                         
345                for child in self:          # Check if any sub tree can be collapsed
346                    child.collapse(threshold)
347
348
349    def Get_tree(self,depth=0):
350        """Traverse tree below self
351           Print for each node the name and
352           if it is a leaf the number of objects
353        """
354        s = ''
355        if depth == 0:
356            s = '\n'
357           
358        s += "%s%s:" % ('  '*depth, self.name)
359        if self.children:
360            s += '\n'
361            for child in self.children:
362                s += child.Get_tree(depth+1)
363        else:
364            s += '(#wp=%d)\n' %(self.count())
365
366        return s
367
368       
369    def show(self, depth=0):
370        """Traverse tree below self
371           Print for each node the name and
372           if it is a leaf the number of objects
373        """
374        if depth == 0:
375            print 
376        print "%s%s" % ('  '*depth, self.name),
377        if self.children:
378            print
379            for child in self.children:
380                child.show(depth+1)
381        else:
382            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
383                  %(self.western, self.eastern,
384                    self.southern, self.northern,
385                    self.count()) 
386
387
388    def show_all(self,depth=0):
389        """Traverse tree below self
390           Print for each node the name and if it is a leaf all its objects
391        """
392        if depth == 0:
393            print 
394        print "%s%s:" % ('  '*depth, self.name),
395        if self.children:
396            print
397            for child in self.children:
398                child.show_all(depth+1)
399        else:
400            print '%s' %self.retrieve()
401
402
403    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
404        """Traverse tree below self and find minimal cell radius,
405           maximumtree depth and maximum number of waypoints per leaf.
406        """
407
408        if self.children:
409            for child in self.children:
410                min_rad, max_depth, max_points =\
411                         child.Stats(depth+1,min_rad,max_depth,max_points)
412        else:
413            #FIXME remvoe radius stuff
414            #min_rad = sys.maxint
415            #if self.radius < min_rad:   min_rad = self.radius
416            if depth > max_depth: max_depth = depth
417            num_points = self.count()
418            if num_points > max_points: max_points = num_points
419
420        #return min_rad, max_depth, max_points   
421        return max_depth, max_points   
422       
423
424    #Class initialisation method
425    # this is bad.  It adds a huge memory structure to the class.
426    # When the instance is deleted the mesh hangs round (leaks).
427    #def initialise(cls, mesh):
428    #    cls.mesh = mesh
429
430    #initialise = classmethod(initialise)
431
432def build_quadtree(mesh, max_points_per_cell = 4):
433    """Build quad tree for mesh.
434
435    All vertices in mesh are stored in quadtree and a reference
436    to the root is returned.
437    """
438
439    from Numeric import minimum, maximum
440
441
442    #Make root cell
443    #print mesh.coordinates
444
445    xmin, xmax, ymin, ymax = mesh.get_extent(absolute=True)
446   
447    # Ensure boundary points are fully contained in region
448    # It is a property of the cell structure that
449    # points on xmax or ymax of any given cell
450    # belong to the neighbouring cell.
451    # Hence, the root cell needs to be expanded slightly
452    ymax += (ymax-ymin)/10
453    xmax += (xmax-xmin)/10
454
455    # To avoid round off error
456    ymin -= (ymax-ymin)/10
457    xmin -= (xmax-xmin)/10   
458
459    #print "xmin", xmin
460    #print "xmax", xmax
461    #print "ymin", ymin
462    #print "ymax", ymax
463   
464    #FIXME: Use mesh.filename if it exists
465    # why?
466    root = Cell(ymin, ymax, xmin, xmax,mesh,
467                max_points_per_cell = max_points_per_cell)
468
469    #root.show()
470   
471    #Insert indices of all vertices
472    root.insert( range(mesh.number_of_nodes) )
473
474    #Build quad tree and return
475    root.split()
476
477    return root
Note: See TracBrowser for help on using the repository browser.