-
Notifications
You must be signed in to change notification settings - Fork 460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Duplicate expansion support #419
Merged
jnwei
merged 9 commits into
setup-improvements
from
setup-improvements_additional-scripts
May 13, 2024
Merged
Changes from 4 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
77860bb
Improve type hints and formatting
ljarosch e678050
Add default shard number
ljarosch ee0c5db
Add duplicate chain file support to alignment DB script
ljarosch 94819bf
Add script for expanding the alignment dir with duplicates
ljarosch e2479cb
Add more efficient script to generate all-seqs FASTA
ljarosch 0b5c949
Give script more descriptive name
ljarosch 244970b
Slightly improve comment
ljarosch 78b9706
Set CLI description to more informative module docstring
ljarosch 04410d5
Improve import formatting
ljarosch File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,17 +5,19 @@ | |
run on the output index. Additionally this script uses threading and | ||
multiprocessing and is much faster than the old version. | ||
""" | ||
|
||
import argparse | ||
import json | ||
from collections import defaultdict | ||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed | ||
import json | ||
from math import ceil | ||
from multiprocessing import cpu_count | ||
from pathlib import Path | ||
from typing import List | ||
|
||
from tqdm import tqdm | ||
from math import ceil | ||
|
||
|
||
def split_file_list(file_list, n_shards): | ||
def split_file_list(file_list: list[Path], n_shards: int): | ||
""" | ||
Split up the total file list into n_shards sublists. | ||
""" | ||
|
@@ -29,26 +31,25 @@ def split_file_list(file_list, n_shards): | |
return split_list | ||
|
||
|
||
def chunked_iterator(lst, chunk_size): | ||
def chunked_iterator(lst: list, chunk_size: int): | ||
"""Iterate over a list in chunks of size chunk_size.""" | ||
for i in range(0, len(lst), chunk_size): | ||
yield lst[i : i + chunk_size] | ||
|
||
|
||
def read_chain_dir(chain_dir) -> dict: | ||
def read_chain_dir(chain_dir: Path) -> dict: | ||
""" | ||
Read all alignment files in a single chain directory and return a dict | ||
mapping chain name to file names and bytes. | ||
""" | ||
if not chain_dir.is_dir(): | ||
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}") | ||
|
||
# ensure that PDB IDs are all lowercase | ||
pdb_id, chain = chain_dir.name.split("_") | ||
pdb_id = pdb_id.lower() | ||
chain_name = f"{pdb_id}_{chain}" | ||
|
||
|
||
|
||
file_data = [] | ||
|
||
for file_path in sorted(chain_dir.iterdir()): | ||
|
@@ -62,7 +63,7 @@ def read_chain_dir(chain_dir) -> dict: | |
return {chain_name: file_data} | ||
|
||
|
||
def process_chunk(chain_files: List[Path]) -> dict: | ||
def process_chunk(chain_files: list[Path]) -> dict: | ||
""" | ||
Returns the file names and bytes for all chains in a chunk of files. | ||
""" | ||
|
@@ -83,7 +84,7 @@ def create_index_default_dict() -> dict: | |
|
||
|
||
def create_shard( | ||
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int | ||
shard_files: list[Path], output_dir: Path, output_name: str, shard_num: int | ||
) -> dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps having a type alias for the index dict could be helpful. If the index file structure is also used in the main library code, then I would consider adding the alias to the main openfold library. But this can also be done in a later PR. |
||
""" | ||
Creates a single shard of the alignment database, and returns the | ||
|
@@ -101,7 +102,11 @@ def create_shard( | |
db_offset = 0 | ||
db_file = open(output_path, "wb") | ||
for files_chunk in tqdm( | ||
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False | ||
chunk_iter, | ||
total=ceil(len(shard_files) / CHUNK_SIZE), | ||
desc=pbar_desc, | ||
position=shard_num, | ||
leave=False, | ||
): | ||
# get processed files for one chunk | ||
chunk_data = process_chunk(files_chunk) | ||
|
@@ -125,9 +130,17 @@ def create_shard( | |
def main(args): | ||
alignment_dir = args.alignment_dir | ||
output_dir = args.output_db_path | ||
output_dir.mkdir(exist_ok=True, parents=True) | ||
output_db_name = args.output_db_name | ||
n_shards = args.n_shards | ||
|
||
n_cpus = cpu_count() | ||
if n_shards > n_cpus: | ||
print( | ||
f"Warning: Your number of shards ({n_shards}) is greater than the number of cores on your machine ({n_cpus}). " | ||
"This may result in slower performance. Consider using a smaller number of shards." | ||
) | ||
|
||
# get all chain dirs in alignment_dir | ||
print("Getting chain directories...") | ||
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())]) | ||
|
@@ -153,12 +166,36 @@ def main(args): | |
super_index.update(shard_index) | ||
print("\nCreated all shards.") | ||
|
||
if args.duplicate_chains_file: | ||
print("Extending super index with duplicate chains...") | ||
duplicates_added = 0 | ||
with open(args.duplicate_chains_file, "r") as fp: | ||
duplicate_chains = [line.strip().split() for line in fp] | ||
|
||
for chains in duplicate_chains: | ||
# find representative with alignment | ||
for chain in chains: | ||
if chain in super_index: | ||
representative_chain = chain | ||
break | ||
else: | ||
print(f"No representative chain found for {chains}, skipping...") | ||
continue | ||
|
||
# add duplicates to index | ||
for chain in chains: | ||
if chain != representative_chain: | ||
super_index[chain] = super_index[representative_chain] | ||
duplicates_added += 1 | ||
|
||
print(f"Added {duplicates_added} duplicate chains to index.") | ||
|
||
# write super index to file | ||
print("\nWriting super index...") | ||
index_path = output_dir / f"{output_db_name}.index" | ||
with open(index_path, "w") as fp: | ||
json.dump(super_index, fp, indent=4) | ||
|
||
print("Done.") | ||
|
||
|
||
|
@@ -179,13 +216,27 @@ def main(args): | |
parser.add_argument( | ||
"alignment_dir", | ||
type=Path, | ||
help="""Path to precomputed alignment directory, with one subdirectory | ||
per chain.""", | ||
help="""Path to precomputed flattened alignment directory, with one | ||
subdirectory per chain.""", | ||
) | ||
parser.add_argument("output_db_path", type=Path) | ||
parser.add_argument("output_db_name", type=str) | ||
parser.add_argument( | ||
"n_shards", type=int, help="Number of shards to split the database into" | ||
"--n_shards", | ||
type=int, | ||
help="Number of shards to split the database into", | ||
default=10, | ||
) | ||
parser.add_argument( | ||
"--duplicate_chains_file", | ||
type=Path, | ||
help=""" | ||
Optional path to file containing duplicate chain information, where each | ||
line contains chains that are 100% sequence identical. If provided, | ||
duplicate chains will be added to the index and point to the same | ||
underlying database entry as their representatives in the alignment dir. | ||
""", | ||
default=None, | ||
) | ||
|
||
args = parser.parse_args() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
""" | ||
The RODA database is non-redundant, meaning that it only stores one explicit | ||
representative alignment directory for all PDB chains in a 100% sequence | ||
identity cluster. In order to add explicit alignments for all PDB chains, this | ||
script will add the missing chain directories and symlink them to their | ||
representative alignment directories. | ||
""" | ||
|
||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
|
||
from tqdm import tqdm | ||
|
||
|
||
def create_duplicate_dirs(duplicate_chains: list[list[str]], alignment_dir: Path): | ||
""" | ||
Create duplicate directory symlinks for all chains in the given duplicate lists. | ||
|
||
Args: | ||
duplicate_lists (list[list[str]]): A list of lists, where each inner list | ||
contains chains that are 100% sequence identical. | ||
alignment_dir (Path): Path to flattened alignment directory, with one | ||
subdirectory per chain. | ||
""" | ||
print("Creating duplicate directory symlinks...") | ||
dirs_created = 0 | ||
for chains in tqdm(duplicate_chains): | ||
# find the chain that has an alignment | ||
for chain in chains: | ||
if (alignment_dir / chain).exists(): | ||
representative_chain = chain | ||
break | ||
else: | ||
print(f"No representative chain found for {chains}, skipping...") | ||
continue | ||
|
||
# create symlinks for all other chains | ||
for chain in chains: | ||
if chain != representative_chain: | ||
target_path = alignment_dir / chain | ||
if target_path.exists(): | ||
print(f"Chain {chain} already exists, skipping...") | ||
else: | ||
(target_path).symlink_to(alignment_dir / representative_chain) | ||
dirs_created += 1 | ||
|
||
print(f"Created directories for {dirs_created} duplicate chains.") | ||
|
||
|
||
def main(alignment_dir: Path, duplicate_chains_file: Path): | ||
# read duplicate chains file | ||
with open(duplicate_chains_file, "r") as fp: | ||
duplicate_chains = [list(line.strip().split()) for line in fp] | ||
|
||
create_duplicate_dirs(duplicate_chains, alignment_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser(description=__doc__) | ||
parser.add_argument( | ||
"alignment_dir", | ||
type=Path, | ||
help="""Path to flattened alignment directory, with one subdirectory | ||
per chain.""", | ||
) | ||
parser.add_argument( | ||
"duplicate_chains_file", | ||
type=Path, | ||
help="""Path to file containing duplicate chains, where each line | ||
contains a space-separated list of chains that are 100%% | ||
sequence identical. | ||
""", | ||
) | ||
args = parser.parse_args() | ||
main(args.alignment_dir, args.duplicate_chains_file) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want, you can also add typing to the dict, i.e. for this specific example, you could write dict[str, Tuple[str, bytes]]
They can sometimes be helpful to know what to expect. For this specific function I don't think it's as needed / helpful.