source: anuga_core/source/anuga/utilities/quad.py @ 5571

Last change on this file since 5571 was 4932, checked in by steve, 16 years ago
File size: 14.7 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8
9#FIXME verts are added one at a time.
10#FIXME add max min x y in general_mesh
11
12class Cell(TreeNode):
13    """class Cell
14
15    One cell in the plane delimited by southern, northern,
16    western, eastern boundaries.
17
18    Public Methods:
19        prune()
20        insert(point)
21        search(x, y)
22        collapse()
23        split()
24        store()
25        retrieve()
26        count()
27    """
28 
29    def __init__(self, southern, northern, western, eastern, mesh,
30                 name = 'cell',
31                 max_points_per_cell = 4):
32 
33        # Initialise base classes
34        TreeNode.__init__(self, string.lower(name))
35       
36        # Initialise cell
37        self.southern = round(southern,5)   
38        self.northern = round(northern,5)
39        self.western = round(western,5)   
40        self.eastern = round(eastern,5)
41        self.mesh = mesh
42
43        # The points in this cell     
44        self.points = []
45       
46        self.max_points_per_cell = max_points_per_cell
47       
48       
49    def __repr__(self):
50        return self.name 
51
52
53    def spawn(self):
54        """Create four child cells unless they already exist
55        """
56
57        if self.children:
58            return
59        else:
60            self.children = []
61
62        # convenience variables
63        cs = self.southern   
64        cn = self.northern
65        cw = self.western   
66        ce = self.eastern   
67        mesh = self.mesh
68
69        # create 4 child cells
70        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,mesh,self.name+'_nw'))
71        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,mesh,self.name+'_ne'))
72        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,mesh,self.name+'_se'))
73        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,mesh,self.name+'_sw'))
74       
75 
76    def search(self, x, y, get_vertices=False):
77        """Find all point indices sharing the same cell as point (x, y)
78        """
79        branch = []
80        points = []
81        if self.children:
82            for child in self:
83                if child.contains(x,y):
84                    brothers = list(self.children)
85                    brothers.remove(child)
86                    branch.append(brothers)
87                    points, branch = child.search_branch(x,y, branch,
88                                                  get_vertices=get_vertices)
89        else:
90            # Leaf node: Get actual waypoints
91            points = self.retrieve(get_vertices=get_vertices)
92        self.branch = branch   
93        return points
94
95
96    def search_branch(self, x, y, branch, get_vertices=False):
97        """Find all point indices sharing the same cell as point (x, y)
98        """
99        points = []
100        if self.children:
101            for child in self:
102                if child.contains(x,y):
103                    brothers = list(self.children)
104                    brothers.remove(child)
105                    branch.append(brothers)
106                    points, branch = child.search_branch(x,y, branch,
107                                                  get_vertices=get_vertices)
108                   
109        else:
110            # Leaf node: Get actual waypoints
111            points = self.retrieve(get_vertices=get_vertices)     
112        return points, branch
113
114
115    def expand_search(self, get_vertices=False):
116        """Find all point indices 'up' one cell from the last search
117        """
118       
119        points = []
120        if self.branch == []:
121            points = []
122        else:
123            three_cells = self.branch.pop()
124            for cell in three_cells:
125                #print "cell ", cell.show()
126                points += cell.retrieve(get_vertices=get_vertices)
127        return points, self.branch
128
129
130    def contains(*args):   
131        """True only if P's coordinates lie within cell boundaries
132        This methods has two forms:
133       
134        cell.contains(index)
135        #True if cell contains indexed point
136        cell.contains(x, y)
137        #True if cell contains point (x,y)     
138        """
139        self = args[0]
140        if len(args) == 2:
141            point_id = int(args[1])
142            x, y = self.mesh.get_node(point_id, absolute=True)
143        elif len(args) == 3:
144            x = float(args[1])
145            y = float(args[2])
146        else:
147            msg = 'Number of arguments to method must be two or three'
148            raise msg                         
149       
150        if y <  self.southern: return False
151        if y >= self.northern: return False
152        if x <  self.western:  return False
153        if x >= self.eastern:  return False
154        return True
155   
156   
157    def insert(self, points, split = False):
158        """insert point(s) in existing tree structure below self
159           and split if requested
160        """
161
162        # Call insert for each element of a list of points
163        if type(points) == types.ListType:
164            for point in points:
165                self.insert(point, split)
166        else:
167            #Only one point given as argument   
168            point = points
169       
170            # Find appropriate cell
171            if self.children is not None:
172                for child in self:
173                    if child.contains(point):
174                        child.insert(point, split)
175                        break
176            else:
177                # self is a leaf cell: insert point into cell
178                if self.contains(point):
179                    self.store(point)
180                else:
181                    # Have to take into account of georef.
182                    #x = self.mesh.coordinates[point][0]
183                    #y = self.mesh.coordinates[point][1]
184                    node = self.mesh.get_node(point, absolute=True)
185                    print "node", node
186                    print "(" + str(node[0]) + "," + str(node[1]) + ")"
187                    raise 'point not in region: %s' %str(point)
188               
189               
190        #Split datastructure if requested       
191        if split is True:
192            self.split()
193               
194
195
196    def store(self,objects):
197       
198        if type(objects) not in [types.ListType,types.TupleType]:
199            self.points.append(objects)
200        else:
201            self.points.extend(objects)
202
203
204    def retrieve_triangles(self):
205        """return a list of lists. For the inner lists,
206        The first element is the triangle index,
207        the second element is a list.for this list
208           the first element is a list of three (x, y) vertices,
209           the following elements are the three triangle normals.
210
211        This info is used in searching for a triangle that a point is in.
212
213        Post condition
214        No more points can be added to the quad tree, since the
215        points data structure is removed.
216        """
217        # FIXME Tidy up the structure that is returned.
218        # if the triangles att has been made
219        # return it.
220        if not hasattr(self,'triangles'):
221            # use a dictionary to remove duplicates
222            triangles = {}
223            verts = self.retrieve_vertices()
224            # print "verts", verts
225            for vert in verts:
226                triangle_list = self.mesh.get_triangles_and_vertices_per_node(vert)
227                for k, _ in triangle_list:
228                    if not triangles.has_key(k):
229                        # print 'k',k
230                        tri = self.mesh.get_vertex_coordinates(k,
231                                                               absolute=True)
232                        n0 = self.mesh.get_normal(k, 0)
233                        n1 = self.mesh.get_normal(k, 1)
234                        n2 = self.mesh.get_normal(k, 2) 
235                        triangles[k]=(tri, (n0, n1, n2))
236            self.triangles = triangles.items()
237            # Delete the old cell data structure to save memory
238            del self.points
239        return self.triangles
240           
241    def retrieve_vertices(self):
242         return self.points
243
244
245    def retrieve(self, get_vertices=True):
246         objects = []
247         if self.children is None:
248             if get_vertices is True:
249                 objects = self.retrieve_vertices()
250             else:
251                 objects =  self.retrieve_triangles()
252         else: 
253             for child in self:
254                 objects += child.retrieve(get_vertices=get_vertices)
255         return objects
256       
257
258    def count(self, keywords=None):
259        """retrieve number of stored objects beneath this node inclusive
260        """
261       
262        num_waypoint = 0
263        if self.children:
264            for child in self:
265                num_waypoint = num_waypoint + child.count()
266        else:
267            num_waypoint = len(self.points)
268        return num_waypoint
269 
270
271    def clear(self):
272        self.Prune()   # TreeNode method
273
274
275    def clear_leaf_node(self):
276        """Clears storage in leaf node.
277        Called from Treenod.
278        Must exist.     
279        """
280        self.points = []
281       
282       
283    def clear_internal_node(self):
284        """Called from Treenode.   
285        Must exist.
286        """
287        pass
288
289
290
291    def split(self, threshold=None):
292        """
293        Partition cell when number of contained waypoints exceeds
294        threshold.  All waypoints are then moved into correct
295        child cell.
296        """
297        if threshold == None:
298           threshold = self.max_points_per_cell
299           
300        #FIXME, mincellsize removed.  base it on side length, if needed
301       
302        #Protect against silly thresholds such as -1
303        if threshold < 1:
304            return
305       
306        if not self.children:               # Leaf cell
307            if self.count() > threshold :   
308                #Split is needed
309                points = self.retrieve()    # Get points from leaf cell
310                self.clear()                # and remove them from storage
311                   
312                self.spawn()                # Spawn child cells and move
313                for p in points:            # points to appropriate child
314                    for child in self:
315                        if child.contains(p):
316                            child.insert(p) 
317                            break
318                       
319        if self.children:                   # Parent cell
320            for child in self:              # split (possibly newly created)
321                child.split(threshold)      # child cells recursively
322               
323
324
325    def collapse(self,threshold=None):
326        """
327        collapse child cells into immediate parent if total number of contained waypoints
328        in subtree below is less than or equal to threshold.
329        All waypoints are then moved into parent cell and
330        children are removed. If self is a leaf node initially, do nothing.
331        """
332       
333        if threshold is None:
334            threshold = self.max_points_per_cell       
335
336
337        if self.children:                   # Parent cell   
338            if self.count() <= threshold:   # collapse
339                points = self.retrieve()    # Get all points from child cells
340                self.clear()                # Remove children, self is now a leaf node
341                self.insert(points)         # Insert all points in local storage
342            else:                         
343                for child in self:          # Check if any sub tree can be collapsed
344                    child.collapse(threshold)
345
346
347    def Get_tree(self,depth=0):
348        """Traverse tree below self
349           Print for each node the name and
350           if it is a leaf the number of objects
351        """
352        s = ''
353        if depth == 0:
354            s = '\n'
355           
356        s += "%s%s:" % ('  '*depth, self.name)
357        if self.children:
358            s += '\n'
359            for child in self.children:
360                s += child.Get_tree(depth+1)
361        else:
362            s += '(#wp=%d)\n' %(self.count())
363
364        return s
365
366       
367    def show(self, depth=0):
368        """Traverse tree below self
369           Print for each node the name and
370           if it is a leaf the number of objects
371        """
372        if depth == 0:
373            print 
374        print "%s%s" % ('  '*depth, self.name),
375        if self.children:
376            print
377            for child in self.children:
378                child.show(depth+1)
379        else:
380            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
381                  %(self.western, self.eastern,
382                    self.southern, self.northern,
383                    self.count()) 
384
385
386    def show_all(self,depth=0):
387        """Traverse tree below self
388           Print for each node the name and if it is a leaf all its objects
389        """
390        if depth == 0:
391            print 
392        print "%s%s:" % ('  '*depth, self.name),
393        if self.children:
394            print
395            for child in self.children:
396                child.show_all(depth+1)
397        else:
398            print '%s' %self.retrieve()
399
400
401    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
402        """Traverse tree below self and find minimal cell radius,
403           maximumtree depth and maximum number of waypoints per leaf.
404        """
405
406        if self.children:
407            for child in self.children:
408                min_rad, max_depth, max_points =\
409                         child.Stats(depth+1,min_rad,max_depth,max_points)
410        else:
411            #FIXME remvoe radius stuff
412            #min_rad = sys.maxint
413            #if self.radius < min_rad:   min_rad = self.radius
414            if depth > max_depth: max_depth = depth
415            num_points = self.count()
416            if num_points > max_points: max_points = num_points
417
418        #return min_rad, max_depth, max_points   
419        return max_depth, max_points   
420       
421
422    #Class initialisation method
423    # this is bad.  It adds a huge memory structure to the class.
424    # When the instance is deleted the mesh hangs round (leaks).
425    #def initialise(cls, mesh):
426    #    cls.mesh = mesh
427
428    #initialise = classmethod(initialise)
429
430def build_quadtree(mesh, max_points_per_cell = 4):
431    """Build quad tree for mesh.
432
433    All vertices in mesh are stored in quadtree and a reference
434    to the root is returned.
435    """
436
437    from Numeric import minimum, maximum
438
439
440    #Make root cell
441    #print mesh.coordinates
442
443    xmin, xmax, ymin, ymax = mesh.get_extent(absolute=True)
444   
445    # Ensure boundary points are fully contained in region
446    # It is a property of the cell structure that
447    # points on xmax or ymax of any given cell
448    # belong to the neighbouring cell.
449    # Hence, the root cell needs to be expanded slightly
450    ymax += (ymax-ymin)/10
451    xmax += (xmax-xmin)/10
452
453    # To avoid round off error
454    ymin -= (ymax-ymin)/10
455    xmin -= (xmax-xmin)/10   
456
457    #print "xmin", xmin
458    #print "xmax", xmax
459    #print "ymin", ymin
460    #print "ymax", ymax
461   
462    #FIXME: Use mesh.filename if it exists
463    # why?
464    root = Cell(ymin, ymax, xmin, xmax,mesh,
465                max_points_per_cell = max_points_per_cell)
466
467    #root.show()
468   
469    #Insert indices of all vertices
470    root.insert( range(mesh.number_of_nodes) )
471
472    #Build quad tree and return
473    root.split()
474
475    return root
Note: See TracBrowser for help on using the repository browser.