source: anuga_core/source/anuga/utilities/quad.py @ 4653

Last change on this file since 4653 was 4653, checked in by duncan, 16 years ago

checking in for benchmarking. When fitting cell data - triangle vertices and norms - are calculated the first time a point is looked for in a cell. This is to speed thing up.

File size: 14.1 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8
9#FIXME verts are added one at a time.
10#FIXME add max min x y in general_mesh
11
12class Cell(TreeNode):
13    """class Cell
14
15    One cell in the plane delimited by southern, northern,
16    western, eastern boundaries.
17
18    Public Methods:
19        prune()
20        insert(point)
21        search(x, y)
22        collapse()
23        split()
24        store()
25        retrieve()
26        count()
27    """
28 
29    def __init__(self, southern, northern, western, eastern, 
30                 name = 'cell',
31                 max_points_per_cell = 4):
32 
33        # Initialise base classes
34        TreeNode.__init__(self, string.lower(name))
35       
36        # Initialise cell
37        self.southern = round(southern,5)   
38        self.northern = round(northern,5)
39        self.western = round(western,5)   
40        self.eastern = round(eastern,5)
41
42        # The points in this cell     
43        self.points = []
44       
45        self.max_points_per_cell = max_points_per_cell
46       
47       
48    def __repr__(self):
49        return self.name 
50
51
52    def spawn(self):
53        """Create four child cells unless they already exist
54        """
55
56        if self.children:
57            return
58        else:
59            self.children = []
60
61        # convenience variables
62        cs = self.southern   
63        cn = self.northern
64        cw = self.western   
65        ce = self.eastern
66
67        # create 4 child cells
68        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,self.name+'_nw'))
69        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,self.name+'_ne'))
70        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,self.name+'_se'))
71        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,self.name+'_sw'))
72       
73 
74    def search(self, x, y, get_vertices=False):
75        """Find all point indices sharing the same cell as point (x, y)
76        """
77        branch = []
78        points = []
79        if self.children:
80            for child in self:
81                if child.contains(x,y):
82                    brothers = list(self.children)
83                    brothers.remove(child)
84                    branch.append(brothers)
85                    points, branch = child.search_branch(x,y, branch,
86                                                  get_vertices=get_vertices)
87        else:
88            # Leaf node: Get actual waypoints
89            points = self.retrieve(get_vertices=get_vertices)
90        self.branch = branch   
91        return points
92
93
94    def search_branch(self, x, y, branch, get_vertices=False):
95        """Find all point indices sharing the same cell as point (x, y)
96        """
97        points = []
98        if self.children:
99            for child in self:
100                if child.contains(x,y):
101                    brothers = list(self.children)
102                    brothers.remove(child)
103                    branch.append(brothers)
104                    points, branch = child.search_branch(x,y, branch,
105                                                  get_vertices=get_vertices)
106                   
107        else:
108            # Leaf node: Get actual waypoints
109            points = self.retrieve(get_vertices=get_vertices)     
110        return points, branch
111
112
113    def expand_search(self, get_vertices=False):
114        """Find all point indices 'up' one cell from the last search
115        """
116        points = []
117        if self.branch == []:
118            points = []
119        else:
120            three_cells = self.branch.pop()
121            for cell in three_cells:
122                #print "cell ", cell.show()
123                points += cell.retrieve(get_vertices=get_vertices)
124        return points, self.branch
125
126
127    def contains(*args):   
128        """True only if P's coordinates lie within cell boundaries
129        This methods has two forms:
130       
131        cell.contains(index)
132          #True if cell contains indexed point
133        cell.contains(x, y)
134          #True if cell contains point (x,y)   
135
136        """
137       
138        self = args[0]
139        if len(args) == 2:
140            point_id = int(args[1])
141            x, y = self.__class__.mesh.get_nodes()[point_id]
142
143            #print point_id, x, y
144        elif len(args) == 3:
145            x = float(args[1])
146            y = float(args[2])
147        else:
148            msg = 'Number of arguments to method must be two or three'
149            raise msg                         
150       
151        if y <  self.southern: return False
152        if y >= self.northern: return False
153        if x <  self.western:  return False
154        if x >= self.eastern:  return False
155        return True
156   
157   
158    def insert(self, points, split = False):
159        """insert point(s) in existing tree structure below self
160           and split if requested
161        """
162
163        # Call insert for each element of a list of points
164        if type(points) == types.ListType:
165            for point in points:
166                self.insert(point, split)
167        else:
168            #Only one point given as argument   
169            point = points
170       
171            # Find appropriate cell
172            if self.children is not None:
173                for child in self:
174                    if child.contains(point):
175                        child.insert(point, split)
176                        break
177            else:
178                # self is a leaf cell: insert point into cell
179                if self.contains(point):
180                    self.store(point)
181                else:
182                    x = self.__class__.mesh.coordinates[point][0]
183                    y = self.__class__.mesh.coordinates[point][1]
184                    print "(" + str(x) + "," + str(y) + ")"
185                    raise 'point not in region: %s' %str(point)
186               
187               
188        #Split datastructure if requested       
189        if split is True:
190            self.split()
191               
192
193
194    def store(self,objects):
195       
196        if type(objects) not in [types.ListType,types.TupleType]:
197            self.points.append(objects)
198        else:
199            self.points.extend(objects)
200
201
202    def retrieve_triangles(self):
203        """return a list of lists. For the inner lists,
204        The first element is the triangle index,
205        the second element is a list.for this list
206           the first element is a list of three (x, y) vertices,
207           the following elements are the three triangle normals.
208
209        This info is used in searching for a triangle that a point is in.
210        """
211        # FIXME Tidy up the structure that is returned.
212        # if the triangles att has been made
213        # return it.
214        if not hasattr(self,'triangles'):
215            # use a dictionary to remove duplicates
216            triangles = {}
217            verts = self.retrieve_vertices()
218            # print "verts", verts
219            for vert in verts:
220                triangle_list = self.__class__.mesh.get_triangles_and_vertices_per_node(vert)
221                for k, _ in triangle_list:
222                    if not triangles.has_key(k):
223                        # print 'k',k
224                        tri = self.__class__.mesh.get_vertex_coordinates(k)
225                        n0 = self.__class__.mesh.get_normal(k, 0)
226                        n1 = self.__class__.mesh.get_normal(k, 1)
227                        n2 = self.__class__.mesh.get_normal(k, 2) 
228                        triangles[k]=(tri, (n0, n1, n2))
229            self.triangles = triangles.items()
230        return self.triangles
231           
232    def retrieve_vertices(self):
233         objects = []
234         if self.children is None:
235             objects = self.points
236         else: 
237             for child in self:
238                 objects += child.retrieve()
239         return objects 
240
241
242    def retrieve(self, get_vertices=True):
243        if get_vertices is True:
244            return self.retrieve_vertices()
245        else:
246            return self.retrieve_triangles()
247       
248
249    def count(self, keywords=None):
250        """retrieve number of stored objects beneath this node inclusive
251        """
252       
253        num_waypoint = 0
254        if self.children:
255            for child in self:
256                num_waypoint = num_waypoint + child.count()
257        else:
258            num_waypoint = len(self.points)
259        return num_waypoint
260 
261
262    def clear(self):
263        self.Prune()   # TreeNode method
264
265
266    def clear_leaf_node(self):
267        """Clears storage in leaf node.
268        Called from Treenod.
269        Must exist.     
270        """
271        self.points = []
272       
273       
274    def clear_internal_node(self):
275        """Called from Treenode.   
276        Must exist.
277        """
278        pass
279
280
281
282    def split(self, threshold=None):
283        """
284        Partition cell when number of contained waypoints exceeds
285        threshold.  All waypoints are then moved into correct
286        child cell.
287        """
288        if threshold == None:
289           threshold = self.max_points_per_cell
290           
291        #FIXME, mincellsize removed.  base it on side length, if needed
292       
293        #Protect against silly thresholds such as -1
294        if threshold < 1:
295            return
296       
297        if not self.children:               # Leaf cell
298            if self.count() > threshold :   
299                #Split is needed
300                points = self.retrieve()    # Get points from leaf cell
301                self.clear()                # and remove them from storage
302               
303                self.spawn()                # Spawn child cells and move
304                for p in points:            # points to appropriate child
305                    for child in self:
306                        if child.contains(p):
307                            child.insert(p) 
308                            break
309   
310        if self.children:                   # Parent cell
311            for child in self:              # split (possibly newly created)
312                child.split(threshold)      # child cells recursively
313               
314
315
316    def collapse(self,threshold=None):
317        """
318        collapse child cells into immediate parent if total number of contained waypoints
319        in subtree below is less than or equal to threshold.
320        All waypoints are then moved into parent cell and
321        children are removed. If self is a leaf node initially, do nothing.
322        """
323       
324        if threshold is None:
325            threshold = self.max_points_per_cell       
326
327
328        if self.children:                   # Parent cell   
329            if self.count() <= threshold:   # collapse
330                points = self.retrieve()    # Get all points from child cells
331                self.clear()                # Remove children, self is now a leaf node
332                self.insert(points)         # Insert all points in local storage
333            else:                         
334                for child in self:          # Check if any sub tree can be collapsed
335                    child.collapse(threshold)
336
337
338    def Get_tree(self,depth=0):
339        """Traverse tree below self
340           Print for each node the name and
341           if it is a leaf the number of objects
342        """
343        s = ''
344        if depth == 0:
345            s = '\n'
346           
347        s += "%s%s:" % ('  '*depth, self.name)
348        if self.children:
349            s += '\n'
350            for child in self.children:
351                s += child.Get_tree(depth+1)
352        else:
353            s += '(#wp=%d)\n' %(self.count())
354
355        return s
356
357       
358    def show(self, depth=0):
359        """Traverse tree below self
360           Print for each node the name and
361           if it is a leaf the number of objects
362        """
363        if depth == 0:
364            print 
365        print "%s%s" % ('  '*depth, self.name),
366        if self.children:
367            print
368            for child in self.children:
369                child.show(depth+1)
370        else:
371            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
372                  %(self.western, self.eastern,
373                    self.southern, self.northern,
374                    self.count()) 
375
376
377    def show_all(self,depth=0):
378        """Traverse tree below self
379           Print for each node the name and if it is a leaf all its objects
380        """
381        if depth == 0:
382            print 
383        print "%s%s:" % ('  '*depth, self.name),
384        if self.children:
385            print
386            for child in self.children:
387                child.show_all(depth+1)
388        else:
389            print '%s' %self.retrieve()
390
391
392    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
393        """Traverse tree below self and find minimal cell radius,
394           maximumtree depth and maximum number of waypoints per leaf.
395        """
396
397        if self.children:
398            for child in self.children:
399                min_rad, max_depth, max_points =\
400                         child.Stats(depth+1,min_rad,max_depth,max_points)
401        else:
402            #FIXME remvoe radius stuff
403            #min_rad = sys.maxint
404            #if self.radius < min_rad:   min_rad = self.radius
405            if depth > max_depth: max_depth = depth
406            num_points = self.count()
407            if num_points > max_points: max_points = num_points
408
409        #return min_rad, max_depth, max_points   
410        return max_depth, max_points   
411       
412
413    #Class initialisation method       
414    def initialise(cls, mesh):
415        cls.mesh = mesh
416
417    initialise = classmethod(initialise)
418
419def build_quadtree(mesh, max_points_per_cell = 4):
420    """Build quad tree for mesh.
421
422    All vertices in mesh are stored in quadtree and a reference to the root is returned.
423    """
424
425    from Numeric import minimum, maximum
426
427    #Initialise
428    Cell.initialise(mesh)
429
430    #Make root cell
431    #print mesh.coordinates
432
433    nodes = mesh.get_nodes()
434    xmin = min(nodes[:,0])
435    xmax = max(nodes[:,0])
436    ymin = min(nodes[:,1])
437    ymax = max(nodes[:,1])
438
439   
440    #Ensure boundary points are fully contained in region
441    #It is a property of the cell structure that points on xmax or ymax of any given cell
442    #belong to the neighbouring cell.
443    #Hence, the root cell needs to be expanded slightly
444    ymax += (ymax-ymin)/10
445    xmax += (xmax-xmin)/10
446
447    # To avoid round off error
448    ymin -= (ymax-ymin)/10
449    xmin -= (xmax-xmin)/10   
450
451    #print "xmin", xmin
452    #print "xmax", xmax
453    #print "ymin", ymin
454    #print "ymax", ymax
455   
456    #FIXME: Use mesh.filename if it exists
457    root = Cell(ymin, ymax, xmin, xmax,
458                #name = ....
459                max_points_per_cell = max_points_per_cell)
460
461    #root.show()
462   
463    #Insert indices of all vertices
464    root.insert( range(mesh.number_of_nodes) )
465
466    #Build quad tree and return
467    root.split()
468
469    return root
Note: See TracBrowser for help on using the repository browser.