source: anuga_core/source/anuga/utilities/quad.py @ 4666

Last change on this file since 4666 was 4666, checked in by duncan, 17 years ago

unverified solution to ticket#176

File size: 14.4 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8
9#FIXME verts are added one at a time.
10#FIXME add max min x y in general_mesh
11
12class Cell(TreeNode):
13    """class Cell
14
15    One cell in the plane delimited by southern, northern,
16    western, eastern boundaries.
17
18    Public Methods:
19        prune()
20        insert(point)
21        search(x, y)
22        collapse()
23        split()
24        store()
25        retrieve()
26        count()
27    """
28 
29    def __init__(self, southern, northern, western, eastern, mesh,
30                 name = 'cell',
31                 max_points_per_cell = 4):
32 
33        # Initialise base classes
34        TreeNode.__init__(self, string.lower(name))
35       
36        # Initialise cell
37        self.southern = round(southern,5)   
38        self.northern = round(northern,5)
39        self.western = round(western,5)   
40        self.eastern = round(eastern,5)
41        self.mesh = mesh
42
43        # The points in this cell     
44        self.points = []
45       
46        self.max_points_per_cell = max_points_per_cell
47       
48       
49    def __repr__(self):
50        return self.name 
51
52
53    def spawn(self):
54        """Create four child cells unless they already exist
55        """
56
57        if self.children:
58            return
59        else:
60            self.children = []
61
62        # convenience variables
63        cs = self.southern   
64        cn = self.northern
65        cw = self.western   
66        ce = self.eastern   
67        mesh = self.mesh
68
69        # create 4 child cells
70        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,mesh,self.name+'_nw'))
71        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,mesh,self.name+'_ne'))
72        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,mesh,self.name+'_se'))
73        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,mesh,self.name+'_sw'))
74       
75 
76    def search(self, x, y, get_vertices=False):
77        """Find all point indices sharing the same cell as point (x, y)
78        """
79        branch = []
80        points = []
81        if self.children:
82            for child in self:
83                if child.contains(x,y):
84                    brothers = list(self.children)
85                    brothers.remove(child)
86                    branch.append(brothers)
87                    points, branch = child.search_branch(x,y, branch,
88                                                  get_vertices=get_vertices)
89        else:
90            # Leaf node: Get actual waypoints
91            points = self.retrieve(get_vertices=get_vertices)
92        self.branch = branch   
93        return points
94
95
96    def search_branch(self, x, y, branch, get_vertices=False):
97        """Find all point indices sharing the same cell as point (x, y)
98        """
99        points = []
100        if self.children:
101            for child in self:
102                if child.contains(x,y):
103                    brothers = list(self.children)
104                    brothers.remove(child)
105                    branch.append(brothers)
106                    points, branch = child.search_branch(x,y, branch,
107                                                  get_vertices=get_vertices)
108                   
109        else:
110            # Leaf node: Get actual waypoints
111            points = self.retrieve(get_vertices=get_vertices)     
112        return points, branch
113
114
115    def expand_search(self, get_vertices=False):
116        """Find all point indices 'up' one cell from the last search
117        """
118       
119        points = []
120        if self.branch == []:
121            points = []
122        else:
123            three_cells = self.branch.pop()
124            for cell in three_cells:
125                #print "cell ", cell.show()
126                points += cell.retrieve(get_vertices=get_vertices)
127        return points, self.branch
128
129
130    def contains(*args):   
131        """True only if P's coordinates lie within cell boundaries
132        This methods has two forms:
133       
134        cell.contains(index)
135          #True if cell contains indexed point
136        cell.contains(x, y)
137          #True if cell contains point (x,y)   
138
139        """
140       
141        self = args[0]
142        if len(args) == 2:
143            point_id = int(args[1])
144            x, y = self.mesh.get_nodes()[point_id]
145
146            #print point_id, x, y
147        elif len(args) == 3:
148            x = float(args[1])
149            y = float(args[2])
150        else:
151            msg = 'Number of arguments to method must be two or three'
152            raise msg                         
153       
154        if y <  self.southern: return False
155        if y >= self.northern: return False
156        if x <  self.western:  return False
157        if x >= self.eastern:  return False
158        return True
159   
160   
161    def insert(self, points, split = False):
162        """insert point(s) in existing tree structure below self
163           and split if requested
164        """
165
166        # Call insert for each element of a list of points
167        if type(points) == types.ListType:
168            for point in points:
169                self.insert(point, split)
170        else:
171            #Only one point given as argument   
172            point = points
173       
174            # Find appropriate cell
175            if self.children is not None:
176                for child in self:
177                    if child.contains(point):
178                        child.insert(point, split)
179                        break
180            else:
181                # self is a leaf cell: insert point into cell
182                if self.contains(point):
183                    self.store(point)
184                else:
185                    x = self.mesh.coordinates[point][0]
186                    y = self.mesh.coordinates[point][1]
187                    print "(" + str(x) + "," + str(y) + ")"
188                    raise 'point not in region: %s' %str(point)
189               
190               
191        #Split datastructure if requested       
192        if split is True:
193            self.split()
194               
195
196
197    def store(self,objects):
198       
199        if type(objects) not in [types.ListType,types.TupleType]:
200            self.points.append(objects)
201        else:
202            self.points.extend(objects)
203
204
205    def retrieve_triangles(self):
206        """return a list of lists. For the inner lists,
207        The first element is the triangle index,
208        the second element is a list.for this list
209           the first element is a list of three (x, y) vertices,
210           the following elements are the three triangle normals.
211
212        This info is used in searching for a triangle that a point is in.
213
214        Post condition
215        No more points can be added to the quad tree, since the
216        points data structure is removed.
217        """
218        # FIXME Tidy up the structure that is returned.
219        # if the triangles att has been made
220        # return it.
221        if not hasattr(self,'triangles'):
222            # use a dictionary to remove duplicates
223            triangles = {}
224            verts = self.retrieve_vertices()
225            # print "verts", verts
226            for vert in verts:
227                triangle_list = self.mesh.get_triangles_and_vertices_per_node(vert)
228                for k, _ in triangle_list:
229                    if not triangles.has_key(k):
230                        # print 'k',k
231                        tri = self.mesh.get_vertex_coordinates(k)
232                        n0 = self.mesh.get_normal(k, 0)
233                        n1 = self.mesh.get_normal(k, 1)
234                        n2 = self.mesh.get_normal(k, 2) 
235                        triangles[k]=(tri, (n0, n1, n2))
236            self.triangles = triangles.items()
237            # Delete the old cell data structure to save memory
238            del self.points
239        return self.triangles
240           
241    def retrieve_vertices(self):
242         return self.points
243
244
245    def retrieve(self, get_vertices=True):
246         objects = []
247         if self.children is None:
248             if get_vertices is True:
249                 objects = self.retrieve_vertices()
250             else:
251                 objects =  self.retrieve_triangles()
252         else: 
253             for child in self:
254                 objects += child.retrieve(get_vertices=get_vertices)
255         return objects
256       
257
258    def count(self, keywords=None):
259        """retrieve number of stored objects beneath this node inclusive
260        """
261       
262        num_waypoint = 0
263        if self.children:
264            for child in self:
265                num_waypoint = num_waypoint + child.count()
266        else:
267            num_waypoint = len(self.points)
268        return num_waypoint
269 
270
271    def clear(self):
272        self.Prune()   # TreeNode method
273
274
275    def clear_leaf_node(self):
276        """Clears storage in leaf node.
277        Called from Treenod.
278        Must exist.     
279        """
280        self.points = []
281       
282       
283    def clear_internal_node(self):
284        """Called from Treenode.   
285        Must exist.
286        """
287        pass
288
289
290
291    def split(self, threshold=None):
292        """
293        Partition cell when number of contained waypoints exceeds
294        threshold.  All waypoints are then moved into correct
295        child cell.
296        """
297        if threshold == None:
298           threshold = self.max_points_per_cell
299           
300        #FIXME, mincellsize removed.  base it on side length, if needed
301       
302        #Protect against silly thresholds such as -1
303        if threshold < 1:
304            return
305       
306        if not self.children:               # Leaf cell
307            if self.count() > threshold :   
308                #Split is needed
309                points = self.retrieve()    # Get points from leaf cell
310                self.clear()                # and remove them from storage
311               
312                self.spawn()                # Spawn child cells and move
313                for p in points:            # points to appropriate child
314                    for child in self:
315                        if child.contains(p):
316                            child.insert(p) 
317                            break
318   
319        if self.children:                   # Parent cell
320            for child in self:              # split (possibly newly created)
321                child.split(threshold)      # child cells recursively
322               
323
324
325    def collapse(self,threshold=None):
326        """
327        collapse child cells into immediate parent if total number of contained waypoints
328        in subtree below is less than or equal to threshold.
329        All waypoints are then moved into parent cell and
330        children are removed. If self is a leaf node initially, do nothing.
331        """
332       
333        if threshold is None:
334            threshold = self.max_points_per_cell       
335
336
337        if self.children:                   # Parent cell   
338            if self.count() <= threshold:   # collapse
339                points = self.retrieve()    # Get all points from child cells
340                self.clear()                # Remove children, self is now a leaf node
341                self.insert(points)         # Insert all points in local storage
342            else:                         
343                for child in self:          # Check if any sub tree can be collapsed
344                    child.collapse(threshold)
345
346
347    def Get_tree(self,depth=0):
348        """Traverse tree below self
349           Print for each node the name and
350           if it is a leaf the number of objects
351        """
352        s = ''
353        if depth == 0:
354            s = '\n'
355           
356        s += "%s%s:" % ('  '*depth, self.name)
357        if self.children:
358            s += '\n'
359            for child in self.children:
360                s += child.Get_tree(depth+1)
361        else:
362            s += '(#wp=%d)\n' %(self.count())
363
364        return s
365
366       
367    def show(self, depth=0):
368        """Traverse tree below self
369           Print for each node the name and
370           if it is a leaf the number of objects
371        """
372        if depth == 0:
373            print 
374        print "%s%s" % ('  '*depth, self.name),
375        if self.children:
376            print
377            for child in self.children:
378                child.show(depth+1)
379        else:
380            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
381                  %(self.western, self.eastern,
382                    self.southern, self.northern,
383                    self.count()) 
384
385
386    def show_all(self,depth=0):
387        """Traverse tree below self
388           Print for each node the name and if it is a leaf all its objects
389        """
390        if depth == 0:
391            print 
392        print "%s%s:" % ('  '*depth, self.name),
393        if self.children:
394            print
395            for child in self.children:
396                child.show_all(depth+1)
397        else:
398            print '%s' %self.retrieve()
399
400
401    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
402        """Traverse tree below self and find minimal cell radius,
403           maximumtree depth and maximum number of waypoints per leaf.
404        """
405
406        if self.children:
407            for child in self.children:
408                min_rad, max_depth, max_points =\
409                         child.Stats(depth+1,min_rad,max_depth,max_points)
410        else:
411            #FIXME remvoe radius stuff
412            #min_rad = sys.maxint
413            #if self.radius < min_rad:   min_rad = self.radius
414            if depth > max_depth: max_depth = depth
415            num_points = self.count()
416            if num_points > max_points: max_points = num_points
417
418        #return min_rad, max_depth, max_points   
419        return max_depth, max_points   
420       
421
422    #Class initialisation method
423    # this is bad.  It adds a huge memory structure to the class.
424    # When the instance is deleted the mesh hangs round (leaks).
425    #def initialise(cls, mesh):
426    #    cls.mesh = mesh
427
428    #initialise = classmethod(initialise)
429
430def build_quadtree(mesh, max_points_per_cell = 4):
431    """Build quad tree for mesh.
432
433    All vertices in mesh are stored in quadtree and a reference
434    to the root is returned.
435    """
436
437    from Numeric import minimum, maximum
438
439
440    #Make root cell
441    #print mesh.coordinates
442
443    nodes = mesh.get_nodes()
444    xmin = min(nodes[:,0])
445    xmax = max(nodes[:,0])
446    ymin = min(nodes[:,1])
447    ymax = max(nodes[:,1])
448
449   
450    # Ensure boundary points are fully contained in region
451    # It is a property of the cell structure that
452    # points on xmax or ymax of any given cell
453    # belong to the neighbouring cell.
454    # Hence, the root cell needs to be expanded slightly
455    ymax += (ymax-ymin)/10
456    xmax += (xmax-xmin)/10
457
458    # To avoid round off error
459    ymin -= (ymax-ymin)/10
460    xmin -= (xmax-xmin)/10   
461
462    #print "xmin", xmin
463    #print "xmax", xmax
464    #print "ymin", ymin
465    #print "ymax", ymax
466   
467    #FIXME: Use mesh.filename if it exists
468    root = Cell(ymin, ymax, xmin, xmax,mesh,
469                #name = ....
470                max_points_per_cell = max_points_per_cell)
471
472    #root.show()
473   
474    #Insert indices of all vertices
475    root.insert( range(mesh.number_of_nodes) )
476
477    #Build quad tree and return
478    root.split()
479
480    return root
Note: See TracBrowser for help on using the repository browser.