source: inundation-numpy-branch/pyvolution/quad.py @ 2608

Last change on this file since 2608 was 2608, checked in by ole, 18 years ago

Played with custom exceptions for ANUGA

File size: 12.3 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8
9#FIXME verts are added one at a time.
10#FIXME add max min x y in general_mesh
11
12class Cell(TreeNode):
13    """class Cell
14
15    One cell in the plane delimited by southern, northern,
16    western, eastern boundaries.
17
18    Public Methods:
19        prune()
20        insert(point)
21        search(x, y)
22        collapse()
23        split()
24        store()
25        retrieve()
26        count()
27    """
28 
29    def __init__(self, southern, northern, western, eastern, 
30                 name = 'cell',
31                 max_points_per_cell = 4):
32 
33        # Initialise base classes
34        TreeNode.__init__(self, string.lower(name))
35       
36        # Initialise cell
37        self.southern = round(southern,5)   
38        self.northern = round(northern,5)
39        self.western = round(western,5)   
40        self.eastern = round(eastern,5)
41
42        # The points in this cell     
43        self.points = []
44       
45        self.max_points_per_cell = max_points_per_cell
46       
47       
48    def __repr__(self):
49        return self.name 
50
51
52    def spawn(self):
53        """Create four child cells unless they already exist
54        """
55
56        if self.children:
57            return
58        else:
59            self.children = []
60
61        # convenience variables
62        cs = self.southern   
63        cn = self.northern
64        cw = self.western   
65        ce = self.eastern
66
67        # create 4 child cells
68        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,self.name+'_nw'))
69        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,self.name+'_ne'))
70        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,self.name+'_se'))
71        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,self.name+'_sw'))
72       
73 
74    def search(self, x, y):
75    #def search_new(self, x, y):
76        """Find all point indices sharing the same cell as point (x, y)
77        """
78        branch = []
79        points = []
80        if self.children:
81            for child in self:
82                if child.contains(x,y):
83                    brothers = list(self.children)
84                    brothers.remove(child)
85                    branch.append(brothers)
86                    points, branch = child.search_branch(x,y, branch)
87        else:
88            # Leaf node: Get actual waypoints
89            points = self.retrieve()
90
91        self.branch = branch   
92        return points
93
94
95    def search_branch(self, x, y, branch):
96        """Find all point indices sharing the same cell as point (x, y)
97        """
98        points = []
99        if self.children:
100            for child in self:
101                if child.contains(x,y):
102                    brothers = list(self.children)
103                    brothers.remove(child)
104                    branch.append(brothers)
105                    points, branch = child.search_branch(x,y, branch)
106                   
107        else:
108            # Leaf node: Get actual waypoints
109            points = self.retrieve()     
110        return points, branch
111
112
113    def expand_search(self):
114        """Find all point indices 'up' one cell from the last search
115        """
116        points = []
117        if self.branch == []:
118            points = []
119        else:
120            three_cells = self.branch.pop()
121            for cell in three_cells:
122                #print "cell ", cell.show()
123                points += cell.retrieve()
124        return points, self.branch
125
126
127    def contains(*args):   
128        """True only if P's coordinates lie within cell boundaries
129        This methods has two forms:
130       
131        cell.contains(index)
132          #True if cell contains indexed point
133        cell.contains(x, y)
134          #True if cell contains point (x,y)   
135
136        """
137       
138        self = args[0]
139        if len(args) == 2:
140            point_id = int(args[1])
141            x = self.__class__.mesh.coordinates[point_id][0]
142            y = self.__class__.mesh.coordinates[point_id][1]
143
144            #print point_id, x, y
145        elif len(args) == 3:
146            x = float(args[1])
147            y = float(args[2])
148        else:
149            msg = 'Number of arguments to method must be two or three'
150            raise msg                         
151       
152        if y <  self.southern: return False
153        if y >= self.northern: return False
154        if x <  self.western:  return False
155        if x >= self.eastern:  return False
156        return True
157   
158   
159    def insert(self, points, split = False):
160        """insert point(s) in existing tree structure below self
161           and split if requested
162        """
163
164        # Call insert for each element of a list of points
165        if type(points) == types.ListType:
166            for point in points:
167                self.insert(point, split)
168        else:
169            #Only one point given as argument   
170            point = points
171       
172            # Find appropriate cell
173            if self.children is not None:
174                for child in self:
175                    if child.contains(point):
176                        child.insert(point, split)
177                        break
178            else:
179                # self is a leaf cell: insert point into cell
180                if self.contains(point):
181                    self.store(point)
182                else:
183                    x = self.__class__.mesh.coordinates[point][0]
184                    y = self.__class__.mesh.coordinates[point][1]
185                    print "(" + str(x) + "," + str(y) + ")"
186                    raise 'point not in region: %s' %str(point)
187               
188               
189        #Split datastructure if requested       
190        if split is True:
191            self.split()
192               
193
194
195    def store(self,objects):
196       
197        if type(objects) not in [types.ListType,types.TupleType]:
198            self.points.append(objects)
199        else:
200            self.points.extend(objects)
201
202
203    def retrieve(self):
204         objects = []
205         if self.children is None:
206             objects = self.points
207         else: 
208             for child in self:
209                 objects += child.retrieve()
210         return objects 
211
212
213    def count(self, keywords=None):
214        """retrieve number of stored objects beneath this node inclusive
215        """
216       
217        num_waypoint = 0
218        if self.children:
219            for child in self:
220                num_waypoint = num_waypoint + child.count()
221        else:
222            num_waypoint = len(self.points)
223        return num_waypoint
224 
225
226    def clear(self):
227        self.Prune()   # TreeNode method
228
229
230    def clear_leaf_node(self):
231        """Clears storage in leaf node.
232        Called from Treenod.
233        Must exist.     
234        """
235        self.points = []
236       
237       
238    def clear_internal_node(self):
239        """Called from Treenode.   
240        Must exist.
241        """
242        pass
243
244
245
246    def split(self, threshold=None):
247        """
248        Partition cell when number of contained waypoints exceeds
249        threshold.  All waypoints are then moved into correct
250        child cell.
251        """
252        if threshold == None:
253           threshold = self.max_points_per_cell
254           
255        #FIXME, mincellsize removed.  base it on side length, if needed
256       
257        #Protect against silly thresholds such as -1
258        if threshold < 1:
259            return
260       
261        if not self.children:               # Leaf cell
262            if self.count() > threshold :   
263                #Split is needed
264                points = self.retrieve()    # Get points from leaf cell
265                self.clear()                # and remove them from storage
266               
267                self.spawn()                # Spawn child cells and move
268                for p in points:            # points to appropriate child
269                    for child in self:
270                        if child.contains(p):
271                            child.insert(p) 
272                            break
273   
274        if self.children:                   # Parent cell
275            for child in self:              # split (possibly newly created)
276                child.split(threshold)      # child cells recursively
277               
278
279
280    def collapse(self,threshold=None):
281        """
282        collapse child cells into immediate parent if total number of contained waypoints
283        in subtree below is less than or equal to threshold.
284        All waypoints are then moved into parent cell and
285        children are removed. If self is a leaf node initially, do nothing.
286        """
287       
288        if threshold is None:
289            threshold = self.max_points_per_cell       
290
291
292        if self.children:                   # Parent cell   
293            if self.count() <= threshold:   # collapse
294                points = self.retrieve()    # Get all points from child cells
295                self.clear()                # Remove children, self is now a leaf node
296                self.insert(points)         # Insert all points in local storage
297            else:                         
298                for child in self:          # Check if any sub tree can be collapsed
299                    child.collapse(threshold)
300
301
302    def Get_tree(self,depth=0):
303        """Traverse tree below self
304           Print for each node the name and
305           if it is a leaf the number of objects
306        """
307        s = ''
308        if depth == 0:
309            s = '\n'
310           
311        s += "%s%s:" % ('  '*depth, self.name)
312        if self.children:
313            s += '\n'
314            for child in self.children:
315                s += child.Get_tree(depth+1)
316        else:
317            s += '(#wp=%d)\n' %(self.count())
318
319        return s
320
321       
322    def show(self, depth=0):
323        """Traverse tree below self
324           Print for each node the name and
325           if it is a leaf the number of objects
326        """
327        if depth == 0:
328            print 
329        print "%s%s" % ('  '*depth, self.name),
330        if self.children:
331            print
332            for child in self.children:
333                child.show(depth+1)
334        else:
335            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
336                  %(self.western, self.eastern,
337                    self.southern, self.northern,
338                    self.count()) 
339
340
341    def show_all(self,depth=0):
342        """Traverse tree below self
343           Print for each node the name and if it is a leaf all its objects
344        """
345        if depth == 0:
346            print 
347        print "%s%s:" % ('  '*depth, self.name),
348        if self.children:
349            print
350            for child in self.children:
351                child.show_all(depth+1)
352        else:
353            print '%s' %self.retrieve()
354
355
356    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
357        """Traverse tree below self and find minimal cell radius,
358           maximumtree depth and maximum number of waypoints per leaf.
359        """
360
361        if self.children:
362            for child in self.children:
363                min_rad, max_depth, max_points =\
364                         child.Stats(depth+1,min_rad,max_depth,max_points)
365        else:
366            #FIXME remvoe radius stuff
367            #min_rad = sys.maxint
368            #if self.radius < min_rad:   min_rad = self.radius
369            if depth > max_depth: max_depth = depth
370            num_points = self.count()
371            if num_points > max_points: max_points = num_points
372
373        #return min_rad, max_depth, max_points   
374        return max_depth, max_points   
375       
376
377    #Class initialisation method       
378    def initialise(cls, mesh):
379        cls.mesh = mesh
380
381    initialise = classmethod(initialise)
382
383def build_quadtree(mesh, max_points_per_cell = 4):
384    """Build quad tree for mesh.
385
386    All vertices in mesh are stored in quadtree and a reference to the root is returned.
387    """
388
389    from numpy import minimum, maximum
390
391    #Initialise
392    Cell.initialise(mesh)
393
394    #Make root cell
395    #print mesh.coordinates
396   
397    xmin = min(mesh.coordinates[:,0])
398    xmax = max(mesh.coordinates[:,0])
399    ymin = min(mesh.coordinates[:,1])
400    ymax = max(mesh.coordinates[:,1])
401
402   
403    #Ensure boundary points are fully contained in region
404    #It is a property of the cell structure that points on xmax or ymax of any given cell
405    #belong to the neighbouring cell.
406    #Hence, the root cell needs to be expanded slightly
407    ymax += (ymax-ymin)/10
408    xmax += (xmax-xmin)/10
409
410    # To avoid round off error
411    ymin -= (ymax-ymin)/10
412    xmin -= (xmax-xmin)/10   
413
414    #print "xmin", xmin
415    #print "xmax", xmax
416    #print "ymin", ymin
417    #print "ymax", ymax
418   
419    #FIXME: Use mesh.filename if it exists
420    root = Cell(ymin, ymax, xmin, xmax,
421                #name = ....
422                max_points_per_cell = max_points_per_cell)
423
424    #root.show()
425   
426    #Insert indices of all vertices
427    root.insert( range(len(mesh.coordinates)) )
428
429    #Build quad tree and return
430    root.split()
431
432    return root
Note: See TracBrowser for help on using the repository browser.