source: inundation/ga/storm_surge/pyvolution/quad.py @ 643

Last change on this file since 643 was 611, checked in by duncan, 20 years ago

Optimised the least_squares algorithm for building A matrix

File size: 11.9 KB
Line 
1"""quad.py - quad tree data structure for fast indexing of points in the plane
2
3
4"""
5
6from treenode import TreeNode
7import string, types, sys
8
9#FIXME verts are added one at a time.
10#FIXME add max min x y in general_mesh
11
12class Cell(TreeNode):
13    """class Cell
14
15    One cell in the plane delimited by southern, northern,
16    western, eastern boundaries.
17
18    Public Methods:
19        prune()
20        insert(point)
21        search(x, y)
22        collapse()
23        split()
24        store()
25        retrieve()
26        count()
27    """
28 
29    def __init__(self, southern, northern, western, eastern, 
30                 name = 'cell',
31                 max_points_per_cell = 4):
32 
33        # Initialise base classes
34        TreeNode.__init__(self, string.lower(name))
35       
36        # Initialise cell
37        self.southern = round(southern,5)   
38        self.northern = round(northern,5)
39        self.western = round(western,5)   
40        self.eastern = round(eastern,5)
41
42        # The points in this cell     
43        self.points = []
44       
45        self.max_points_per_cell = max_points_per_cell
46       
47       
48    def __repr__(self):
49        return self.name 
50
51
52    def spawn(self):
53        """Create four child cells unless they already exist
54        """
55
56        if self.children:
57            return
58        else:
59            self.children = []
60
61        # convenience variables
62        cs = self.southern   
63        cn = self.northern
64        cw = self.western   
65        ce = self.eastern
66
67        # create 4 child cells
68        self.AddChild(Cell((cn+cs)/2,cn,cw,(cw+ce)/2,self.name+'_nw'))
69        self.AddChild(Cell((cn+cs)/2,cn,(cw+ce)/2,ce,self.name+'_ne'))
70        self.AddChild(Cell(cs,(cn+cs)/2,(cw+ce)/2,ce,self.name+'_se'))
71        self.AddChild(Cell(cs,(cn+cs)/2,cw,(cw+ce)/2,self.name+'_sw'))
72       
73 
74    def search(self, x, y):
75    #def search_new(self, x, y):
76        """Find all point indices sharing the same cell as point (x, y)
77        """
78        branch = []
79        points = []
80        if self.children:
81            for child in self:
82                if child.contains(x,y):
83                    brothers = list(self.children)
84                    brothers.remove(child)
85                    branch.append(brothers)
86                    points, branch = child.search_branch(x,y, branch)
87        else:
88            # Leaf node: Get actual waypoints
89            points = self.retrieve()
90
91        self.branch = branch   
92        return points
93
94
95    def search_branch(self, x, y, branch):
96        """Find all point indices sharing the same cell as point (x, y)
97        """
98        points = []
99        if self.children:
100            for child in self:
101                if child.contains(x,y):
102                    brothers = list(self.children)
103                    brothers.remove(child)
104                    branch.append(brothers)
105                    points, branch = child.search_branch(x,y, branch)
106                   
107        else:
108            # Leaf node: Get actual waypoints
109            points = self.retrieve()     
110        return points, branch
111
112
113    def expand_search(self):
114        """Find all point indices 'up' one cell from the last search
115        """
116        points = []
117        if self.branch == []:
118            points = []
119        else:
120            three_cells = self.branch.pop()
121            for cell in three_cells:
122                #print "cell ", cell.show()
123                points += cell.retrieve()
124        return points
125
126
127    def contains(*args):   
128        """True only if P's coordinates lie within cell boundaries
129        This methods has two forms:
130       
131        cell.contains(index)
132          #True if cell contains indexed point
133        cell.contains(x, y)
134          #True if cell contains point (x,y)   
135
136        """
137       
138        self = args[0]
139        if len(args) == 2:
140            point_id = int(args[1])
141            x = self.__class__.mesh.coordinates[point_id][0]
142            y = self.__class__.mesh.coordinates[point_id][1]
143
144            #print point_id, x, y
145        elif len(args) == 3:
146            x = float(args[1])
147            y = float(args[2])
148        else:
149            msg = 'Number of arguments to method must be two or three'
150            raise msg                         
151       
152        if y <  self.southern: return False
153        if y >= self.northern: return False
154        if x <  self.western:  return False
155        if x >= self.eastern:  return False
156        return True
157   
158   
159    def insert(self, points, split = False):
160        """insert point(s) in existing tree structure below self
161           and split if requested
162        """
163
164        # Call insert for each element of a list of points
165        if type(points) == types.ListType:
166            for point in points:
167                self.insert(point, split)
168        else:
169            #Only one point given as argument   
170            point = points
171       
172            # Find appropriate cell
173            if self.children is not None:
174                for child in self:
175                    if child.contains(point):
176                        child.insert(point, split)
177                        break
178            else:
179                # self is a leaf cell: insert point into cell
180                if self.contains(point):
181                    self.store(point)
182                else:
183                    raise 'point not in region: %s' %str(point)
184               
185               
186        #Split datastructure if requested       
187        if split is True:
188            self.split()
189               
190
191
192    def store(self,objects):
193       
194        if type(objects) not in [types.ListType,types.TupleType]:
195            self.points.append(objects)
196        else:
197            self.points.extend(objects)
198
199
200    def retrieve(self):
201         objects = []
202         if self.children is None:
203             objects = self.points
204         else: 
205             for child in self:
206                 objects += child.retrieve()
207         return objects 
208
209
210    def count(self, keywords=None):
211        """retrieve number of stored objects beneath this node inclusive
212        """
213       
214        num_waypoint = 0
215        if self.children:
216            for child in self:
217                num_waypoint = num_waypoint + child.count()
218        else:
219            num_waypoint = len(self.points)
220        return num_waypoint
221 
222
223    def clear(self):
224        self.Prune()   # TreeNode method
225
226
227    def clear_leaf_node(self):
228        """Clears storage in leaf node.
229        Called from Treenod.
230        Must exist.     
231        """
232        self.points = []
233       
234       
235    def clear_internal_node(self):
236        """Called from Treenode.   
237        Must exist.
238        """
239        pass
240
241
242
243    def split(self, threshold=None):
244        """
245        Partition cell when number of contained waypoints exceeds
246        threshold.  All waypoints are then moved into correct
247        child cell.
248        """
249        if threshold == None:
250           threshold = self.max_points_per_cell
251           
252        #FIXME, mincellsize removed.  base it on side length, if needed
253       
254        #Protect against silly thresholds such as -1
255        if threshold < 1:
256            return
257       
258        if not self.children:               # Leaf cell
259            if self.count() > threshold :   
260                #Split is needed
261                points = self.retrieve()    # Get points from leaf cell
262                self.clear()                # and remove them from storage
263               
264                self.spawn()                # Spawn child cells and move
265                for p in points:            # points to appropriate child
266                    for child in self:
267                        if child.contains(p):
268                            child.insert(p) 
269                            break
270   
271        if self.children:                   # Parent cell
272            for child in self:              # split (possibly newly created)
273                child.split(threshold)      # child cells recursively
274               
275
276
277    def collapse(self,threshold=None):
278        """
279        collapse child cells into immediate parent if total number of contained waypoints
280        in subtree below is less than or equal to threshold.
281        All waypoints are then moved into parent cell and
282        children are removed. If self is a leaf node initially, do nothing.
283        """
284       
285        if threshold is None:
286            threshold = self.max_points_per_cell       
287
288
289        if self.children:                   # Parent cell   
290            if self.count() <= threshold:   # collapse
291                points = self.retrieve()    # Get all points from child cells
292                self.clear()                # Remove children, self is now a leaf node
293                self.insert(points)         # Insert all points in local storage
294            else:                         
295                for child in self:          # Check if any sub tree can be collapsed
296                    child.collapse(threshold)
297
298
299    def Get_tree(self,depth=0):
300        """Traverse tree below self
301           Print for each node the name and
302           if it is a leaf the number of objects
303        """
304        s = ''
305        if depth == 0:
306            s = '\n'
307           
308        s += "%s%s:" % ('  '*depth, self.name)
309        if self.children:
310            s += '\n'
311            for child in self.children:
312                s += child.Get_tree(depth+1)
313        else:
314            s += '(#wp=%d)\n' %(self.count())
315
316        return s
317
318       
319    def show(self, depth=0):
320        """Traverse tree below self
321           Print for each node the name and
322           if it is a leaf the number of objects
323        """
324        if depth == 0:
325            print 
326        print "%s%s" % ('  '*depth, self.name),
327        if self.children:
328            print
329            for child in self.children:
330                child.show(depth+1)
331        else:
332            print '(xmin=%.2f, xmax=%.2f, ymin=%.2f, ymax=%.2f): [%d]'\
333                  %(self.western, self.eastern,
334                    self.southern, self.northern,
335                    self.count()) 
336
337
338    def show_all(self,depth=0):
339        """Traverse tree below self
340           Print for each node the name and if it is a leaf all its objects
341        """
342        if depth == 0:
343            print 
344        print "%s%s:" % ('  '*depth, self.name),
345        if self.children:
346            print
347            for child in self.children:
348                child.show_all(depth+1)
349        else:
350            print '%s' %self.retrieve()
351
352
353    def stats(self,depth=0,min_rad=sys.maxint,max_depth=0,max_points=0):
354        """Traverse tree below self and find minimal cell radius,
355           maximumtree depth and maximum number of waypoints per leaf.
356        """
357
358        if self.children:
359            for child in self.children:
360                min_rad, max_depth, max_points =\
361                         child.Stats(depth+1,min_rad,max_depth,max_points)
362        else:
363            #FIXME remvoe radius stuff
364            #min_rad = sys.maxint
365            #if self.radius < min_rad:   min_rad = self.radius
366            if depth > max_depth: max_depth = depth
367            num_points = self.count()
368            if num_points > max_points: max_points = num_points
369
370        #return min_rad, max_depth, max_points   
371        return max_depth, max_points   
372       
373
374    #Class initialisation method       
375    def initialise(cls, mesh):
376        cls.mesh = mesh
377
378    initialise = classmethod(initialise)
379
380def build_quadtree(mesh, max_points_per_cell = 4):
381    """Build quad tree for mesh.
382
383    All vertices in mesh are stored in quadtree and a reference to the root is returned.
384    """
385
386    from Numeric import minimum, maximum
387
388    #Initialise
389    Cell.initialise(mesh)
390
391    #Make root cell
392    #print mesh.coordinates
393   
394    xmin = min(mesh.coordinates[:,0])
395    xmax = max(mesh.coordinates[:,0])
396    ymin = min(mesh.coordinates[:,1])
397    ymax = max(mesh.coordinates[:,1])
398
399    #Ensure boundary points are fully contained in region
400    #It is a property of the cell structure that points on xmax or ymax of any given cell
401    #belong to the neighbouring cell.
402    #Hence, the root cell needs to be expanded slightly
403    ymax += (ymax-ymin)/10
404    xmax += (xmax-xmin)/10   
405
406
407    #FIXME: Use mesh.filename if it exists
408    root = Cell(ymin, ymax, xmin, xmax,
409                #name = ....
410                max_points_per_cell = max_points_per_cell)
411
412    #root.show()
413   
414    #Insert indices of all vertices
415    root.insert( range(len(mesh.coordinates)) )
416
417    #Build quad tree and return
418    root.split()
419
420    return root
Note: See TracBrowser for help on using the repository browser.