/
adalead.py
175 lines (148 loc) · 6.25 KB
/
adalead.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Defines the Adalead explorer class."""
import random
from typing import Optional, Tuple
import numpy as np
import pandas as pd
import flexs
from flexs.utils import sequence_utils as s_utils
class Adalead(flexs.Explorer):
"""
Adalead explorer.
Algorithm works as follows:
Initialize set of top sequences whose fitnesses are at least
(1 - threshold) of the maximum fitness so far
While we can still make model queries in this batch
Recombine top sequences and append to parents
Rollout from parents and append to mutants.
"""
def __init__(
self,
model: flexs.Model,
rounds: int,
sequences_batch_size: int,
model_queries_per_batch: int,
starting_sequence: str,
alphabet: str,
mu: int = 1,
recomb_rate: float = 0,
threshold: float = 0.05,
rho: int = 0,
eval_batch_size: int = 20,
log_file: Optional[str] = None,
):
"""
Args:
mu: Expected number of mutations to the full sequence (mu/L per position).
recomb_rate: The probability of a crossover at any position in a sequence.
threshold: At each round only sequences with fitness above
(1-threshold)*f_max are retained as parents for generating next set of
sequences.
rho: The expected number of recombination partners for each recombinant.
eval_batch_size: For code optimization; size of batches sent to model.
"""
name = f"Adalead_mu={mu}_threshold={threshold}"
super().__init__(
model,
name,
rounds,
sequences_batch_size,
model_queries_per_batch,
starting_sequence,
log_file,
)
self.threshold = threshold
self.recomb_rate = recomb_rate
self.alphabet = alphabet
self.mu = mu # number of mutations per *sequence*.
self.rho = rho
self.eval_batch_size = eval_batch_size
def _recombine_population(self, gen):
# If only one member of population, can't do any recombining
if len(gen) == 1:
return gen
random.shuffle(gen)
ret = []
for i in range(0, len(gen) - 1, 2):
strA = []
strB = []
switch = False
for ind in range(len(gen[i])):
if random.random() < self.recomb_rate:
switch = not switch
# putting together recombinants
if switch:
strA.append(gen[i][ind])
strB.append(gen[i + 1][ind])
else:
strB.append(gen[i][ind])
strA.append(gen[i + 1][ind])
ret.append("".join(strA))
ret.append("".join(strB))
return ret
def propose_sequences(
self, measured_sequences: pd.DataFrame
) -> Tuple[np.ndarray, np.ndarray]:
"""Propose top `sequences_batch_size` sequences for evaluation."""
measured_sequence_set = set(measured_sequences["sequence"])
# Get all sequences within `self.threshold` percentile of the top_fitness
top_fitness = measured_sequences["true_score"].max()
top_inds = measured_sequences["true_score"] >= top_fitness * (
1 - np.sign(top_fitness) * self.threshold
)
parents = np.resize(
measured_sequences["sequence"][top_inds].to_numpy(),
self.sequences_batch_size,
)
sequences = {}
previous_model_cost = self.model.cost
while self.model.cost - previous_model_cost < self.model_queries_per_batch:
# generate recombinant mutants
for i in range(self.rho):
parents = self._recombine_population(parents)
for i in range(0, len(parents), self.eval_batch_size):
# Here we do rollouts from each parent (root of rollout tree)
roots = parents[i : i + self.eval_batch_size]
root_fitnesses = self.model.get_fitness(roots)
nodes = list(enumerate(roots))
while (
len(nodes) > 0
and self.model.cost - previous_model_cost + self.eval_batch_size
< self.model_queries_per_batch
):
child_idxs = []
children = []
while len(children) < len(nodes):
idx, node = nodes[len(children) - 1]
child = s_utils.generate_random_mutant(
node,
self.mu * 1 / len(node),
self.alphabet,
)
# Stop when we generate new child that has never been seen
# before
if (
child not in measured_sequence_set
and child not in sequences
):
child_idxs.append(idx)
children.append(child)
# Stop the rollout once the child has worse predicted
# fitness than the root of the rollout tree.
# Otherwise, set node = child and add child to the list
# of sequences to propose.
fitnesses = self.model.get_fitness(children)
sequences.update(zip(children, fitnesses))
nodes = []
for idx, child, fitness in zip(child_idxs, children, fitnesses):
if fitness >= root_fitnesses[idx]:
nodes.append((idx, child))
if len(sequences) == 0:
raise ValueError(
"No sequences generated. If `model_queries_per_batch` is small, try "
"making `eval_batch_size` smaller"
)
# We propose the top `self.sequences_batch_size` new sequences we have generated
new_seqs = np.array(list(sequences.keys()))
preds = np.array(list(sequences.values()))
sorted_order = np.argsort(preds)[: -self.sequences_batch_size : -1]
return new_seqs[sorted_order], preds[sorted_order]