source: branches/numpy/anuga/utilities/quad.py @ 6441

Last change on this file since 6441 was 6441, checked in by rwilson, 15 years ago

After changes to get_absolute, ensure_numeric, etc.

File size: 14.7 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8
9
10#FIXME verts are added one at a time.
11#FIXME add max min x y in general_mesh
12
13class Cell(TreeNode):
14    """class Cell
15
16    One cell in the plane delimited by southern, northern,
17    western, eastern boundaries.
18
19    Public Methods:
20        prune()
21        insert(point)
22        search(x, y)
23        collapse()
24        split()
25        store()
26        retrieve()
27        count()
28    """
29 
30    def __init__(self, southern, northern, western, eastern, mesh,
31                 name = 'cell',
32                 max_points_per_cell = 4):
33 
34        # Initialise base classes
35        TreeNode.__init__(self, string.lower(name))
36       
37        # Initialise cell
38        self.southern = round(southern,5)   
39        self.northern = round(northern,5)
40        self.western = round(western,5)   
41        self.eastern = round(eastern,5)
42        self.mesh = mesh
43
44        # The points in this cell     
45        self.points = []
46       
47        self.max_points_per_cell = max_points_per_cell
48       
49       
50    def __repr__(self):
51        return self.name 
52
53
54    def spawn(self):
55        """Create four child cells unless they already exist
56        """
57
58        if self.children:
59            return
60        else:
61            self.children = []
62
63        # convenience variables
64        cs = self.southern   
65        cn = self.northern
66        cw = self.western   
67        ce = self.eastern   
68        mesh = self.mesh
69
70        # create 4 child cells
71        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,mesh,self.name+'_nw'))
72        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,mesh,self.name+'_ne'))
73        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,mesh,self.name+'_se'))
74        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,mesh,self.name+'_sw'))
75       
76 
77    def search(self, x, y, get_vertices=False):
78        """Find all point indices sharing the same cell as point (x, y)
79        """
80        branch = []
81        points = []
82        if self.children:
83            for child in self:
84                if child.contains(x,y):
85                    brothers = list(self.children)
86                    brothers.remove(child)
87                    branch.append(brothers)
88                    points, branch = child.search_branch(x,y, branch,
89                                                  get_vertices=get_vertices)
90        else:
91            # Leaf node: Get actual waypoints
92            points = self.retrieve(get_vertices=get_vertices)
93        self.branch = branch   
94        return points
95
96
97    def search_branch(self, x, y, branch, get_vertices=False):
98        """Find all point indices sharing the same cell as point (x, y)
99        """
100        points = []
101        if self.children:
102            for child in self:
103                if child.contains(x,y):
104                    brothers = list(self.children)
105                    brothers.remove(child)
106                    branch.append(brothers)
107                    points, branch = child.search_branch(x,y, branch,
108                                                  get_vertices=get_vertices)
109                   
110        else:
111            # Leaf node: Get actual waypoints
112            points = self.retrieve(get_vertices=get_vertices)     
113        return points, branch
114
115
116    def expand_search(self, get_vertices=False):
117        """Find all point indices 'up' one cell from the last search
118        """
119       
120        points = []
121        if self.branch == []:
122            points = []
123        else:
124            three_cells = self.branch.pop()
125            for cell in three_cells:
126                #print "cell ", cell.show()
127                points += cell.retrieve(get_vertices=get_vertices)
128        return points, self.branch
129
130
131    def contains(*args):   
132        """True only if P's coordinates lie within cell boundaries
133        This methods has two forms:
134       
135        cell.contains(index)
136        #True if cell contains indexed point
137        cell.contains(x, y)
138        #True if cell contains point (x,y)     
139        """
140        self = args[0]
141        if len(args) == 2:
142            point_id = int(args[1])
143            x, y = self.mesh.get_node(point_id, absolute=True)
144        elif len(args) == 3:
145            x = float(args[1])
146            y = float(args[2])
147        else:
148            msg = 'Number of arguments to method must be two or three'
149            raise msg                         
150       
151        if y <  self.southern: return False
152        if y >= self.northern: return False
153        if x <  self.western:  return False
154        if x >= self.eastern:  return False
155        return True
156   
157   
158    def insert(self, points, split = False):
159        """insert point(s) in existing tree structure below self
160           and split if requested
161        """
162
163        # Call insert for each element of a list of points
164        if type(points) == types.ListType:
165            for point in points:
166                self.insert(point, split)
167        else:
168            #Only one point given as argument   
169            point = points
170       
171            # Find appropriate cell
172            if self.children is not None:
173                for child in self:
174                    if child.contains(point):
175                        child.insert(point, split)
176                        break
177            else:
178                # self is a leaf cell: insert point into cell
179                if self.contains(point):
180                    self.store(point)
181                else:
182                    # Have to take into account of georef.
183                    #x = self.mesh.coordinates[point][0]
184                    #y = self.mesh.coordinates[point][1]
185                    node = self.mesh.get_node(point, absolute=True)
186                    msg = ('point not in region: %s\nnode=%s'
187                           % (str(point), str(node)))
188                    raise Exception, msg
189               
190               
191        #Split datastructure if requested       
192        if split is True:
193            self.split()
194               
195
196
197    def store(self,objects):
198       
199        if type(objects) not in [types.ListType,types.TupleType]:
200            self.points.append(objects)
201        else:
202            self.points.extend(objects)
203
204
205    def retrieve_triangles(self):
206        """return a list of lists. For the inner lists,
207        The first element is the triangle index,
208        the second element is a list.for this list
209           the first element is a list of three (x, y) vertices,
210           the following elements are the three triangle normals.
211
212        This info is used in searching for a triangle that a point is in.
213
214        Post condition
215        No more points can be added to the quad tree, since the
216        points data structure is removed.
217        """
218        # FIXME Tidy up the structure that is returned.
219        # if the triangles att has been made
220        # return it.
221        if not hasattr(self,'triangles'):
222            # use a dictionary to remove duplicates
223            triangles = {}
224            verts = self.retrieve_vertices()
225            # print "verts", verts
226            for vert in verts:
227                triangle_list = self.mesh.get_triangles_and_vertices_per_node(vert)
228                for k, _ in triangle_list:
229                    if not triangles.has_key(k):
230                        # print 'k',k
231                        tri = self.mesh.get_vertex_coordinates(k,
232                                                               absolute=True)
233                        n0 = self.mesh.get_normal(k, 0)
234                        n1 = self.mesh.get_normal(k, 1)
235                        n2 = self.mesh.get_normal(k, 2) 
236                        triangles[k]=(tri, (n0, n1, n2))
237            self.triangles = triangles.items()
238            # Delete the old cell data structure to save memory
239            del self.points
240        return self.triangles
241           
242    def retrieve_vertices(self):
243         return self.points
244
245
246    def retrieve(self, get_vertices=True):
247         objects = []
248         if self.children is None:
249             if get_vertices is True:
250                 objects = self.retrieve_vertices()
251             else:
252                 objects =  self.retrieve_triangles()
253         else: 
254             for child in self:
255                 objects += child.retrieve(get_vertices=get_vertices)
256         return objects
257       
258
259    def count(self, keywords=None):
260        """retrieve number of stored objects beneath this node inclusive
261        """
262       
263        num_waypoint = 0
264        if self.children:
265            for child in self:
266                num_waypoint = num_waypoint + child.count()
267        else:
268            num_waypoint = len(self.points)
269        return num_waypoint
270 
271
272    def clear(self):
273        self.Prune()   # TreeNode method
274
275
276    def clear_leaf_node(self):
277        """Clears storage in leaf node.
278        Called from Treenod.
279        Must exist.     
280        """
281        self.points = []
282       
283       
284    def clear_internal_node(self):
285        """Called from Treenode.   
286        Must exist.
287        """
288        pass
289
290
291
292    def split(self, threshold=None):
293        """
294        Partition cell when number of contained waypoints exceeds
295        threshold.  All waypoints are then moved into correct
296        child cell.
297        """
298        if threshold == None:
299           threshold = self.max_points_per_cell
300           
301        #FIXME, mincellsize removed.  base it on side length, if needed
302       
303        #Protect against silly thresholds such as -1
304        if threshold < 1:
305            return
306       
307        if not self.children:               # Leaf cell
308            if self.count() > threshold :   
309                #Split is needed
310                points = self.retrieve()    # Get points from leaf cell
311                self.clear()                # and remove them from storage
312                   
313                self.spawn()                # Spawn child cells and move
314                for p in points:            # points to appropriate child
315                    for child in self:
316                        if child.contains(p):
317                            child.insert(p) 
318                            break
319                       
320        if self.children:                   # Parent cell
321            for child in self:              # split (possibly newly created)
322                child.split(threshold)      # child cells recursively
323               
324
325
326    def collapse(self,threshold=None):
327        """
328        collapse child cells into immediate parent if total number of contained waypoints
329        in subtree below is less than or equal to threshold.
330        All waypoints are then moved into parent cell and
331        children are removed. If self is a leaf node initially, do nothing.
332        """
333       
334        if threshold is None:
335            threshold = self.max_points_per_cell       
336
337
338        if self.children:                   # Parent cell   
339            if self.count() <= threshold:   # collapse
340                points = self.retrieve()    # Get all points from child cells
341                self.clear()                # Remove children, self is now a leaf node
342                self.insert(points)         # Insert all points in local storage
343            else:                         
344                for child in self:          # Check if any sub tree can be collapsed
345                    child.collapse(threshold)
346
347
348    def Get_tree(self,depth=0):
349        """Traverse tree below self
350           Print for each node the name and
351           if it is a leaf the number of objects
352        """
353        s = ''
354        if depth == 0:
355            s = '\n'
356           
357        s += "%s%s:" % ('  '*depth, self.name)
358        if self.children:
359            s += '\n'
360            for child in self.children:
361                s += child.Get_tree(depth+1)
362        else:
363            s += '(#wp=%d)\n' %(self.count())
364
365        return s
366
367       
368    def show(self, depth=0):
369        """Traverse tree below self
370           Print for each node the name and
371           if it is a leaf the number of objects
372        """
373        if depth == 0:
374            print 
375        print "%s%s" % ('  '*depth, self.name),
376        if self.children:
377            print
378            for child in self.children:
379                child.show(depth+1)
380        else:
381            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
382                  %(self.western, self.eastern,
383                    self.southern, self.northern,
384                    self.count()) 
385
386
387    def show_all(self,depth=0):
388        """Traverse tree below self
389           Print for each node the name and if it is a leaf all its objects
390        """
391        if depth == 0:
392            print 
393        print "%s%s:" % ('  '*depth, self.name),
394        if self.children:
395            print
396            for child in self.children:
397                child.show_all(depth+1)
398        else:
399            print '%s' %self.retrieve()
400
401
402    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
403        """Traverse tree below self and find minimal cell radius,
404           maximumtree depth and maximum number of waypoints per leaf.
405        """
406
407        if self.children:
408            for child in self.children:
409                min_rad, max_depth, max_points =\
410                         child.Stats(depth+1,min_rad,max_depth,max_points)
411        else:
412            #FIXME remvoe radius stuff
413            #min_rad = sys.maxint
414            #if self.radius < min_rad:   min_rad = self.radius
415            if depth > max_depth: max_depth = depth
416            num_points = self.count()
417            if num_points > max_points: max_points = num_points
418
419        #return min_rad, max_depth, max_points   
420        return max_depth, max_points   
421       
422
423    #Class initialisation method
424    # this is bad.  It adds a huge memory structure to the class.
425    # When the instance is deleted the mesh hangs round (leaks).
426    #def initialise(cls, mesh):
427    #    cls.mesh = mesh
428
429    #initialise = classmethod(initialise)
430
431def build_quadtree(mesh, max_points_per_cell = 4):
432    """Build quad tree for mesh.
433
434    All vertices in mesh are stored in quadtree and a reference
435    to the root is returned.
436    """
437
438
439    #Make root cell
440    #print mesh.coordinates
441
442    xmin, xmax, ymin, ymax = mesh.get_extent(absolute=True)
443   
444    # Ensure boundary points are fully contained in region
445    # It is a property of the cell structure that
446    # points on xmax or ymax of any given cell
447    # belong to the neighbouring cell.
448    # Hence, the root cell needs to be expanded slightly
449    ymax += (ymax-ymin)/10
450    xmax += (xmax-xmin)/10
451
452    # To avoid round off error
453    ymin -= (ymax-ymin)/10
454    xmin -= (xmax-xmin)/10   
455
456    #print "xmin", xmin
457    #print "xmax", xmax
458    #print "ymin", ymin
459    #print "ymax", ymax
460   
461    #FIXME: Use mesh.filename if it exists
462    # why?
463    root = Cell(ymin, ymax, xmin, xmax,mesh,
464                max_points_per_cell = max_points_per_cell)
465
466    #root.show()
467   
468    #Insert indices of all vertices
469    root.insert( range(mesh.number_of_nodes) )
470
471    #Build quad tree and return
472    root.split()
473
474    return root
Note: See TracBrowser for help on using the repository browser.