Skip to content

Commit

Permalink
Fixed bug in triangle multiplicative update and added early stop recy…
Browse files Browse the repository at this point in the history
…cling.
  • Loading branch information
christinaflo committed Jun 2, 2023
1 parent 425bdb5 commit c1129be
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 79 deletions.
43 changes: 34 additions & 9 deletions openfold/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,38 @@ def model_config(
c.loss.tm.weight = 0.1
elif "multimer" in name:
c.globals.is_multimer = True
c.globals.bfloat16 = True
c.globals.bfloat16_output = False
c.loss.masked_msa.num_classes = 22
c.data.common.max_recycling_iters = 20

for k,v in multimer_model_config_update.items():
c.model[k] = v

# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
c.model.evoformer.num_msa = 252
c.model.evoformer.num_extra_msa= 1152
c.model.evoformer.fuse_projection_weights = False
#c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_msa_clusters = 252
c.data.predict.max_msa_clusters = 252
c.data.train.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
c.model.evoformer_stack.fuse_projection_weights = False
c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False
c.model.template.template_pair_stack.fuse_projection_weights = False
elif name == 'model_4_multimer_v3':
c.model.evoformer.num_extra_msa = 1152
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
elif name == 'model_5_multimer_v3':
c.model.evoformer.num_extra_msa = 1152

for k,v in multimer_model_config_update.items():
c.model[k] = v
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
else:
c.data.train.max_msa_clusters = 508
c.data.predict.max_msa_clusters = 508
c.data.train.max_extra_msa = 2048
c.data.predict.max_extra_msa = 2048

c.data.common.unsupervised_features.extend([
"msa_mask",
Expand Down Expand Up @@ -646,13 +663,20 @@ def model_config(
"eps": eps,
},
"ema": {"decay": 0.999},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance": -1
}
)

multimer_model_config_update = {
"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
#"num_msa": 508,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
Expand Down Expand Up @@ -703,6 +727,7 @@ def model_config(
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
#"num_extra_msa": 2048
},
"extra_msa_stack": {
"c_m": c_e,
Expand Down Expand Up @@ -788,5 +813,5 @@ def model_config(
"c_out": 37,
},
},

"recycle_early_stop_tolerance": 0.5
}
47 changes: 18 additions & 29 deletions openfold/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def run_msa_tool(
else:
result = msa_runner.query(fasta_path)[0]

assert msa_out_path.split('.')[-1] == msa_format
with open(msa_out_path, "w") as f:
f.write(result[msa_format])

Expand Down Expand Up @@ -321,6 +322,7 @@ def make_sequence_features_with_custom_template(
**template_features.features
}


class AlignmentRunner:
"""Runs alignment tools and saves the results"""
def __init__(
Expand Down Expand Up @@ -372,6 +374,8 @@ def __init__(
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
uniprot_max_hits:
Max number of uniprot hits
"""
db_map = {
"jackhmmer": {
Expand Down Expand Up @@ -468,7 +472,7 @@ def run(
):
"""Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None):
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.sto")

jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
Expand Down Expand Up @@ -505,7 +509,7 @@ def run(
)

if(self.jackhmmer_mgnify_runner is not None):
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.sto")
jackhmmer_mgnify_result = run_msa_tool(
msa_runner=self.jackhmmer_mgnify_runner,
fasta_path=fasta_path,
Expand Down Expand Up @@ -719,16 +723,14 @@ def read_msa(start, size):
msa = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
elif(ext == ".sto" and not "hmm_output" == filename):
msa = parsers.parse_stockholm(read_msa(start, size))
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
else:
continue

msa_data[name] = data
msa_data[name] = msa

fp.close()
else:
Expand All @@ -739,17 +741,15 @@ def read_msa(start, size):
if(ext == ".a3m"):
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
elif(ext == ".sto" and not "hmm_output" == filename):
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
else:
continue

msa_data[f] = data
msa_data[f] = msa

return msa_data

Expand Down Expand Up @@ -831,8 +831,6 @@ def read_template(start, size):
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits

return

def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
Expand All @@ -849,24 +847,19 @@ def _get_msas(self,
)

deletion_matrix = [[0 for _ in input_sequence]]
msa_data["dummy"] = {
"msa": parsers.Msa(sequences=input_sequence, deletion_matrix=deletion_matrix, descriptions=None),
"deletion_matrix": deletion_matrix,
}

msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
msa_data["dummy"] = parsers.Msa(sequences=input_sequence,
deletion_matrix=deletion_matrix,
descriptions=None)

return msas, deletion_matrices
return list(msa_data.values())

def _process_msa_feats(
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas(
msas = self._get_msas(
alignment_dir, input_sequence, alignment_index
)
msa_features = make_msa_features(
Expand Down Expand Up @@ -944,7 +937,6 @@ def process_mmcif(
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
alignment_index)

template_features = make_template_features(
Expand Down Expand Up @@ -994,7 +986,6 @@ def process_pdb(

hits = self._parse_template_hits(
alignment_dir,
input_sequence,
alignment_index
)

Expand Down Expand Up @@ -1080,11 +1071,11 @@ def process_multiseq_fasta(self,
alignment_dir = os.path.join(
super_alignment_dir, desc
)
msas, deletion_mats = self._get_msas(
msas = self._get_msas(
alignment_dir, seq, None
)
msa_list.append(msas)
deletion_mat_list.append(deletion_mats)
msa_list.append([m.sequences for m in msas])
deletion_mat_list.append([m.deletion_matrix for m in msas])

final_msa = []
final_deletion_mat = []
Expand Down Expand Up @@ -1181,12 +1172,10 @@ def _process_single_chain(

def _all_seq_msa_features(self, fasta_path, alignment_dir):
"""Get MSA features for unclustered uniprot, for pairing."""
#TODO: Quick fix, change back to .sto after parsing fixed
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.a3m")
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_a3m(uniprot_msa_string)
#msa = parsers.parse_stockholm(uniprot_msa_string)
msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
Expand Down
4 changes: 2 additions & 2 deletions openfold/data/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def _process_single_hit(
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.sum_probs if hit.sum_probs else 0.,
hit.index,
str(e),
parsing_result.errors,
Expand All @@ -919,7 +919,7 @@ def _process_single_hit(
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.sum_probs if hit.sum_probs else 0.,
hit.index,
str(e),
parsing_result.errors,
Expand Down
14 changes: 5 additions & 9 deletions openfold/model/evoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,14 @@ def forward(self,
_attn_chunk_size=_attn_chunk_size
)

m = input_tensors[0]
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
device = z.device
del m, z
assert (sys.getrefcount(input_tensors[0]) == 2)
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
m, _ = input_tensors
else:
m = input_tensors[0]

return m, z

Expand Down Expand Up @@ -713,12 +711,10 @@ def fn(input_tensors):
if (_offload_inference and inplace_safe):
# m: GPU, z: GPU
device = z.device
del m, z
del m
assert (sys.getrefcount(input_tensors[0]) == 2)
assert (sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
m, _ = input_tensors

return m, z

Expand Down

0 comments on commit c1129be

Please sign in to comment.