Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Duplicate expansion support #419

Merged
merged 9 commits into from
May 13, 2024
83 changes: 67 additions & 16 deletions scripts/alignment_db_scripts/create_alignment_db_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
run on the output index. Additionally this script uses threading and
multiprocessing and is much faster than the old version.
"""

import argparse
import json
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import json
from math import ceil
from multiprocessing import cpu_count
from pathlib import Path
from typing import List

from tqdm import tqdm
from math import ceil


def split_file_list(file_list, n_shards):
def split_file_list(file_list: list[Path], n_shards: int):
"""
Split up the total file list into n_shards sublists.
"""
Expand All @@ -29,26 +31,25 @@ def split_file_list(file_list, n_shards):
return split_list


def chunked_iterator(lst, chunk_size):
def chunked_iterator(lst: list, chunk_size: int):
"""Iterate over a list in chunks of size chunk_size."""
for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size]


def read_chain_dir(chain_dir) -> dict:
def read_chain_dir(chain_dir: Path) -> dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want, you can also add typing to the dict, i.e. for this specific example, you could write dict[str, Tuple[str, bytes]]

They can sometimes be helpful to know what to expect. For this specific function I don't think it's as needed / helpful.

"""
Read all alignment files in a single chain directory and return a dict
mapping chain name to file names and bytes.
"""
if not chain_dir.is_dir():
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")

# ensure that PDB IDs are all lowercase
pdb_id, chain = chain_dir.name.split("_")
pdb_id = pdb_id.lower()
chain_name = f"{pdb_id}_{chain}"



file_data = []

for file_path in sorted(chain_dir.iterdir()):
Expand All @@ -62,7 +63,7 @@ def read_chain_dir(chain_dir) -> dict:
return {chain_name: file_data}


def process_chunk(chain_files: List[Path]) -> dict:
def process_chunk(chain_files: list[Path]) -> dict:
"""
Returns the file names and bytes for all chains in a chunk of files.
"""
Expand All @@ -83,7 +84,7 @@ def create_index_default_dict() -> dict:


def create_shard(
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int
shard_files: list[Path], output_dir: Path, output_name: str, shard_num: int
) -> dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps having a type alias for the index dict could be helpful.

If the index file structure is also used in the main library code, then I would consider adding the alias to the main openfold library. But this can also be done in a later PR.

"""
Creates a single shard of the alignment database, and returns the
Expand All @@ -101,7 +102,11 @@ def create_shard(
db_offset = 0
db_file = open(output_path, "wb")
for files_chunk in tqdm(
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False
chunk_iter,
total=ceil(len(shard_files) / CHUNK_SIZE),
desc=pbar_desc,
position=shard_num,
leave=False,
):
# get processed files for one chunk
chunk_data = process_chunk(files_chunk)
Expand All @@ -125,9 +130,17 @@ def create_shard(
def main(args):
alignment_dir = args.alignment_dir
output_dir = args.output_db_path
output_dir.mkdir(exist_ok=True, parents=True)
output_db_name = args.output_db_name
n_shards = args.n_shards

n_cpus = cpu_count()
if n_shards > n_cpus:
print(
f"Warning: Your number of shards ({n_shards}) is greater than the number of cores on your machine ({n_cpus}). "
"This may result in slower performance. Consider using a smaller number of shards."
)

# get all chain dirs in alignment_dir
print("Getting chain directories...")
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])
Expand All @@ -153,12 +166,36 @@ def main(args):
super_index.update(shard_index)
print("\nCreated all shards.")

if args.duplicate_chains_file:
print("Extending super index with duplicate chains...")
duplicates_added = 0
with open(args.duplicate_chains_file, "r") as fp:
duplicate_chains = [line.strip().split() for line in fp]

for chains in duplicate_chains:
# find representative with alignment
for chain in chains:
if chain in super_index:
representative_chain = chain
break
else:
print(f"No representative chain found for {chains}, skipping...")
continue

# add duplicates to index
for chain in chains:
if chain != representative_chain:
super_index[chain] = super_index[representative_chain]
duplicates_added += 1

print(f"Added {duplicates_added} duplicate chains to index.")

# write super index to file
print("\nWriting super index...")
index_path = output_dir / f"{output_db_name}.index"
with open(index_path, "w") as fp:
json.dump(super_index, fp, indent=4)

print("Done.")


Expand All @@ -179,13 +216,27 @@ def main(args):
parser.add_argument(
"alignment_dir",
type=Path,
help="""Path to precomputed alignment directory, with one subdirectory
per chain.""",
help="""Path to precomputed flattened alignment directory, with one
subdirectory per chain.""",
)
parser.add_argument("output_db_path", type=Path)
parser.add_argument("output_db_name", type=str)
parser.add_argument(
"n_shards", type=int, help="Number of shards to split the database into"
"--n_shards",
type=int,
help="Number of shards to split the database into",
default=10,
)
parser.add_argument(
"--duplicate_chains_file",
type=Path,
help="""
Optional path to file containing duplicate chain information, where each
line contains chains that are 100% sequence identical. If provided,
duplicate chains will be added to the index and point to the same
underlying database entry as their representatives in the alignment dir.
""",
default=None,
)

args = parser.parse_args()
Expand Down
75 changes: 75 additions & 0 deletions scripts/expand_roda_duplicates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
The RODA database is non-redundant, meaning that it only stores one explicit
representative alignment directory for all PDB chains in a 100% sequence
identity cluster. In order to add explicit alignments for all PDB chains, this
script will add the missing chain directories and symlink them to their
representative alignment directories.
"""

from argparse import ArgumentParser
from pathlib import Path

from tqdm import tqdm


def create_duplicate_dirs(duplicate_chains: list[list[str]], alignment_dir: Path):
"""
Create duplicate directory symlinks for all chains in the given duplicate lists.

Args:
duplicate_lists (list[list[str]]): A list of lists, where each inner list
contains chains that are 100% sequence identical.
alignment_dir (Path): Path to flattened alignment directory, with one
subdirectory per chain.
"""
print("Creating duplicate directory symlinks...")
dirs_created = 0
for chains in tqdm(duplicate_chains):
# find the chain that has an alignment
for chain in chains:
if (alignment_dir / chain).exists():
representative_chain = chain
break
else:
print(f"No representative chain found for {chains}, skipping...")
continue

# create symlinks for all other chains
for chain in chains:
if chain != representative_chain:
target_path = alignment_dir / chain
if target_path.exists():
print(f"Chain {chain} already exists, skipping...")
else:
(target_path).symlink_to(alignment_dir / representative_chain)
dirs_created += 1

print(f"Created directories for {dirs_created} duplicate chains.")


def main(alignment_dir: Path, duplicate_chains_file: Path):
# read duplicate chains file
with open(duplicate_chains_file, "r") as fp:
duplicate_chains = [list(line.strip().split()) for line in fp]

create_duplicate_dirs(duplicate_chains, alignment_dir)


if __name__ == "__main__":
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"alignment_dir",
type=Path,
help="""Path to flattened alignment directory, with one subdirectory
per chain.""",
)
parser.add_argument(
"duplicate_chains_file",
type=Path,
help="""Path to file containing duplicate chains, where each line
contains a space-separated list of chains that are 100%%
sequence identical.
""",
)
args = parser.parse_args()
main(args.alignment_dir, args.duplicate_chains_file)