forked from hgfalling/pyquest
-
Notifications
You must be signed in to change notification settings - Fork 3
/
tree.py
316 lines (279 loc) · 9.59 KB
/
tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""
tree.py: Defines a tree structure for use in the questionnaire.
"""
import copy
class ClusterTreeNode(object):
def __init__(self,elements,parent=None):
self.parent = parent
self.elements = sorted(set(elements))
self.children = []
def __getitem__(self,key):
"""
Allows lookup of tree nodes by index, like row_tree[17].
"""
return self.nodes_list[key]
def __iter__(self):
"""
Allows iteration of the tree without using traverse(), ie:
for node in tree:
<do whatever>
"""
for x in self.nodes_list:
yield x
def __len__(self):
return self.tree_size
def iterkeys(self):
return self.__iter__()
def compare(self,other):
"""
Compares this tree to a different tree.
Returns True if for every node x in this tree, the set of elements
in x is equal to the set of elements in some node in the other tree.
"""
nodelists = [x.elements for x in other.traverse()]
for node in self.nodes_list:
if node.elements not in nodelists:
return False
return True
def __eq__(self,other):
return self.compare(other) and other.compare(self)
def create_subclusters(self,partition):
"""
Divides a tree node into pieces based on partition.
partition should be a group indicator for each element of the node.
So partition should have the same length as self.elements.
Useful for top-down clustering.
"""
assert len(partition) == len(self.elements)
p_elements = set(partition)
for subcluster in sorted(p_elements):
sc_elements = [x for (x,y) in zip(self.elements,partition)
if y == subcluster]
self.children.append(ClusterTreeNode(sc_elements,self))
def assign_to_parent(self,parent):
"""
Assigns a subcluster to a parent cluster.
Will add the elements of self to the elements of parent if they
are not already there.
Useful for bottom-up clustering.
"""
self.parent = parent
parent.children.append(self)
parent.elements.extend(self.elements)
parent.elements = sorted(set(parent.elements))
def traverse(self,floor_level=None):
"""
Performs a BFS traversal of the tree.
External use is deprecated now, use the iterator methods instead.
Left here for compatibility reasons, may become a _ method later.
Later note: okay to use for non-root nodes of the tree.
"""
#BFS
queue = []
traversal = []
queue.append(self)
while len(queue) > 0:
node = queue.pop(0)
traversal.append(node)
if floor_level is None:
queue.extend(node.children)
elif node.level <= floor_level - 1:
queue.extend(node.children)
traversal.sort(key=lambda x:x.level*1e10+min(x.elements))
return traversal
def dfs_leaves(self):
"""
Depth-first leaves search.
Returns all nodes in depth-first search order.
"""
traversal = []
if len(self.children) == 0:
traversal.append(self)
else:
for child in self.children:
traversal.extend(child.dfs_leaves())
return traversal
def dfs_level(self,level=None):
"""
Returns the set of all nodes at the level specified.
Also accepts negative indices (so the bottom level is -1)
"""
if level is None:
level = self.tree_depth
if level < 0:
level = self.tree_depth + level
traversal = []
if self.level == level:
traversal.append(self)
else:
for child in self.children:
traversal.extend(child.dfs_level(level))
return traversal
def leaves(self):
"""
Returns the set of all leaves.
"""
leaves_list = []
for node in self.nodes_list:
if len(node.children) == 0:
leaves_list.append(node)
return leaves_list
@property
def tree_size(self):
"""
Returns the total size of the tree rooted at this node.
"""
if self.parent is None:
return len([x for x in self.nodes_list])
else:
return len(self.traverse())
@property
def level(self):
"""
Returns the level of the tree at which this node sits.
Indexed starting at 1. This might be changed in the future.
"""
if self.parent is None:
return 1
else:
return 1+self.parent.level
@property
def tree_depth(self):
"""
Returns the depth in levels of the tree rooted at this node.
"""
if self.children == []:
return 1
else:
return 1 + self.children[0].tree_depth
@property
def size(self):
"""
Returns the size of this node (in elements).
"""
return len(self.elements)
def sublevel_elements(self,level):
"""
Returns a list of lists of the elements in level
"""
elist = []
for x in self.nodes_list:
if x.level + 1 - self.level == level:
elist.append(x.elements)
return elist
def level_nodes(self,level=None):
"""
Returns the set of all nodes at the level specified.
Also accepts negative indices (so the bottom level is -1)
"""
if level is None:
level = self.tree_depth
if level < 0:
level = self.tree_depth + level
return [x for x in self.nodes_list if x.level == level]
def make_index(self):
"""
Precalculates some things and makes the tree much easier to use.
Needs to be called after tree construction is finished in all cases.
"""
idx = 0
self.nodes_list = self.traverse()
for node in self.nodes_list:
node.idx = idx
idx += 1
def disp_tree(self):
"""
Prints out crude representation of tree structure by elements in folders.
No return value.
"""
for i in xrange(self.tree_depth):
print i,self.sublevel_elements(i+1)
def disp_tree_folder_sizes(self):
"""
Prints out crude representation of tree structure by folder sizes.
No return value.
"""
for i in xrange(self.tree_depth):
print i,sorted([len(x) for x in self.sublevel_elements(i+1)])
def folder_set(self,element):
"""
Returns the index set of all parents of element.
"""
return [x.idx for x in self.nodes_list if element in x.elements]
def level_partition(self,level):
"""
Returns the entire partition of the tree at the specified level
as an array of tree.size with the index of the partition containing
a particular point in each position.
"""
partition = [0]*self.size
els = self.sublevel_elements(level)
for (idx,l) in enumerate(els):
for m in l:
partition[m] = idx
return partition
def all_ancestors(self):
"""
Returns a list of the indices of all the ancestors of a given node.
"""
curnode = self
parents = []
while curnode.parent is not None:
parents.append(curnode.parent.idx)
curnode = curnode.parent
return parents
def all_descendants(self):
"""
Returns a list of the indices of all the descendants of a given node.
"""
if len(self.children) == 0:
return []
else:
return [x.idx for x in self.traverse()[1:]]
def tree_distance(self,i,j):
"""
Returns the tree distance between elements i and j.
"""
if i==j:
return 0.0
curnode = self
while curnode.parent is not None:
curnode = curnode.parent
tree_size = curnode.size
if i in self.elements and j in self.elements:
for child in self.children:
if i in child.elements and j in child.elements:
return child.tree_distance(i,j)
return 1.0*self.size/tree_size
def tree_distance_mat(self):
leaves_idx = self.leaves_idx(0)
sz = self.size
D = np.zeros([sz,sz])
for i in range(sz):
for j in range(i,sz):
D[i,j] = self.tree_distance(i,j)
D[j,i] = D[i,j]
return D
def copy(self):
return copy.deepcopy(self)
def leaves_idx(self, node_ind):
leaves = self.nodes_list[node_ind].dfs_leaves()
leave_idxs = [node.elements[0] for node in leaves]
return leave_idxs
def dyadic_tree(n):
"""
Generates the basic dyadic tree on 2**n elements
"""
elements = range(2**n)
tree_list = [ClusterTreeNode([element]) for element in elements]
tree_list2 = []
for _ in xrange(n):
while len(tree_list) > 0:
tree_list2.append(ClusterTreeNode([]))
tree_list[0].assign_to_parent(tree_list2[-1])
tree_list[1].assign_to_parent(tree_list2[-1])
tree_list = tree_list[2:]
tree_list = tree_list2
tree_list2 = []
tree_list[0].make_index()
return tree_list[0]