source: branches/source_numpy_conversion/anuga/utilities/quad.py @ 6768

Last change on this file since 6768 was 5902, checked in by rwilson, 16 years ago

NumPy? conversion.

File size: 15.0 KB
Line 
1
2"""quad.py - quad tree data structure for fast indexing of points in the plane
3
4
5"""
6
7from treenode import TreeNode
8import string, types, sys
9
10#FIXME verts are added one at a time.
11#FIXME add max min x y in general_mesh
12
13class Cell(TreeNode):
14    """class Cell
15
16    One cell in the plane delimited by southern, northern,
17    western, eastern boundaries.
18
19    Public Methods:
20        prune()
21        insert(point)
22        search(x, y)
23        collapse()
24        split()
25        store()
26        retrieve()
27        count()
28    """
29 
30    def __init__(self, southern, northern, western, eastern, mesh,
31                 name = 'cell',
32                 max_points_per_cell = 4):
33 
34        # Initialise base classes
35        TreeNode.__init__(self, string.lower(name))
36       
37        # Initialise cell
38        self.southern = round(southern,5)   
39        self.northern = round(northern,5)
40        self.western = round(western,5)   
41        self.eastern = round(eastern,5)
42        self.mesh = mesh
43
44        # The points in this cell     
45        self.points = []
46       
47        self.max_points_per_cell = max_points_per_cell
48       
49       
50    def __repr__(self):
51        return self.name 
52
53
54    def spawn(self):
55        """Create four child cells unless they already exist
56        """
57
58        if self.children:
59            return
60        else:
61            self.children = []
62
63        # convenience variables
64        cs = self.southern   
65        cn = self.northern
66        cw = self.western   
67        ce = self.eastern   
68        mesh = self.mesh
69
70        # create 4 child cells
71        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,mesh,self.name+'_nw'))
72        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,mesh,self.name+'_ne'))
73        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,mesh,self.name+'_se'))
74        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,mesh,self.name+'_sw'))
75       
76 
77    def search(self, x, y, get_vertices=False):
78        """Find all point indices sharing the same cell as point (x, y)
79        """
80        branch = []
81        points = []
82        if self.children:
83            for child in self:
84                if child.contains(x,y):
85                    brothers = list(self.children)
86                    brothers.remove(child)
87                    branch.append(brothers)
88                    points, branch = child.search_branch(x,y, branch,
89                                                  get_vertices=get_vertices)
90        else:
91            # Leaf node: Get actual waypoints
92            points = self.retrieve(get_vertices=get_vertices)
93        self.branch = branch   
94        return points
95
96
97    def search_branch(self, x, y, branch, get_vertices=False):
98        """Find all point indices sharing the same cell as point (x, y)
99        """
100        points = []
101        if self.children:
102            for child in self:
103                if child.contains(x,y):
104                    brothers = list(self.children)
105                    brothers.remove(child)
106                    branch.append(brothers)
107                    points, branch = child.search_branch(x,y, branch,
108                                                  get_vertices=get_vertices)
109                   
110        else:
111            # Leaf node: Get actual waypoints
112            points = self.retrieve(get_vertices=get_vertices)     
113        return points, branch
114
115
116    def expand_search(self, get_vertices=False):
117        """Find all point indices 'up' one cell from the last search
118        """
119       
120        points = []
121        if self.branch == []:
122            points = []
123        else:
124            three_cells = self.branch.pop()
125            for cell in three_cells:
126                #print "cell ", cell.show()
127                points += cell.retrieve(get_vertices=get_vertices)
128        return points, self.branch
129
130
131    def contains(*args):   
132        """True only if P's coordinates lie within cell boundaries
133        This methods has two forms:
134       
135        cell.contains(index)
136        #True if cell contains indexed point
137        cell.contains(x, y)
138        #True if cell contains point (x,y)     
139        """
140        self = args[0]
141        if len(args) == 2:
142            point_id = int(args[1])
143            x, y = self.mesh.get_node(point_id, absolute=True)
144        elif len(args) == 3:
145            x = float(args[1])
146            y = float(args[2])
147        else:
148            msg = 'Number of arguments to method must be two or three'
149            raise msg                         
150       
151        if y <  self.southern: return False
152        if y >= self.northern: return False
153        if x <  self.western:  return False
154        if x >= self.eastern:  return False
155        return True
156   
157   
158    def insert(self, points, split = False):
159        """insert point(s) in existing tree structure below self
160           and split if requested
161        """
162
163        # Call insert for each element of a list of points
164        if type(points) == types.ListType:
165            for point in points:
166                self.insert(point, split)
167        else:
168            #Only one point given as argument   
169            point = points
170       
171            # Find appropriate cell
172            if self.children is not None:
173                for child in self:
174                    if child.contains(point):
175                        child.insert(point, split)
176                        break
177            else:
178                # self is a leaf cell: insert point into cell
179                if self.contains(point):
180                    self.store(point)
181                else:
182                    # Have to take into account of georef.
183                    #x = self.mesh.coordinates[point][0]
184                    #y = self.mesh.coordinates[point][1]
185                    node = self.mesh.get_node(point, absolute=True)
186                    print "node", node
187                    print "(" + str(node[0]) + "," + str(node[1]) + ")"
188                    raise 'point not in region: %s' %str(point)
189               
190               
191        #Split datastructure if requested       
192        if split is True:
193            self.split()
194               
195
196
197    def store(self,objects):
198       
199        if type(objects) not in [types.ListType,types.TupleType]:
200            self.points.append(objects)
201        else:
202            self.points.extend(objects)
203
204
205    def retrieve_triangles(self):
206        """return a list of lists. For the inner lists,
207        The first element is the triangle index,
208        the second element is a list.for this list
209           the first element is a list of three (x, y) vertices,
210           the following elements are the three triangle normals.
211
212        This info is used in searching for a triangle that a point is in.
213
214        Post condition
215        No more points can be added to the quad tree, since the
216        points data structure is removed.
217        """
218        # FIXME Tidy up the structure that is returned.
219        # if the triangles att has been made
220        # return it.
221        if not hasattr(self,'triangles'):
222            # use a dictionary to remove duplicates
223            triangles = {}
224            verts = self.retrieve_vertices()
225            for vert in verts:
226                triangle_list = self.mesh.get_triangles_and_vertices_per_node(vert)
227                for k, _ in triangle_list:
228                    if not triangles.has_key(k):
229                        # print 'k',k
230                        tri = self.mesh.get_vertex_coordinates(k, absolute=True)
231                        n0 = self.mesh.get_normal(k, 0)
232                        n1 = self.mesh.get_normal(k, 1)
233                        n2 = self.mesh.get_normal(k, 2) 
234                        triangles[k]=(tri, (n0, n1, n2))
235            self.triangles = triangles.items()
236            # Delete the old cell data structure to save memory
237            del self.points
238        return self.triangles
239           
240    def retrieve_vertices(self):
241         return self.points
242
243
244    def retrieve(self, get_vertices=True):
245         objects = []
246         if self.children is None:
247             if get_vertices is True:
248                 objects = self.retrieve_vertices()
249             else:
250                 objects =  self.retrieve_triangles()
251         else: 
252             for child in self:
253                 objects += child.retrieve(get_vertices=get_vertices)
254         return objects
255       
256
257    def count(self, keywords=None):
258        """retrieve number of stored objects beneath this node inclusive
259        """
260       
261        num_waypoint = 0
262        if self.children:
263            for child in self:
264                num_waypoint = num_waypoint + child.count()
265        else:
266            num_waypoint = len(self.points)
267        return num_waypoint
268 
269
270    def clear(self):
271        self.Prune()   # TreeNode method
272
273
274    def clear_leaf_node(self):
275        """Clears storage in leaf node.
276        Called from Treenod.
277        Must exist.     
278        """
279        self.points = []
280       
281       
282    def clear_internal_node(self):
283        """Called from Treenode.   
284        Must exist.
285        """
286        pass
287
288
289
290    def split(self, threshold=None):
291        """
292        Partition cell when number of contained waypoints exceeds
293        threshold.  All waypoints are then moved into correct
294        child cell.
295        """
296        if threshold == None:
297           threshold = self.max_points_per_cell
298           
299        #FIXME, mincellsize removed.  base it on side length, if needed
300       
301        #Protect against silly thresholds such as -1
302        if threshold < 1:
303            return
304       
305        if not self.children:               # Leaf cell
306            if self.count() > threshold :   
307                #Split is needed
308                points = self.retrieve()    # Get points from leaf cell
309                self.clear()                # and remove them from storage
310                   
311                self.spawn()                # Spawn child cells and move
312                for p in points:            # points to appropriate child
313                    for child in self:
314                        if child.contains(p):
315                            child.insert(p) 
316                            break
317                       
318        if self.children:                   # Parent cell
319            for child in self:              # split (possibly newly created)
320                child.split(threshold)      # child cells recursively
321               
322
323
324    def collapse(self,threshold=None):
325        """
326        collapse child cells into immediate parent if total number of contained waypoints
327        in subtree below is less than or equal to threshold.
328        All waypoints are then moved into parent cell and
329        children are removed. If self is a leaf node initially, do nothing.
330        """
331       
332        if threshold is None:
333            threshold = self.max_points_per_cell       
334
335
336        if self.children:                   # Parent cell   
337            if self.count() <= threshold:   # collapse
338                points = self.retrieve()    # Get all points from child cells
339                self.clear()                # Remove children, self is now a leaf node
340                self.insert(points)         # Insert all points in local storage
341            else:                         
342                for child in self:          # Check if any sub tree can be collapsed
343                    child.collapse(threshold)
344
345
346    def Get_tree(self,depth=0):
347        """Traverse tree below self
348           Print for each node the name and
349           if it is a leaf the number of objects
350        """
351        s = ''
352        if depth == 0:
353            s = '\n'
354           
355        s += "%s%s:" % ('  '*depth, self.name)
356        if self.children:
357            s += '\n'
358            for child in self.children:
359                s += child.Get_tree(depth+1)
360        else:
361            s += '(#wp=%d)\n' %(self.count())
362
363        return s
364
365       
366    def show(self, depth=0):
367        """Traverse tree below self
368           Print for each node the name and
369           if it is a leaf the number of objects
370        """
371        if depth == 0:
372            print 
373        print "%s%s" % ('  '*depth, self.name),
374        if self.children:
375            print
376            for child in self.children:
377                child.show(depth+1)
378        else:
379            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
380                  %(self.western, self.eastern,
381                    self.southern, self.northern,
382                    self.count()) 
383
384
385    def show_all(self,depth=0):
386        """Traverse tree below self
387           Print for each node the name and if it is a leaf all its objects
388        """
389        if depth == 0:
390            print 
391        print "%s%s:" % ('  '*depth, self.name),
392        if self.children:
393            print
394            for child in self.children:
395                child.show_all(depth+1)
396        else:
397            print '%s' %self.retrieve()
398
399
400    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
401        """Traverse tree below self and find minimal cell radius,
402           maximumtree depth and maximum number of waypoints per leaf.
403        """
404
405        if self.children:
406            for child in self.children:
407                min_rad, max_depth, max_points =\
408                         child.Stats(depth+1,min_rad,max_depth,max_points)
409        else:
410            #FIXME remvoe radius stuff
411            #min_rad = sys.maxint
412            #if self.radius < min_rad:   min_rad = self.radius
413            if depth > max_depth: max_depth = depth
414            num_points = self.count()
415            if num_points > max_points: max_points = num_points
416
417        #return min_rad, max_depth, max_points   
418        return max_depth, max_points   
419       
420
421    #Class initialisation method
422    # this is bad.  It adds a huge memory structure to the class.
423    # When the instance is deleted the mesh hangs round (leaks).
424    #def initialise(cls, mesh):
425    #    cls.mesh = mesh
426
427    #initialise = classmethod(initialise)
428
429def build_quadtree(mesh, max_points_per_cell = 4):
430    """Build quad tree for mesh.
431
432    All vertices in mesh are stored in quadtree and a reference
433    to the root is returned.
434    """
435
436    #Make root cell
437    #print mesh.coordinates
438
439    xmin, xmax, ymin, ymax = mesh.get_extent(absolute=True)
440   
441    # Ensure boundary points are fully contained in region
442    # It is a property of the cell structure that
443    # points on xmax or ymax of any given cell
444    # belong to the neighbouring cell.
445    # Hence, the root cell needs to be expanded slightly
446    ymax += (ymax-ymin)/10
447    xmax += (xmax-xmin)/10
448
449    # To avoid round off error
450    ymin -= (ymax-ymin)/10
451    xmin -= (xmax-xmin)/10   
452
453    #print "xmin", xmin
454    #print "xmax", xmax
455    #print "ymin", ymin
456    #print "ymax", ymax
457   
458    #FIXME: Use mesh.filename if it exists
459    # why?
460    root = Cell(ymin, ymax, xmin, xmax,mesh,
461                max_points_per_cell = max_points_per_cell)
462
463    #root.show()
464   
465    #Insert indices of all vertices
466    root.insert( range(mesh.number_of_nodes) )
467
468    #Build quad tree and return
469    root.split()
470
471    return root
Note: See TracBrowser for help on using the repository browser.