/
twisst_node_depth.py
348 lines (243 loc) · 13.9 KB
/
twisst_node_depth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import matplotlib.pyplot as plt
import ete3
import numpy as np
np.seterr(divide='ignore', invalid='ignore')
import wquantiles
import itertools,argparse,gzip,sys
def getMidPos_method2(pos1, pos2, depth1, depth2, depth3): return((1.*pos2*(depth2-depth1) + pos1*(depth3-depth1))/(depth3-2*depth1+depth2))
def getNodePos(node, method = 1):
children = node.get_children()
if method == 1:
c1,c2 = children
assert len(children)==2, "Position method one only works for bifurcating nodes."
return((1.*c2.pos*(c1.depth-node.depth)+c1.pos*(c2.depth-node.depth))/(c2.depth-2*node.depth+c1.depth))
elif method == 2:
return(np.mean([c.pos for c in node.get_children()]))
else: raise "Position method can only be 1 or 2."
def drawTree(tree, leafPos = None, depthDict = None, depthRangeDict=None, extendTips=False, rootIsZero=False,
show=True, posMethod=1, col="black",linewidth=2,alpha=1,direction="down",taxColDict=None):
tree = tree.copy("newick")
#get node depths
for node in tree.traverse():
if depthDict is None: node.add_feature("depth", node.get_distance(tree))
else: node.add_feature("depth", depthDict[node.name])
#extend tips to align them if needed
if extendTips:
maxDP = max([l.depth for l in tree.iter_leaves()])
for l in tree.iter_leaves(): l.depth = maxDP
#adjust depths so that they are aligned at zero
if not rootIsZero:
maxDP = max([l.depth for l in tree.iter_leaves()])
for node in tree.traverse():
node.depth -= maxDP
#set leaf positions
if leafPos is None:
leafNames = [l.name for l in tree.get_leaves()]
leafPos = dict(zip(leafNames,range(len(leafNames))))
for leaf in tree.iter_leaves(): leaf.add_feature("pos", leafPos[leaf.name])
#set positions for all other nodes relative to their children
for node in tree.traverse(strategy="postorder"):
if not node.is_leaf():
node.add_feature("pos", getNodePos(node,method=posMethod))
#draw
for node in tree.traverse():
if direction is "down":
plt.setp(plt.gca(),xticks=[])
for child in node.get_children():
plt.plot([node.pos,child.pos],[node.depth,child.depth],color=col,linewidth=linewidth,alpha=alpha, solid_capstyle="round")
if depthRangeDict:
plt.plot([node.pos]*2,depthRangeDict[node.name],color=col,linewidth=1,alpha=alpha, solid_capstyle="round")
plt.plot([node.pos-.1,node.pos+.1],[node.depth]*2,color=col,linewidth=1,alpha=alpha, solid_capstyle="round")
if node.is_leaf(): plt.text(node.pos, node.depth - 0.1, node.name,
horizontalalignment='center', verticalalignment='center',
color=taxColDict[node.name] if taxColDict else "black")
else:
plt.setp(plt.gca(),yticks=[])
for child in node.get_children():
plt.plot([node.depth,child.depth],[node.pos,child.pos],color=col,linewidth=linewidth,alpha=alpha, solid_capstyle="round")
if depthRangeDict:
plt.plot(depthRangeDict[node.name],[node.pos]*2,color=col,linewidth=1,alpha=alpha, solid_capstyle="round")
plt.plot([node.depth]*2,[node.pos-.1,node.pos+.1],color=col,linewidth=1,alpha=alpha, solid_capstyle="round")
if node.is_leaf(): plt.text(node.depth - 0.1, node.pos, node.name,
horizontalalignment='left', verticalalignment='center',
color=taxColDict[node.name] if taxColDict else "black")
if show: plt.show()
def subset(things,subLen):
starts = range(0,len(things),subLen)
ends = [start+subLen for start in starts]
return [things[starts[i]:ends[i]] for i in range(len(starts))]
def asColumn(a):
return a[:,np.newaxis]
def addNodeNames(tree):
for n in tree.traverse():
if n.name is not None: n.name = "_".join(sorted(n.get_leaf_names()))
allIsNaN = lambda x: np.all(np.isnan(x))
def normLength(tree, outgroup):
l = 1.*tree.get_distance(outgroup)
for node in tree.traverse(): node.dist /= l
def addNodeNames(tree):
for n in tree.traverse():
if n.name is "": n.name = "_".join(sorted(n.get_leaf_names()))
def treeToParentChildTable(tree):
return [(n.up.name,n.name,n.dist) for n in tree.traverse() if n.up is not None]
def getLeafPairs(node):
assert len(node.children) == 2, "Node {} does not have two children.".format(node.name)
return itertools.product(node.children[0].get_leaf_names(), node.children[1].get_leaf_names())
def make2DarrayFrom1DupperTriangle(upperTriangle1D,N,includesDiagnol=False):
a = np.zeros([N,N])
n=N if includesDiagnol else N-1
indices = list(np.triu_indices(n))
if not includesDiagnol: indices[1]+=1
a[indices] = a[indices[::-1]] = upperTriangle1D
return a
########################## plot topos with average branch lengths
parser = argparse.ArgumentParser()
parser.add_argument("-w", "--weightsFiles", help="Input weights file(s) from Twisst", action = "store", nargs= "+", required = True)
parser.add_argument("-d", "--distsFiles", help="Input dists file(s) from Twisst", action = "store", nargs = "+", required = True)
parser.add_argument("-f", "--figFile", help="File for output figure", action = "store", required = True)
parser.add_argument("--figFormat", help="Format of figFile", action = "store", default="pdf")
parser.add_argument("--figSize", help="Size of figFile", action = "store", nargs=2, type=float, default=(10,10,))
parser.add_argument("--posMethod", help="Node positioning method", choices=(1,2,), type=int, action = "store", default = 2)
parser.add_argument("--quantiles", help="Add quantiles for each node in tree", type=float, nargs = 2, action = "store")
parser.add_argument("--plotTaxa", help="Prune tree to include on the specifed taxa", nargs = "+", action = "store")
parser.add_argument("--taxOrder", choices = ("levelorder", "preorder", "postorder", "predefined"), action = "store", default="levelorder",
help="How to determine order of taxa in plots")
parser.add_argument("--lineWidth", help="Width for tree lines", type=float, action = "store", default= 4)
parser.add_argument("--scaleLinesByWeights", help="Scale tree lines in figure by weights", action = "store_true")
parser.add_argument("--orderByWeights", help="Order tree plots by weights", action = "store_true")
parser.add_argument("--cols", help="Topology colours", nargs = "+", action = "store")
parser.add_argument("--taxCols", help="Taxon name colours", nargs = "+", action = "store")
parser.add_argument("--alpha", help="Topology alpha", type=float, action = "store", default=1.)
parser.add_argument("--layout", help="Rows and columns to plot", nargs=2, type=int, action = "store")
parser.add_argument("--tight", help="Pading for tight edges", nargs=2, type=float, action = "store")
args = parser.parse_args()
sys.stderr.write("\nReading distances file...")
dists = np.vstack([np.loadtxt(f, skiprows=1) for f in args.distsFiles])
#get topologies
sys.stderr.write("\nGetting topologies...")
topos = []
with gzip.open(args.weightsFiles[0], "r") as wf:
while True:
try: topos.append(ete3.Tree(wf.readline().split()[-1]))
except: break
nTopos = len(topos)
for t in topos: addNodeNames(t)
nTaxa = len(topos[0].get_leaves())
sys.stderr.write("\nThere are {} topologies and {} taxa".format(nTopos,nTaxa))
#make a separate set of topologies for plotting
plotTopos = [t.copy("newick") for t in topos]
if args.plotTaxa:
for t in plotTopos: t.prune(args.plotTaxa)
if args.taxOrder == "predefined":
assert args.plotTaxa is not None, "Predefined taxa order must be given using --plotTaxa."
taxOrder = [args.plotTaxa]*nTopos
else:
taxOrder = [[node.name for node in topo.traverse(strategy=args.taxOrder) if node.is_leaf()] for topo in plotTopos]
if args.taxCols:
try: taxColDict = dict(zip(args.plotTaxa,args.taxCols))
except: raise ValueError("To plot coloured taxon labels, you must specify names of taxa using --plotTaxa")
else:
taxColDict = None
if args.layout: nRow,nCol = args.layout
elif nTopos == 3: nRow,nCol = (1,3,)
elif nTopos == 15: nRow,nCol = (3,5,)
elif nTopos == 105: nRow,nCol = (3,5,)
else: raise ValueError("Please specify number of rows and columns in plot using --layout")
#get pair names in the dists file. The order is essential here.
#We use the order of the first N headers, but assume that the rest follow the same pattern.
#for example, if the taxaare called A, B C and D, the headers should be:
#Topo1_A_B Topo1_A_C Topo1_A_D Topo1_B_C Topo1_B_D Topo1_C_D Topo2_A_B Topo2_A_C ... ect
with gzip.open(args.distsFiles[0], "r") as df: pairs = df.readline().split()
pairs = [pair.split("_")[1:] for pair in pairs][:nTopos]
taxonNames = pairs[0] + [pair[1] for pair in pairs[1:nTaxa]]
#get columns for dists separated by topology
topo_column_indices = subset(range(dists.shape[1]), dists.shape[1]/nTopos)
#split topologies into a third dimension
dists = np.dstack([dists[:,i] for i in topo_column_indices])
dists = np.swapaxes(dists,1,2)
#set all missing dists to zero. Necessary for averaging, and doesnt impact results, because weighting for these is zero
dists = np.nan_to_num(dists)
############# weights
#read weights and convert to proportions
sys.stderr.write("\nReading weights file...")
weights = np.vstack([np.loadtxt(f, skiprows=nTopos+1) for f in args.weightsFiles])
assert weights.shape[0] == dists.shape[0]
assert weights.shape[1] == dists.shape[1]
#convert to proportions
rowSums = np.apply_along_axis(np.sum, 1, weights)
weights = weights / np.reshape(rowSums,[len(rowSums),1])
#convert any nan to zero
weights = np.nan_to_num(weights)
#get means
meanWeights = np.apply_along_axis(np.mean, 0, weights)
# now we need to get the average distance between leaves for each node in each topo
# the first step here is to get the two sets of leaves that descend from each node
nodes_all = [list(tree.traverse()) for tree in topos]
nodeNames = [[n.name for n in nodes] for nodes in nodes_all]
nodeLeafPairs = [[zip(*[(taxonNames.index(x),taxonNames.index(y),) for x,y in getLeafPairs(node)]) if not node.is_leaf() else None for node in nodes] for nodes in nodes_all]
#make a detpths array that gives the depth of each node for each topo at each window
depths = np.zeros([dists.shape[0], dists.shape[1], len(nodeNames[0])])
#now we go line by line, topology by topology and retrieve the depth
# as the avergae pairwise distance between all leaf pairs for each node
# unless the node is a leaf, in which case depth is zero.
sys.stderr.write("\nComputing depth for each node for each topology for each line in input...")
for x in range(depths.shape[0]):
for y in range(nTopos):
distMat = make2DarrayFrom1DupperTriangle(dists[x,y,:], nTaxa)
for z in range(len(nodeNames[y])):
depths[x,y,z] = distMat[nodeLeafPairs[y][z]].mean() if not nodes_all[y][z].is_leaf() else 0.0
#scale depths by dividing by the depth of the root
#the first node in each topo is the root, as the traversal goes to the root first
depths = depths / np.repeat(depths[:,:,0,np.newaxis], depths.shape[2], axis=2)
#anyehere we have nan is where the root depth was zero. This happens where we had missing data. So we can set all these tree depths to zero.
depths = np.nan_to_num(depths)
depths_average = np.average(depths, axis = 0, weights=np.repeat(weights[:,:,np.newaxis], depths.shape[2], axis=2))
depths_median = [[wquantiles.median(depths[:,j,k], weights=weights[:,j]) for k in range(depths.shape[2])] for j in range(depths.shape[1])]
if args.quantiles:
depths_qL = [[wquantiles.quantile(depths[:,j,k], weights[:,j], args.quantiles[0]) for k in range(depths.shape[2])] for j in range(depths.shape[1])]
depths_qU = [[wquantiles.quantile(depths[:,j,k], weights[:,j], args.quantiles[1]) for k in range(depths.shape[2])] for j in range(depths.shape[1])]
#cols = np.array([
#"#2BCE48", #Green
#"#005C31", #Forest
#"#94FFB5", #Jade
#"#9DCC00", #Lime
#"#426600", #Quagmire
#"#00998F", #Turquoise
#"#5EF1F2", #Sky
#"#0075DC", #Blue
#"#003380", #Navy
#"#740AFF", #Violet
#"#FF5005", #Zinnia
#"#F0A3FF", #Amethyst
#"#FFA405", #Orpiment
#"#FF0010", #Red
#"#C20088"]) #Mallow
cols = args.cols if args.cols else ["#000000"]*nTopos
lineWidths = np.array([args.lineWidth]*nTopos, dtype=float)
if args.scaleLinesByWeights: lineWidths *= (meanWeights/meanWeights.max())
plotOrder = np.argsort(meanWeights)[::-1] if args.orderByWeights else range(nTopos)
sys.stderr.write("\nMaking plot.")
plt.figure(figsize=args.figSize, frameon=False)
for i in range(len(plotOrder)):
plt.subplot(nRow,nCol,i+1)
x = plotOrder[i]
for y in np.arange(0,1.1,0.1): plt.plot([0,len(taxOrder[x])+1],[y,y],color="#CCCCCC")
drawTree(plotTopos[x], leafPos = dict(zip(taxOrder[x],np.arange(1,len(taxOrder[x])+1))),
depthDict = dict(zip(nodeNames[x],depths_median[x])),
depthRangeDict = dict(zip(nodeNames[x],zip(depths_qL[x],depths_qU[x]))) if args.quantiles else None,
show=False, alpha = args.alpha, posMethod = args.posMethod,
linewidth = lineWidths[x], col=cols[x], taxColDict=taxColDict)
axes = plt.gca()
axes.set_ylim([-0.1,1.1])
axes.set_xlim([0.5,len(taxOrder[x])+.5])
axes.spines["top"].set_visible(False)
axes.spines["right"].set_visible(False)
axes.spines["bottom"].set_visible(False)
axes.spines["left"].set_visible(False)
plt.text(1,.95,"T"+str(x+1),color=cols[x],
horizontalalignment='left', verticalalignment='center', size=12)
#plt.show()
if args.tight: plt.tight_layout(h_pad=args.tight[0], w_pad = args.tight[1])
plt.savefig(args.figFile, format=args.figFormat, figsize=args.figSize, frameon=False)
plt.close()
sys.stderr.write("\nDone.\n")