Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add use_custom_templates option #408

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ where `data` is the same directory as in the previous step. If `jackhmmer`,
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query, you have the option to
skip the expensive alignment computation here with
`--use_precomputed_alignments`.
`--use_precomputed_alignments`. If you wish to use a specific template as input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this information to a different subheading of it's own? Maybe called "Custom Templates"?
Do mention that it'll be the same template(s) that will be used if multiple separate sequences are passed for inference.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments. The others are fulfilled without any question.
I am not sure if I understood you correctly about this one here though.

Now, I changed the README, so custom templates stands a bit more on it's own, but still in the monomer inference section.

Should I make a new section on the same level as monomer inference, multimer inference or soloseq?

you can use the argument `--use_custom_template`, which then will read all .cif
files in `template_mmcif_dir`. Make sure the chains of interest have the identifier _A_
and have the same length as the input sequence.

`--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files,
Expand Down
19 changes: 16 additions & 3 deletions openfold/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,19 @@
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np
import torch
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data import (
templates,
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
feature_processing_multimer,
)
from openfold.data.templates import (
get_custom_template_features,
empty_template_feats,
CustomHitFeaturizer,
)
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.np import residue_constants, protein

Expand All @@ -38,7 +49,9 @@ def make_template_features(
template_featurizer: Any,
) -> FeatureDict:
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None):
if template_featurizer is None or (
len(hits_cat) == 0 and not isinstance(template_featurizer, CustomHitFeaturizer)
):
template_features = empty_template_feats(len(input_sequence))
else:
templates_result = template_featurizer.get_templates(
Expand Down
109 changes: 68 additions & 41 deletions openfold/data/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
import logging
import os
from pathlib import Path
import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple

Expand Down Expand Up @@ -947,49 +948,58 @@ def _process_single_hit(


def get_custom_template_features(
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str):

with open(mmcif_path, "r") as mmcif_path:
cif_string = mmcif_path.read()

mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]


mapping = {x:x for x, _ in enumerate(query_sequence)}


features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]

# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []

for k in template_features:
template_features[k].append(features[k])
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: Optional[str] = "A",
kalign_binary_path: Optional[str] = None,
):
if os.path.isfile(mmcif_path):
template_paths = [Path(mmcif_path)]

for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
elif os.path.isdir(mmcif_path):
template_paths = list(Path(mmcif_path).glob("*.cif"))
else:
logging.error("Custom template path %s does not exist", mmcif_path)
raise ValueError(f"Custom template path {mmcif_path} does not exist")
warnings = []
template_features = dict()
for template_path in template_paths:
logging.info("Featurizing template: %s", template_path)
# pdb_id only for error reporting, take file name
pdb_id = Path(template_path).stem
with open(template_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
# chain_id defaults to A, should be changed?
rostro36 marked this conversation as resolved.
Show resolved Hide resolved
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x: x for x, _ in enumerate(template_sequence)}

curr_features, curr_warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True,
)
curr_features["template_sum_probs"] = [1.0]
template_features = {
curr_name: template_features.get(curr_name, []) + [curr_item]
for curr_name, curr_item in curr_features.items()
}
warnings.append(curr_warnings)

template_features = {
template_feature_name: np.stack(
template_features[template_feature_name], axis=0
).astype(template_feature_type)
for template_feature_name, template_feature_type in TEMPLATE_FEATURES.items()
}
return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings
)
Expand Down Expand Up @@ -1188,6 +1198,23 @@ def get_templates(
)


class CustomHitFeaturizer(TemplateHitFeaturizer):
"""Featurizer for templates given in folder.
Chain of interest has to be chain A and of same residue size as input sequence."""
rostro36 marked this conversation as resolved.
Show resolved Hide resolved
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info("Featurizing mmcif_dir: %s", self._mmcif_dir)
return get_custom_template_features(
self._mmcif_dir,
query_sequence=query_sequence,
pdb_id="test",
chain_id="A",
kalign_binary_path=self._kalign_binary_path,
)
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates(
self,
Expand Down
20 changes: 13 additions & 7 deletions run_pretrained_openfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,15 @@ def main(args):
)

is_multimer = "multimer" in args.config_preset

if is_multimer:
is_custom_template = "use_custom_template" in args and args.use_custom_template
if is_custom_template:
template_featurizer = templates.CustomHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date="9999-12-31", # just dummy, not used
max_hits=-1, # just dummy, not used
kalign_binary_path=args.kalign_binary_path
)
elif is_multimer:
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
Expand All @@ -205,11 +212,9 @@ def main(args):
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)

data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)

if is_multimer:
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor,
Expand All @@ -222,7 +227,6 @@ def main(args):

np.random.seed(random_seed)
torch.manual_seed(random_seed + 1)

feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
Expand Down Expand Up @@ -292,7 +296,6 @@ def main(args):
)

feature_dicts[tag] = feature_dict

processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=is_multimer
)
Expand Down Expand Up @@ -379,6 +382,10 @@ def main(args):
help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored."""
)
parser.add_argument(
"--use_custom_template", action="store_true", default=False,
help="""Use mmcif given with "template_mmcif_dir" argument as template input."""
)
parser.add_argument(
"--use_single_seq_mode", action="store_true", default=False,
help="""Use single sequence embeddings instead of MSAs."""
Expand Down Expand Up @@ -466,5 +473,4 @@ def main(args):
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)

main(args)