source: anuga_core/source/anuga/utilities/quad.py @ 3945

Last change on this file since 3945 was 3945, checked in by ole, 17 years ago

One large step towards major cleanup. This has mainly to do with
the way vertex coordinates are handled internally.

File size: 12.3 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8
9#FIXME verts are added one at a time.
10#FIXME add max min x y in general_mesh
11
12class Cell(TreeNode):
13    """class Cell
14
15    One cell in the plane delimited by southern, northern,
16    western, eastern boundaries.
17
18    Public Methods:
19        prune()
20        insert(point)
21        search(x, y)
22        collapse()
23        split()
24        store()
25        retrieve()
26        count()
27    """
28 
29    def __init__(self, southern, northern, western, eastern, 
30                 name = 'cell',
31                 max_points_per_cell = 4):
32 
33        # Initialise base classes
34        TreeNode.__init__(self, string.lower(name))
35       
36        # Initialise cell
37        self.southern = round(southern,5)   
38        self.northern = round(northern,5)
39        self.western = round(western,5)   
40        self.eastern = round(eastern,5)
41
42        # The points in this cell     
43        self.points = []
44       
45        self.max_points_per_cell = max_points_per_cell
46       
47       
48    def __repr__(self):
49        return self.name 
50
51
52    def spawn(self):
53        """Create four child cells unless they already exist
54        """
55
56        if self.children:
57            return
58        else:
59            self.children = []
60
61        # convenience variables
62        cs = self.southern   
63        cn = self.northern
64        cw = self.western   
65        ce = self.eastern
66
67        # create 4 child cells
68        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,self.name+'_nw'))
69        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,self.name+'_ne'))
70        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,self.name+'_se'))
71        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,self.name+'_sw'))
72       
73 
74    def search(self, x, y):
75    #def search_new(self, x, y):
76        """Find all point indices sharing the same cell as point (x, y)
77        """
78        branch = []
79        points = []
80        if self.children:
81            for child in self:
82                if child.contains(x,y):
83                    brothers = list(self.children)
84                    brothers.remove(child)
85                    branch.append(brothers)
86                    points, branch = child.search_branch(x,y, branch)
87        else:
88            # Leaf node: Get actual waypoints
89            points = self.retrieve()
90
91        self.branch = branch   
92        return points
93
94
95    def search_branch(self, x, y, branch):
96        """Find all point indices sharing the same cell as point (x, y)
97        """
98        points = []
99        if self.children:
100            for child in self:
101                if child.contains(x,y):
102                    brothers = list(self.children)
103                    brothers.remove(child)
104                    branch.append(brothers)
105                    points, branch = child.search_branch(x,y, branch)
106                   
107        else:
108            # Leaf node: Get actual waypoints
109            points = self.retrieve()     
110        return points, branch
111
112
113    def expand_search(self):
114        """Find all point indices 'up' one cell from the last search
115        """
116        points = []
117        if self.branch == []:
118            points = []
119        else:
120            three_cells = self.branch.pop()
121            for cell in three_cells:
122                #print "cell ", cell.show()
123                points += cell.retrieve()
124        return points, self.branch
125
126
127    def contains(*args):   
128        """True only if P's coordinates lie within cell boundaries
129        This methods has two forms:
130       
131        cell.contains(index)
132          #True if cell contains indexed point
133        cell.contains(x, y)
134          #True if cell contains point (x,y)   
135
136        """
137       
138        self = args[0]
139        if len(args) == 2:
140            point_id = int(args[1])
141            x, y = self.__class__.mesh.get_nodes()[point_id]
142
143            #print point_id, x, y
144        elif len(args) == 3:
145            x = float(args[1])
146            y = float(args[2])
147        else:
148            msg = 'Number of arguments to method must be two or three'
149            raise msg                         
150       
151        if y <  self.southern: return False
152        if y >= self.northern: return False
153        if x <  self.western:  return False
154        if x >= self.eastern:  return False
155        return True
156   
157   
158    def insert(self, points, split = False):
159        """insert point(s) in existing tree structure below self
160           and split if requested
161        """
162
163        # Call insert for each element of a list of points
164        if type(points) == types.ListType:
165            for point in points:
166                self.insert(point, split)
167        else:
168            #Only one point given as argument   
169            point = points
170       
171            # Find appropriate cell
172            if self.children is not None:
173                for child in self:
174                    if child.contains(point):
175                        child.insert(point, split)
176                        break
177            else:
178                # self is a leaf cell: insert point into cell
179                if self.contains(point):
180                    self.store(point)
181                else:
182                    x = self.__class__.mesh.coordinates[point][0]
183                    y = self.__class__.mesh.coordinates[point][1]
184                    print "(" + str(x) + "," + str(y) + ")"
185                    raise 'point not in region: %s' %str(point)
186               
187               
188        #Split datastructure if requested       
189        if split is True:
190            self.split()
191               
192
193
194    def store(self,objects):
195       
196        if type(objects) not in [types.ListType,types.TupleType]:
197            self.points.append(objects)
198        else:
199            self.points.extend(objects)
200
201
202    def retrieve(self):
203         objects = []
204         if self.children is None:
205             objects = self.points
206         else: 
207             for child in self:
208                 objects += child.retrieve()
209         return objects 
210
211
212    def count(self, keywords=None):
213        """retrieve number of stored objects beneath this node inclusive
214        """
215       
216        num_waypoint = 0
217        if self.children:
218            for child in self:
219                num_waypoint = num_waypoint + child.count()
220        else:
221            num_waypoint = len(self.points)
222        return num_waypoint
223 
224
225    def clear(self):
226        self.Prune()   # TreeNode method
227
228
229    def clear_leaf_node(self):
230        """Clears storage in leaf node.
231        Called from Treenod.
232        Must exist.     
233        """
234        self.points = []
235       
236       
237    def clear_internal_node(self):
238        """Called from Treenode.   
239        Must exist.
240        """
241        pass
242
243
244
245    def split(self, threshold=None):
246        """
247        Partition cell when number of contained waypoints exceeds
248        threshold.  All waypoints are then moved into correct
249        child cell.
250        """
251        if threshold == None:
252           threshold = self.max_points_per_cell
253           
254        #FIXME, mincellsize removed.  base it on side length, if needed
255       
256        #Protect against silly thresholds such as -1
257        if threshold < 1:
258            return
259       
260        if not self.children:               # Leaf cell
261            if self.count() > threshold :   
262                #Split is needed
263                points = self.retrieve()    # Get points from leaf cell
264                self.clear()                # and remove them from storage
265               
266                self.spawn()                # Spawn child cells and move
267                for p in points:            # points to appropriate child
268                    for child in self:
269                        if child.contains(p):
270                            child.insert(p) 
271                            break
272   
273        if self.children:                   # Parent cell
274            for child in self:              # split (possibly newly created)
275                child.split(threshold)      # child cells recursively
276               
277
278
279    def collapse(self,threshold=None):
280        """
281        collapse child cells into immediate parent if total number of contained waypoints
282        in subtree below is less than or equal to threshold.
283        All waypoints are then moved into parent cell and
284        children are removed. If self is a leaf node initially, do nothing.
285        """
286       
287        if threshold is None:
288            threshold = self.max_points_per_cell       
289
290
291        if self.children:                   # Parent cell   
292            if self.count() <= threshold:   # collapse
293                points = self.retrieve()    # Get all points from child cells
294                self.clear()                # Remove children, self is now a leaf node
295                self.insert(points)         # Insert all points in local storage
296            else:                         
297                for child in self:          # Check if any sub tree can be collapsed
298                    child.collapse(threshold)
299
300
301    def Get_tree(self,depth=0):
302        """Traverse tree below self
303           Print for each node the name and
304           if it is a leaf the number of objects
305        """
306        s = ''
307        if depth == 0:
308            s = '\n'
309           
310        s += "%s%s:" % ('  '*depth, self.name)
311        if self.children:
312            s += '\n'
313            for child in self.children:
314                s += child.Get_tree(depth+1)
315        else:
316            s += '(#wp=%d)\n' %(self.count())
317
318        return s
319
320       
321    def show(self, depth=0):
322        """Traverse tree below self
323           Print for each node the name and
324           if it is a leaf the number of objects
325        """
326        if depth == 0:
327            print 
328        print "%s%s" % ('  '*depth, self.name),
329        if self.children:
330            print
331            for child in self.children:
332                child.show(depth+1)
333        else:
334            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
335                  %(self.western, self.eastern,
336                    self.southern, self.northern,
337                    self.count()) 
338
339
340    def show_all(self,depth=0):
341        """Traverse tree below self
342           Print for each node the name and if it is a leaf all its objects
343        """
344        if depth == 0:
345            print 
346        print "%s%s:" % ('  '*depth, self.name),
347        if self.children:
348            print
349            for child in self.children:
350                child.show_all(depth+1)
351        else:
352            print '%s' %self.retrieve()
353
354
355    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
356        """Traverse tree below self and find minimal cell radius,
357           maximumtree depth and maximum number of waypoints per leaf.
358        """
359
360        if self.children:
361            for child in self.children:
362                min_rad, max_depth, max_points =\
363                         child.Stats(depth+1,min_rad,max_depth,max_points)
364        else:
365            #FIXME remvoe radius stuff
366            #min_rad = sys.maxint
367            #if self.radius < min_rad:   min_rad = self.radius
368            if depth > max_depth: max_depth = depth
369            num_points = self.count()
370            if num_points > max_points: max_points = num_points
371
372        #return min_rad, max_depth, max_points   
373        return max_depth, max_points   
374       
375
376    #Class initialisation method       
377    def initialise(cls, mesh):
378        cls.mesh = mesh
379
380    initialise = classmethod(initialise)
381
382def build_quadtree(mesh, max_points_per_cell = 4):
383    """Build quad tree for mesh.
384
385    All vertices in mesh are stored in quadtree and a reference to the root is returned.
386    """
387
388    from Numeric import minimum, maximum
389
390    #Initialise
391    Cell.initialise(mesh)
392
393    #Make root cell
394    #print mesh.coordinates
395
396    nodes = mesh.get_nodes()
397    xmin = min(nodes[:,0])
398    xmax = max(nodes[:,0])
399    ymin = min(nodes[:,1])
400    ymax = max(nodes[:,1])
401
402   
403    #Ensure boundary points are fully contained in region
404    #It is a property of the cell structure that points on xmax or ymax of any given cell
405    #belong to the neighbouring cell.
406    #Hence, the root cell needs to be expanded slightly
407    ymax += (ymax-ymin)/10
408    xmax += (xmax-xmin)/10
409
410    # To avoid round off error
411    ymin -= (ymax-ymin)/10
412    xmin -= (xmax-xmin)/10   
413
414    #print "xmin", xmin
415    #print "xmax", xmax
416    #print "ymin", ymin
417    #print "ymax", ymax
418   
419    #FIXME: Use mesh.filename if it exists
420    root = Cell(ymin, ymax, xmin, xmax,
421                #name = ....
422                max_points_per_cell = max_points_per_cell)
423
424    #root.show()
425   
426    #Insert indices of all vertices
427    root.insert( range(mesh.number_of_nodes) )
428
429    #Build quad tree and return
430    root.split()
431
432    return root
Note: See TracBrowser for help on using the repository browser.