Skip to content

Commit

Permalink
Merge pull request #407 from jnwei/pl_upgrades
Browse files Browse the repository at this point in the history
Pytorch lightning upgrades
  • Loading branch information
jnwei committed Feb 19, 2024
2 parents df4dfac + f0fc7d9 commit 49ab053
Show file tree
Hide file tree
Showing 16 changed files with 379 additions and 213 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/undefined_names.yml
Expand Up @@ -5,7 +5,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/setup-python@v5
- run: pip install --upgrade pip
- run: pip install flake8
- run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
2 changes: 1 addition & 1 deletion .gitignore
Expand Up @@ -9,4 +9,4 @@ dist
data
openfold/resources/
tests/test_data/
cutlass
cutlass/
2 changes: 1 addition & 1 deletion Dockerfile
Expand Up @@ -13,7 +13,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/

RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \
"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" \
"https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \
&& bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
&& rm /tmp/Miniforge3-Linux-x86_64.sh
ENV PATH /opt/conda/bin:$PATH
Expand Down
4 changes: 2 additions & 2 deletions README.md
Expand Up @@ -351,7 +351,7 @@ python3 run_pretrained_openfold.py \
--output_dir ./ \
--model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt \
--openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt \
--uniref90_database_path data/uniref90/uniref90.fasta \
--pdb70_database_path data/pdb70/pdb70 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
Expand Down Expand Up @@ -595,4 +595,4 @@ If you use OpenProteinSet, please also cite:
primaryClass={q-bio.BM}
}
```
Any work that cites OpenFold should also cite AlphaFold.
Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
458 changes: 295 additions & 163 deletions notebooks/OpenFold.ipynb 100755 → 100644

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions notebooks/environment.yml
Expand Up @@ -3,15 +3,15 @@ channels:
- conda-forge
- bioconda
dependencies:
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- openmm=7.7
- pdbfixer
- ml-collections
- PyYAML==5.4.1
- requests
- typing-extensions
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
34 changes: 20 additions & 14 deletions openfold/data/data_pipeline.py
Expand Up @@ -21,14 +21,11 @@
from multiprocessing import cpu_count
import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import subprocess
import numpy as np
import torch
import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein

FeatureDict = MutableMapping[str, np.ndarray]
Expand Down Expand Up @@ -704,10 +701,10 @@ def __init__(
def _parse_msa_data(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None,
alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
msa_data = {}
if(alignment_index is not None):
if alignment_index is not None:
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")

def read_msa(start, size):
Expand All @@ -718,14 +715,14 @@ def read_msa(start, size):
for (name, start, size) in alignment_index["files"]:
filename, ext = os.path.splitext(name)

if(ext == ".a3m"):
if ext == ".a3m":
msa = parsers.parse_a3m(
read_msa(start, size)
)
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
# Multimer "uniprot_hits" processed separately.
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]):
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
msa = parsers.parse_stockholm(read_msa(start, size))
else:
continue
Expand All @@ -734,13 +731,22 @@ def read_msa(start, size):

fp.close()
else:
# Now will split the following steps into multiple processes
current_directory = os.path.dirname(os.path.abspath(__file__))
cmd = f"{current_directory}/tools/parse_msa_files.py"
msa_data_path = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data_path = msa_data_path.stdout.lstrip().rstrip()
msa_data = pickle.load((open(msa_data_path,'rb')))
os.remove(msa_data_path)
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
filename, ext = os.path.splitext(f)

if ext == ".a3m":
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
else:
continue

msa_data[f] = msa

return msa_data

Expand Down
4 changes: 2 additions & 2 deletions openfold/data/templates.py
Expand Up @@ -101,8 +101,8 @@ def empty_template_feats(n_res):
"template_all_atom_positions": np.zeros(
(0, n_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_domain_names": np.array([''.encode()], dtype=object),
"template_sequence": np.array([''.encode()], dtype=object),
"template_sum_probs": np.zeros((0, 1), dtype=np.float32),
}

Expand Down
10 changes: 5 additions & 5 deletions openfold/utils/multi_chain_permutation.py
Expand Up @@ -90,15 +90,15 @@ def get_optimal_transform(

def get_least_asym_entity_or_longest_length(batch, input_asym_id):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
First check how many subunit(s) one sequence has. Select the subunit that is less
common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
Args:
batch: in this funtion batch is the full ground truth features
input_asym_id: A list of aym_ids that are in the cropped input features
batch: in this function batch is the full ground truth features
input_asym_id: A list of asym_ids that are in the cropped input features
Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
Expand Down Expand Up @@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]

# If multiple entities have the least asym_id count, return those with the shortest length
# If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
Expand Down
2 changes: 1 addition & 1 deletion openfold/utils/script_utils.py
Expand Up @@ -123,7 +123,7 @@ def parse_fasta(data):
][1:]
tags, seqs = lines[::2], lines[1::2]

tags = [t.split()[0] for t in tags]
tags = [re.split('\W| \|', t)[0] for t in tags]

return tags, seqs

Expand Down
5 changes: 1 addition & 4 deletions run_pretrained_openfold.py
Expand Up @@ -63,10 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")

local_alignment_dir = os.path.join(
alignment_dir,
os.path.join(alignment_dir, tag),
)
local_alignment_dir = os.path.join(alignment_dir, tag)

if args.use_precomputed_alignments is None:
logger.info(f"Generating alignments for {tag}...")
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Expand Up @@ -113,10 +113,10 @@ def get_cuda_bare_metal_version(cuda_dir):

setup(
name='openfold',
version='1.0.1',
version='2.0.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind',
author_email='gahdritz@gmail.com',
author='OpenFold Team',
author_email='jennifer.wei@omsf.io',
license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]),
Expand Down
29 changes: 28 additions & 1 deletion tests/config.py
@@ -1,6 +1,31 @@
import ml_collections as mlc

consts = mlc.ConfigDict(

monomer_consts = mlc.ConfigDict(
{
"model": "model_1_ptm", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": False, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
"n_res": 22,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
"n_heads_extra_msa": 8,
"eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m": 256,
"c_z": 128,
"c_s": 384,
"c_t": 64,
"c_e": 64,
"msa_logits": 23, # monomer: 23, multimer: 22
"template_mmcif_dir": None # Set for test_multimer_datamodule
}
)

multimer_consts = mlc.ConfigDict(
{
"model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True
Expand All @@ -24,6 +49,8 @@
}
)

consts = monomer_consts

config = mlc.ConfigDict(
{
"data": {
Expand Down
5 changes: 1 addition & 4 deletions tests/test_deepspeed_evo_attention.py
Expand Up @@ -244,9 +244,6 @@ def test_compare_template_stack(self):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)

inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]

batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
template_feats = {
k: v for k, v in batch.items() if k.startswith("template_")
Expand Down Expand Up @@ -309,7 +306,7 @@ def test_compare_model(self):
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(torch.float32)
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
Expand Down
16 changes: 11 additions & 5 deletions tests/test_permutation.py
Expand Up @@ -21,7 +21,6 @@
merge_labels)


@unittest.skip("Tests need to be fixed post-refactor")
class TestPermutation(unittest.TestCase):
def setUp(self):
"""
Expand Down Expand Up @@ -65,10 +64,16 @@ def test_1_selecting_anchors(self):
'seq_length': torch.tensor([57])
}
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
self.assertIn(int(anchor_gt_asym), [1, 2])
self.assertNotIn(int(anchor_gt_asym), [3, 4, 5])
self.assertIn(int(anchor_pred_asym), [1, 2])
self.assertNotIn(int(anchor_pred_asym), [3, 4, 5])
anchor_gt_asym = int(anchor_gt_asym)
anchor_pred_asym = {int(i) for i in anchor_pred_asym}
expected_anchors = {1, 2}
expected_non_anchors = {3, 4, 5}

self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)

def test_2_permutation_pentamer(self):
batch = {
Expand Down Expand Up @@ -114,6 +119,7 @@ def test_2_permutation_pentamer(self):
self.assertIn(aligns, possible_outcome)
self.assertNotIn(aligns, wrong_outcome)

@unittest.skip("Test needs to be fixed post-refactor")
def test_3_merge_labels(self):
nres_pad = 325 - 57 # suppose the cropping size is 325
batch = {
Expand Down
1 change: 1 addition & 0 deletions train_openfold.py
Expand Up @@ -235,6 +235,7 @@ def configure_optimizers(self,

lr_scheduler = AlphaFoldLRScheduler(
optimizer,
last_epoch=self.last_lr_step
)

return {
Expand Down

0 comments on commit 49ab053

Please sign in to comment.