source: anuga_core/source/anuga/utilities/quad.py @ 4779

Last change on this file since 4779 was 4779, checked in by duncan, 17 years ago

reduce memory use in quantity.set_value. fit_to_mesh can now use an existing mesh instance, which it does in quantity.set_value.

File size: 14.8 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8
9#FIXME verts are added one at a time.
10#FIXME add max min x y in general_mesh
11
12class Cell(TreeNode):
13    """class Cell
14
15    One cell in the plane delimited by southern, northern,
16    western, eastern boundaries.
17
18    Public Methods:
19        prune()
20        insert(point)
21        search(x, y)
22        collapse()
23        split()
24        store()
25        retrieve()
26        count()
27    """
28 
29    def __init__(self, southern, northern, western, eastern, mesh,
30                 name = 'cell',
31                 max_points_per_cell = 4):
32 
33        # Initialise base classes
34        TreeNode.__init__(self, string.lower(name))
35       
36        # Initialise cell
37        self.southern = round(southern,5)   
38        self.northern = round(northern,5)
39        self.western = round(western,5)   
40        self.eastern = round(eastern,5)
41        self.mesh = mesh
42
43        # The points in this cell     
44        self.points = []
45       
46        self.max_points_per_cell = max_points_per_cell
47       
48       
49    def __repr__(self):
50        return self.name 
51
52
53    def spawn(self):
54        """Create four child cells unless they already exist
55        """
56
57        if self.children:
58            return
59        else:
60            self.children = []
61
62        # convenience variables
63        cs = self.southern   
64        cn = self.northern
65        cw = self.western   
66        ce = self.eastern   
67        mesh = self.mesh
68
69        # create 4 child cells
70        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,mesh,self.name+'_nw'))
71        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,mesh,self.name+'_ne'))
72        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,mesh,self.name+'_se'))
73        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,mesh,self.name+'_sw'))
74       
75 
76    def search(self, x, y, get_vertices=False):
77        """Find all point indices sharing the same cell as point (x, y)
78        """
79        branch = []
80        points = []
81        if self.children:
82            for child in self:
83                if child.contains(x,y):
84                    brothers = list(self.children)
85                    brothers.remove(child)
86                    branch.append(brothers)
87                    points, branch = child.search_branch(x,y, branch,
88                                                  get_vertices=get_vertices)
89        else:
90            # Leaf node: Get actual waypoints
91            points = self.retrieve(get_vertices=get_vertices)
92        self.branch = branch   
93        return points
94
95
96    def search_branch(self, x, y, branch, get_vertices=False):
97        """Find all point indices sharing the same cell as point (x, y)
98        """
99        points = []
100        if self.children:
101            for child in self:
102                if child.contains(x,y):
103                    brothers = list(self.children)
104                    brothers.remove(child)
105                    branch.append(brothers)
106                    points, branch = child.search_branch(x,y, branch,
107                                                  get_vertices=get_vertices)
108                   
109        else:
110            # Leaf node: Get actual waypoints
111            points = self.retrieve(get_vertices=get_vertices)     
112        return points, branch
113
114
115    def expand_search(self, get_vertices=False):
116        """Find all point indices 'up' one cell from the last search
117        """
118       
119        points = []
120        if self.branch == []:
121            points = []
122        else:
123            three_cells = self.branch.pop()
124            for cell in three_cells:
125                #print "cell ", cell.show()
126                points += cell.retrieve(get_vertices=get_vertices)
127        return points, self.branch
128
129
130    def contains(*args):   
131        """True only if P's coordinates lie within cell boundaries
132        This methods has two forms:
133       
134        cell.contains(index)
135          #True if cell contains indexed point
136        cell.contains(x, y)
137          #True if cell contains point (x,y)   
138
139        """
140       
141        self = args[0]
142        if len(args) == 2:
143            point_id = int(args[1])
144            x, y = self.mesh.get_node(point_id, absolute=True)
145
146            #print point_id, x, y
147        elif len(args) == 3:
148            x = float(args[1])
149            y = float(args[2])
150        else:
151            msg = 'Number of arguments to method must be two or three'
152            raise msg                         
153       
154        if y <  self.southern: return False
155        if y >= self.northern: return False
156        if x <  self.western:  return False
157        if x >= self.eastern:  return False
158        return True
159   
160   
161    def insert(self, points, split = False):
162        """insert point(s) in existing tree structure below self
163           and split if requested
164        """
165
166        # Call insert for each element of a list of points
167        if type(points) == types.ListType:
168            for point in points:
169                self.insert(point, split)
170        else:
171            #Only one point given as argument   
172            point = points
173       
174            # Find appropriate cell
175            if self.children is not None:
176                for child in self:
177                    if child.contains(point):
178                        child.insert(point, split)
179                        break
180            else:
181                # self is a leaf cell: insert point into cell
182                if self.contains(point):
183                    self.store(point)
184                else:
185                    # Have to take into account of georef.
186                    #x = self.mesh.coordinates[point][0]
187                    #y = self.mesh.coordinates[point][1]
188                    node = self.mesh.get_node(point, absolute=True)
189                    print "node", node
190                    print "(" + str(node[0]) + "," + str(node[1]) + ")"
191                    raise 'point not in region: %s' %str(point)
192               
193               
194        #Split datastructure if requested       
195        if split is True:
196            self.split()
197               
198
199
200    def store(self,objects):
201       
202        if type(objects) not in [types.ListType,types.TupleType]:
203            self.points.append(objects)
204        else:
205            self.points.extend(objects)
206
207
208    def retrieve_triangles(self):
209        """return a list of lists. For the inner lists,
210        The first element is the triangle index,
211        the second element is a list.for this list
212           the first element is a list of three (x, y) vertices,
213           the following elements are the three triangle normals.
214
215        This info is used in searching for a triangle that a point is in.
216
217        Post condition
218        No more points can be added to the quad tree, since the
219        points data structure is removed.
220        """
221        # FIXME Tidy up the structure that is returned.
222        # if the triangles att has been made
223        # return it.
224        if not hasattr(self,'triangles'):
225            # use a dictionary to remove duplicates
226            triangles = {}
227            verts = self.retrieve_vertices()
228            # print "verts", verts
229            for vert in verts:
230                triangle_list = self.mesh.get_triangles_and_vertices_per_node(vert)
231                for k, _ in triangle_list:
232                    if not triangles.has_key(k):
233                        # print 'k',k
234                        tri = self.mesh.get_vertex_coordinates(k,
235                                                               absolute=True)
236                        n0 = self.mesh.get_normal(k, 0)
237                        n1 = self.mesh.get_normal(k, 1)
238                        n2 = self.mesh.get_normal(k, 2) 
239                        triangles[k]=(tri, (n0, n1, n2))
240            self.triangles = triangles.items()
241            # Delete the old cell data structure to save memory
242            del self.points
243        return self.triangles
244           
245    def retrieve_vertices(self):
246         return self.points
247
248
249    def retrieve(self, get_vertices=True):
250         objects = []
251         if self.children is None:
252             if get_vertices is True:
253                 objects = self.retrieve_vertices()
254             else:
255                 objects =  self.retrieve_triangles()
256         else: 
257             for child in self:
258                 objects += child.retrieve(get_vertices=get_vertices)
259         return objects
260       
261
262    def count(self, keywords=None):
263        """retrieve number of stored objects beneath this node inclusive
264        """
265       
266        num_waypoint = 0
267        if self.children:
268            for child in self:
269                num_waypoint = num_waypoint + child.count()
270        else:
271            num_waypoint = len(self.points)
272        return num_waypoint
273 
274
275    def clear(self):
276        self.Prune()   # TreeNode method
277
278
279    def clear_leaf_node(self):
280        """Clears storage in leaf node.
281        Called from Treenod.
282        Must exist.     
283        """
284        self.points = []
285       
286       
287    def clear_internal_node(self):
288        """Called from Treenode.   
289        Must exist.
290        """
291        pass
292
293
294
295    def split(self, threshold=None):
296        """
297        Partition cell when number of contained waypoints exceeds
298        threshold.  All waypoints are then moved into correct
299        child cell.
300        """
301        if threshold == None:
302           threshold = self.max_points_per_cell
303           
304        #FIXME, mincellsize removed.  base it on side length, if needed
305       
306        #Protect against silly thresholds such as -1
307        if threshold < 1:
308            return
309       
310        if not self.children:               # Leaf cell
311            if self.count() > threshold :   
312                #Split is needed
313                points = self.retrieve()    # Get points from leaf cell
314                self.clear()                # and remove them from storage
315                   
316                self.spawn()                # Spawn child cells and move
317                for p in points:            # points to appropriate child
318                    for child in self:
319                        if child.contains(p):
320                            child.insert(p) 
321                            break
322                       
323        if self.children:                   # Parent cell
324            for child in self:              # split (possibly newly created)
325                child.split(threshold)      # child cells recursively
326               
327
328
329    def collapse(self,threshold=None):
330        """
331        collapse child cells into immediate parent if total number of contained waypoints
332        in subtree below is less than or equal to threshold.
333        All waypoints are then moved into parent cell and
334        children are removed. If self is a leaf node initially, do nothing.
335        """
336       
337        if threshold is None:
338            threshold = self.max_points_per_cell       
339
340
341        if self.children:                   # Parent cell   
342            if self.count() <= threshold:   # collapse
343                points = self.retrieve()    # Get all points from child cells
344                self.clear()                # Remove children, self is now a leaf node
345                self.insert(points)         # Insert all points in local storage
346            else:                         
347                for child in self:          # Check if any sub tree can be collapsed
348                    child.collapse(threshold)
349
350
351    def Get_tree(self,depth=0):
352        """Traverse tree below self
353           Print for each node the name and
354           if it is a leaf the number of objects
355        """
356        s = ''
357        if depth == 0:
358            s = '\n'
359           
360        s += "%s%s:" % ('  '*depth, self.name)
361        if self.children:
362            s += '\n'
363            for child in self.children:
364                s += child.Get_tree(depth+1)
365        else:
366            s += '(#wp=%d)\n' %(self.count())
367
368        return s
369
370       
371    def show(self, depth=0):
372        """Traverse tree below self
373           Print for each node the name and
374           if it is a leaf the number of objects
375        """
376        if depth == 0:
377            print 
378        print "%s%s" % ('  '*depth, self.name),
379        if self.children:
380            print
381            for child in self.children:
382                child.show(depth+1)
383        else:
384            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
385                  %(self.western, self.eastern,
386                    self.southern, self.northern,
387                    self.count()) 
388
389
390    def show_all(self,depth=0):
391        """Traverse tree below self
392           Print for each node the name and if it is a leaf all its objects
393        """
394        if depth == 0:
395            print 
396        print "%s%s:" % ('  '*depth, self.name),
397        if self.children:
398            print
399            for child in self.children:
400                child.show_all(depth+1)
401        else:
402            print '%s' %self.retrieve()
403
404
405    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
406        """Traverse tree below self and find minimal cell radius,
407           maximumtree depth and maximum number of waypoints per leaf.
408        """
409
410        if self.children:
411            for child in self.children:
412                min_rad, max_depth, max_points =\
413                         child.Stats(depth+1,min_rad,max_depth,max_points)
414        else:
415            #FIXME remvoe radius stuff
416            #min_rad = sys.maxint
417            #if self.radius < min_rad:   min_rad = self.radius
418            if depth > max_depth: max_depth = depth
419            num_points = self.count()
420            if num_points > max_points: max_points = num_points
421
422        #return min_rad, max_depth, max_points   
423        return max_depth, max_points   
424       
425
426    #Class initialisation method
427    # this is bad.  It adds a huge memory structure to the class.
428    # When the instance is deleted the mesh hangs round (leaks).
429    #def initialise(cls, mesh):
430    #    cls.mesh = mesh
431
432    #initialise = classmethod(initialise)
433
434def build_quadtree(mesh, max_points_per_cell = 4):
435    """Build quad tree for mesh.
436
437    All vertices in mesh are stored in quadtree and a reference
438    to the root is returned.
439    """
440
441    from Numeric import minimum, maximum
442
443
444    #Make root cell
445    #print mesh.coordinates
446
447   
448    nodes = mesh.get_nodes(absolute=True)
449    xmin = min(nodes[:,0])
450    xmax = max(nodes[:,0])
451    ymin = min(nodes[:,1])
452    ymax = max(nodes[:,1])
453
454    # Don't know why this didn't work
455    #xmin, xmax, ymin, ymax = mesh.get_extent(absolute=True)
456   
457    # Ensure boundary points are fully contained in region
458    # It is a property of the cell structure that
459    # points on xmax or ymax of any given cell
460    # belong to the neighbouring cell.
461    # Hence, the root cell needs to be expanded slightly
462    ymax += (ymax-ymin)/10
463    xmax += (xmax-xmin)/10
464
465    # To avoid round off error
466    ymin -= (ymax-ymin)/10
467    xmin -= (xmax-xmin)/10   
468
469    #print "xmin", xmin
470    #print "xmax", xmax
471    #print "ymin", ymin
472    #print "ymax", ymax
473   
474    #FIXME: Use mesh.filename if it exists
475    root = Cell(ymin, ymax, xmin, xmax,mesh,
476                #name = ....
477                max_points_per_cell = max_points_per_cell)
478
479    #root.show()
480   
481    #Insert indices of all vertices
482    root.insert( range(mesh.number_of_nodes) )
483
484    #Build quad tree and return
485    root.split()
486
487    return root
Note: See TracBrowser for help on using the repository browser.