source: anuga_core/source/anuga/utilities/quad.py @ 7317

Last change on this file since 7317 was 7317, checked in by rwilson, 15 years ago

Replaced 'print' statements with log.critical() calls.

File size: 14.7 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8import anuga.utilities.log as log
9
10
11#FIXME verts are added one at a time.
12#FIXME add max min x y in general_mesh
13
14class Cell(TreeNode):
15    """class Cell
16
17    One cell in the plane delimited by southern, northern,
18    western, eastern boundaries.
19
20    Public Methods:
21        prune()
22        insert(point)
23        search(x, y)
24        collapse()
25        split()
26        store()
27        retrieve()
28        count()
29    """
30 
31    def __init__(self, southern, northern, western, eastern, mesh,
32                 name = 'cell',
33                 max_points_per_cell = 4):
34 
35        # Initialise base classes
36        TreeNode.__init__(self, string.lower(name))
37       
38        # Initialise cell
39        self.southern = round(southern,5)   
40        self.northern = round(northern,5)
41        self.western = round(western,5)   
42        self.eastern = round(eastern,5)
43        self.mesh = mesh
44
45        # The points in this cell     
46        self.points = []
47       
48        self.max_points_per_cell = max_points_per_cell
49       
50       
51    def __repr__(self):
52        return self.name 
53
54
55    def spawn(self):
56        """Create four child cells unless they already exist
57        """
58
59        if self.children:
60            return
61        else:
62            self.children = []
63
64        # convenience variables
65        cs = self.southern   
66        cn = self.northern
67        cw = self.western   
68        ce = self.eastern   
69        mesh = self.mesh
70
71        # create 4 child cells
72        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,mesh,self.name+'_nw'))
73        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,mesh,self.name+'_ne'))
74        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,mesh,self.name+'_se'))
75        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,mesh,self.name+'_sw'))
76       
77 
78    def search(self, x, y, get_vertices=False):
79        """Find all point indices sharing the same cell as point (x, y)
80        """
81        branch = []
82        points = []
83        if self.children:
84            for child in self:
85                if child.contains(x,y):
86                    brothers = list(self.children)
87                    brothers.remove(child)
88                    branch.append(brothers)
89                    points, branch = child.search_branch(x,y, branch,
90                                                  get_vertices=get_vertices)
91        else:
92            # Leaf node: Get actual waypoints
93            points = self.retrieve(get_vertices=get_vertices)
94        self.branch = branch   
95        return points
96
97
98    def search_branch(self, x, y, branch, get_vertices=False):
99        """Find all point indices sharing the same cell as point (x, y)
100        """
101        points = []
102        if self.children:
103            for child in self:
104                if child.contains(x,y):
105                    brothers = list(self.children)
106                    brothers.remove(child)
107                    branch.append(brothers)
108                    points, branch = child.search_branch(x,y, branch,
109                                                  get_vertices=get_vertices)
110                   
111        else:
112            # Leaf node: Get actual waypoints
113            points = self.retrieve(get_vertices=get_vertices)     
114        return points, branch
115
116
117    def expand_search(self, get_vertices=False):
118        """Find all point indices 'up' one cell from the last search
119        """
120       
121        points = []
122        if self.branch == []:
123            points = []
124        else:
125            three_cells = self.branch.pop()
126            for cell in three_cells:
127                points += cell.retrieve(get_vertices=get_vertices)
128        return points, self.branch
129
130
131    def contains(*args):   
132        """True only if P's coordinates lie within cell boundaries
133        This methods has two forms:
134       
135        cell.contains(index)
136        #True if cell contains indexed point
137        cell.contains(x, y)
138        #True if cell contains point (x,y)     
139        """
140        self = args[0]
141        if len(args) == 2:
142            point_id = int(args[1])
143            x, y = self.mesh.get_node(point_id, absolute=True)
144        elif len(args) == 3:
145            x = float(args[1])
146            y = float(args[2])
147        else:
148            msg = 'Number of arguments to method must be two or three'
149            raise msg                         
150       
151        if y <  self.southern: return False
152        if y >= self.northern: return False
153        if x <  self.western:  return False
154        if x >= self.eastern:  return False
155        return True
156   
157   
158    def insert(self, points, split = False):
159        """insert point(s) in existing tree structure below self
160           and split if requested
161        """
162
163        # Call insert for each element of a list of points
164        if type(points) == types.ListType:
165            for point in points:
166                self.insert(point, split)
167        else:
168            #Only one point given as argument   
169            point = points
170       
171            # Find appropriate cell
172            if self.children is not None:
173                for child in self:
174                    if child.contains(point):
175                        child.insert(point, split)
176                        break
177            else:
178                # self is a leaf cell: insert point into cell
179                if self.contains(point):
180                    self.store(point)
181                else:
182                    # Have to take into account of georef.
183                    #x = self.mesh.coordinates[point][0]
184                    #y = self.mesh.coordinates[point][1]
185                    node = self.mesh.get_node(point, absolute=True)
186                    msg = ('point not in region: %s\nnode=%s'
187                           % (str(point), str(node)))
188                    raise Exception, msg
189               
190               
191        #Split datastructure if requested       
192        if split is True:
193            self.split()
194               
195
196
197    def store(self,objects):
198       
199        if type(objects) not in [types.ListType,types.TupleType]:
200            self.points.append(objects)
201        else:
202            self.points.extend(objects)
203
204
205    def retrieve_triangles(self):
206        """return a list of lists. For the inner lists,
207        The first element is the triangle index,
208        the second element is a list.for this list
209           the first element is a list of three (x, y) vertices,
210           the following elements are the three triangle normals.
211
212        This info is used in searching for a triangle that a point is in.
213
214        Post condition
215        No more points can be added to the quad tree, since the
216        points data structure is removed.
217        """
218        # FIXME Tidy up the structure that is returned.
219        # if the triangles att has been made
220        # return it.
221        if not hasattr(self,'triangles'):
222            # use a dictionary to remove duplicates
223            triangles = {}
224            verts = self.retrieve_vertices()
225            for vert in verts:
226                triangle_list = self.mesh.get_triangles_and_vertices_per_node(vert)
227                for k, _ in triangle_list:
228                    if not triangles.has_key(k):
229                        tri = self.mesh.get_vertex_coordinates(k,
230                                                               absolute=True)
231                        n0 = self.mesh.get_normal(k, 0)
232                        n1 = self.mesh.get_normal(k, 1)
233                        n2 = self.mesh.get_normal(k, 2) 
234                        triangles[k]=(tri, (n0, n1, n2))
235            self.triangles = triangles.items()
236            # Delete the old cell data structure to save memory
237            del self.points
238        return self.triangles
239           
240    def retrieve_vertices(self):
241         return self.points
242
243
244    def retrieve(self, get_vertices=True):
245         objects = []
246         if self.children is None:
247             if get_vertices is True:
248                 objects = self.retrieve_vertices()
249             else:
250                 objects =  self.retrieve_triangles()
251         else: 
252             for child in self:
253                 objects += child.retrieve(get_vertices=get_vertices)
254         return objects
255       
256
257    def count(self, keywords=None):
258        """retrieve number of stored objects beneath this node inclusive
259        """
260       
261        num_waypoint = 0
262        if self.children:
263            for child in self:
264                num_waypoint = num_waypoint + child.count()
265        else:
266            num_waypoint = len(self.points)
267        return num_waypoint
268 
269
270    def clear(self):
271        self.Prune()   # TreeNode method
272
273
274    def clear_leaf_node(self):
275        """Clears storage in leaf node.
276        Called from Treenod.
277        Must exist.     
278        """
279        self.points = []
280       
281       
282    def clear_internal_node(self):
283        """Called from Treenode.   
284        Must exist.
285        """
286        pass
287
288
289
290    def split(self, threshold=None):
291        """
292        Partition cell when number of contained waypoints exceeds
293        threshold.  All waypoints are then moved into correct
294        child cell.
295        """
296        if threshold == None:
297           threshold = self.max_points_per_cell
298           
299        #FIXME, mincellsize removed.  base it on side length, if needed
300       
301        #Protect against silly thresholds such as -1
302        if threshold < 1:
303            return
304       
305        if not self.children:               # Leaf cell
306            if self.count() > threshold :   
307                #Split is needed
308                points = self.retrieve()    # Get points from leaf cell
309                self.clear()                # and remove them from storage
310                   
311                self.spawn()                # Spawn child cells and move
312                for p in points:            # points to appropriate child
313                    for child in self:
314                        if child.contains(p):
315                            child.insert(p) 
316                            break
317                       
318        if self.children:                   # Parent cell
319            for child in self:              # split (possibly newly created)
320                child.split(threshold)      # child cells recursively
321               
322
323
324    def collapse(self,threshold=None):
325        """
326        collapse child cells into immediate parent if total number of contained waypoints
327        in subtree below is less than or equal to threshold.
328        All waypoints are then moved into parent cell and
329        children are removed. If self is a leaf node initially, do nothing.
330        """
331       
332        if threshold is None:
333            threshold = self.max_points_per_cell       
334
335
336        if self.children:                   # Parent cell   
337            if self.count() <= threshold:   # collapse
338                points = self.retrieve()    # Get all points from child cells
339                self.clear()                # Remove children, self is now a leaf node
340                self.insert(points)         # Insert all points in local storage
341            else:                         
342                for child in self:          # Check if any sub tree can be collapsed
343                    child.collapse(threshold)
344
345
346    def Get_tree(self,depth=0):
347        """Traverse tree below self
348           Print for each node the name and
349           if it is a leaf the number of objects
350        """
351        s = ''
352        if depth == 0:
353            s = '\n'
354           
355        s += "%s%s:" % ('  '*depth, self.name)
356        if self.children:
357            s += '\n'
358            for child in self.children:
359                s += child.Get_tree(depth+1)
360        else:
361            s += '(#wp=%d)\n' %(self.count())
362
363        return s
364
365       
366    def show(self, depth=0):
367        """Traverse tree below self
368           Print for each node the name and
369           if it is a leaf the number of objects
370        """
371        if depth == 0:
372            log.critical() 
373        log.critical("%s%s" % ('  '*depth, self.name))
374        if self.children:
375            log.critical()
376            for child in self.children:
377                child.show(depth+1)
378        else:
379            log.critical('(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'
380                         % (self.western, self.eastern, self.southern,
381                            self.northern, self.count()))
382
383
384    def show_all(self,depth=0):
385        """Traverse tree below self
386           Print for each node the name and if it is a leaf all its objects
387        """
388        if depth == 0:
389            log.critical() 
390        log.critical("%s%s:" % ('  '*depth, self.name))
391        if self.children:
392            print
393            for child in self.children:
394                child.show_all(depth+1)
395        else:
396            log.critical('%s' % self.retrieve())
397
398
399    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
400        """Traverse tree below self and find minimal cell radius,
401           maximumtree depth and maximum number of waypoints per leaf.
402        """
403
404        if self.children:
405            for child in self.children:
406                min_rad, max_depth, max_points =\
407                         child.Stats(depth+1,min_rad,max_depth,max_points)
408        else:
409            #FIXME remvoe radius stuff
410            #min_rad = sys.maxint
411            #if self.radius < min_rad:   min_rad = self.radius
412            if depth > max_depth: max_depth = depth
413            num_points = self.count()
414            if num_points > max_points: max_points = num_points
415
416        #return min_rad, max_depth, max_points   
417        return max_depth, max_points   
418       
419
420    #Class initialisation method
421    # this is bad.  It adds a huge memory structure to the class.
422    # When the instance is deleted the mesh hangs round (leaks).
423    #def initialise(cls, mesh):
424    #    cls.mesh = mesh
425
426    #initialise = classmethod(initialise)
427
428def build_quadtree(mesh, max_points_per_cell = 4):
429    """Build quad tree for mesh.
430
431    All vertices in mesh are stored in quadtree and a reference
432    to the root is returned.
433    """
434
435
436    #Make root cell
437    #print mesh.coordinates
438
439    xmin, xmax, ymin, ymax = mesh.get_extent(absolute=True)
440   
441    # Ensure boundary points are fully contained in region
442    # It is a property of the cell structure that
443    # points on xmax or ymax of any given cell
444    # belong to the neighbouring cell.
445    # Hence, the root cell needs to be expanded slightly
446    ymax += (ymax-ymin)/10
447    xmax += (xmax-xmin)/10
448
449    # To avoid round off error
450    ymin -= (ymax-ymin)/10
451    xmin -= (xmax-xmin)/10   
452
453    #print "xmin", xmin
454    #print "xmax", xmax
455    #print "ymin", ymin
456    #print "ymax", ymax
457   
458    #FIXME: Use mesh.filename if it exists
459    # why?
460    root = Cell(ymin, ymax, xmin, xmax,mesh,
461                max_points_per_cell = max_points_per_cell)
462
463    #root.show()
464   
465    #Insert indices of all vertices
466    root.insert( range(mesh.number_of_nodes) )
467
468    #Build quad tree and return
469    root.split()
470
471    return root
Note: See TracBrowser for help on using the repository browser.