source: inundation/ga/storm_surge/pyvolution/quad.py @ 527

Last change on this file since 527 was 484, checked in by ole, 20 years ago

First stab at using quad trees in least_squares.
Unit tests pass and least squares produce results
much faster now.

File size: 10.8 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8
9#FIXME verts are added one at a time.
10#FIXME add max min x y in general_mesh
11
12class Cell(TreeNode):
13    """class Cell
14
15    One cell in the plane delimited by southern, northern,
16    western, eastern boundaries.
17
18    Public Methods:
19        prune()
20        insert(point)
21        search(x, y)
22        collapse()
23        split()
24        store()
25        retrieve()
26        count()
27    """
28 
29    def __init__(self, southern, northern, western, eastern, 
30                 name = 'cell',
31                 max_points_per_cell = 4):
32 
33        # Initialise base classes
34        TreeNode.__init__(self, string.lower(name))
35       
36        # Initialise cell
37        self.southern = round(southern,5)   
38        self.northern = round(northern,5)
39        self.western = round(western,5)   
40        self.eastern = round(eastern,5)
41
42        # The points in this cell     
43        self.points = []
44       
45        self.max_points_per_cell = max_points_per_cell
46       
47       
48    def __repr__(self):
49        return self.name 
50
51
52    def spawn(self):
53        """Create four child cells unless they already exist
54        """
55
56        if self.children:
57            return
58        else:
59            self.children = []
60
61        # convenience variables
62        cs = self.southern   
63        cn = self.northern
64        cw = self.western   
65        ce = self.eastern
66
67        # create 4 child cells
68        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,self.name+'_nw'))
69        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,self.name+'_ne'))
70        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,self.name+'_se'))
71        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,self.name+'_sw'))
72     
73 
74
75    def search(self, x, y):
76        """Find all point indices sharing the same cell as point (x, y)
77        """
78     
79        points = []
80        if self.children:
81            for child in self:
82                if child.contains(x,y):
83                     points += child.search(x,y)
84        else:
85            # Leaf node: Get actual waypoints
86            points = self.retrieve()
87           
88        return points
89
90
91
92    def contains(*args):   
93        """True only if P's coordinates lie within cell boundaries
94        This methods has two forms:
95       
96        cell.contains(index)
97          #True if cell contains indexed point
98        cell.contains(x, y)
99          #True if cell contains point (x,y)   
100
101        """
102       
103        self = args[0]
104        if len(args) == 2:
105            point_id = int(args[1])
106            x = self.__class__.mesh.coordinates[point_id][0]
107            y = self.__class__.mesh.coordinates[point_id][1]
108
109            #print point_id, x, y
110        elif len(args) == 3:
111            x = float(args[1])
112            y = float(args[2])
113        else:
114            msg = 'Number of arguments to method must be two or three'
115            raise msg                         
116       
117        if y <  self.southern: return False
118        if y >= self.northern: return False
119        if x <  self.western:  return False
120        if x >= self.eastern:  return False
121        return True
122   
123   
124    def insert(self, points, split = False):
125        """insert point(s) in existing tree structure below self
126           and split if requested
127        """
128
129        # Call insert for each element of a list of points
130        if type(points) == types.ListType:
131            for point in points:
132                self.insert(point, split)
133        else:
134            #Only one point given as argument   
135            point = points
136       
137            # Find appropriate cell
138            if self.children is not None:
139                for child in self:
140                    if child.contains(point):
141                        child.insert(point, split)
142                        break
143            else:
144                # self is a leaf cell: insert point into cell
145                if self.contains(point):
146                    self.store(point)
147                else:
148                    raise 'point not in region: %s' %str(point)
149               
150               
151        #Split datastructure if requested       
152        if split is True:
153            self.split()
154               
155
156
157    def store(self,objects):
158       
159        if type(objects) not in [types.ListType,types.TupleType]:
160            self.points.append(objects)
161        else:
162            self.points.extend(objects)
163
164
165    def retrieve(self):
166         objects = []
167         if self.children is None:
168             objects = self.points
169         else: 
170             for child in self:
171                 objects += child.retrieve()
172         return objects 
173
174
175    def count(self, keywords=None):
176        """retrieve number of stored objects beneath this node inclusive
177        """
178       
179        num_waypoint = 0
180        if self.children:
181            for child in self:
182                num_waypoint = num_waypoint + child.count()
183        else:
184            num_waypoint = len(self.points)
185        return num_waypoint
186 
187
188    def clear(self):
189        self.Prune()   # TreeNode method
190
191
192    def clear_leaf_node(self):
193        """Clears storage in leaf node.
194        Called from Treenod.
195        Must exist.     
196        """
197        self.points = []
198       
199       
200    def clear_internal_node(self):
201        """Called from Treenode.   
202        Must exist.
203        """
204        pass
205
206
207
208    def split(self, threshold=None):
209        """
210        Partition cell when number of contained waypoints exceeds
211        threshold.  All waypoints are then moved into correct
212        child cell.
213        """
214        if threshold == None:
215           threshold = self.max_points_per_cell
216           
217        #FIXME, mincellsize removed.  base it on side length, if needed
218       
219        #Protect against silly thresholds such as -1
220        if threshold < 1:
221            return
222       
223        if not self.children:               # Leaf cell
224            if self.count() > threshold :   
225                #Split is needed
226                points = self.retrieve()    # Get points from leaf cell
227                self.clear()                # and remove them from storage
228               
229                self.spawn()                # Spawn child cells and move
230                for p in points:            # points to appropriate child
231                    for child in self:
232                        if child.contains(p):
233                            child.insert(p) 
234                            break
235   
236        if self.children:                   # Parent cell
237            for child in self:              # split (possibly newly created)
238                child.split(threshold)      # child cells recursively
239               
240
241
242    def collapse(self,threshold=None):
243        """
244        collapse child cells into immediate parent if total number of contained waypoints
245        in subtree below is less than or equal to threshold.
246        All waypoints are then moved into parent cell and
247        children are removed. If self is a leaf node initially, do nothing.
248        """
249       
250        if threshold is None:
251            threshold = self.max_points_per_cell       
252
253
254        if self.children:                   # Parent cell   
255            if self.count() <= threshold:   # collapse
256                points = self.retrieve()    # Get all points from child cells
257                self.clear()                # Remove children, self is now a leaf node
258                self.insert(points)         # Insert all points in local storage
259            else:                         
260                for child in self:          # Check if any sub tree can be collapsed
261                    child.collapse(threshold)
262
263
264    def Get_tree(self,depth=0):
265        """Traverse tree below self
266           Print for each node the name and
267           if it is a leaf the number of objects
268        """
269        s = ''
270        if depth == 0:
271            s = '\n'
272           
273        s += "%s%s:" % ('  '*depth, self.name)
274        if self.children:
275            s += '\n'
276            for child in self.children:
277                s += child.Get_tree(depth+1)
278        else:
279            s += '(#wp=%d)\n' %(self.count())
280
281        return s
282
283       
284    def show(self, depth=0):
285        """Traverse tree below self
286           Print for each node the name and
287           if it is a leaf the number of objects
288        """
289        if depth == 0:
290            print 
291        print "%s%s" % ('  '*depth, self.name),
292        if self.children:
293            print
294            for child in self.children:
295                child.show(depth+1)
296        else:
297            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
298                  %(self.western, self.eastern,
299                    self.southern, self.northern,
300                    self.count()) 
301
302
303    def show_all(self,depth=0):
304        """Traverse tree below self
305           Print for each node the name and if it is a leaf all its objects
306        """
307        if depth == 0:
308            print 
309        print "%s%s:" % ('  '*depth, self.name),
310        if self.children:
311            print
312            for child in self.children:
313                child.show_all(depth+1)
314        else:
315            print '%s' %self.retrieve()
316
317
318    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
319        """Traverse tree below self and find minimal cell radius,
320           maximumtree depth and maximum number of waypoints per leaf.
321        """
322
323        if self.children:
324            for child in self.children:
325                min_rad, max_depth, max_points =\
326                         child.Stats(depth+1,min_rad,max_depth,max_points)
327        else:
328            #FIXME remvoe radius stuff
329            #min_rad = sys.maxint
330            #if self.radius < min_rad:   min_rad = self.radius
331            if depth > max_depth: max_depth = depth
332            num_points = self.count()
333            if num_points > max_points: max_points = num_points
334
335        #return min_rad, max_depth, max_points   
336        return max_depth, max_points   
337       
338
339    #Class initialisation method       
340    def initialise(cls, mesh):
341        cls.mesh = mesh
342
343    initialise = classmethod(initialise)
344
345def build_quadtree(mesh, max_points_per_cell = 4):
346    """Build quad tree for mesh.
347
348    All vertices in mesh are stored in quadtree and a reference to the root is returned.
349    """
350
351    from Numeric import minimum, maximum
352
353    #Initialise
354    Cell.initialise(mesh)
355
356    #Make root cell
357    #print mesh.coordinates
358   
359    xmin = min(mesh.coordinates[:,0])
360    xmax = max(mesh.coordinates[:,0])
361    ymin = min(mesh.coordinates[:,1])
362    ymax = max(mesh.coordinates[:,1])
363
364    #Ensure boundary points are fully contained in region
365    #It is a property of the cell structure that points on xmax or ymax of any given cell
366    #belong to the neighbouring cell.
367    #Hence, the root cell needs to be expanded slightly
368    ymax += (ymax-ymin)/10
369    xmax += (xmax-xmin)/10     
370
371
372    #FIXME: Use mesh.filename if it exists
373    root = Cell(ymin, ymax, xmin, xmax,
374                #name = ....
375                max_points_per_cell = max_points_per_cell)
376
377    #root.show()
378   
379    #Insert indices of all vertices
380    root.insert( range(len(mesh.coordinates)) )
381
382    #Build quad tree and return
383    root.split()
384
385    return root
Note: See TracBrowser for help on using the repository browser.