-
Notifications
You must be signed in to change notification settings - Fork 460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Duplicate expansion support #419
Changes from all commits
77860bb
e678050
ee0c5db
94819bf
e2479cb
0b5c949
244970b
78b9706
04410d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
""" | ||
This script generates a FASTA file for all chains in an alignment directory or | ||
alignment DB. | ||
""" | ||
|
||
import json | ||
from argparse import ArgumentParser | ||
from concurrent.futures import ThreadPoolExecutor, as_completed | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
from tqdm import tqdm | ||
|
||
|
||
def chain_dir_to_fasta(dir: Path) -> str: | ||
""" | ||
Generates a FASTA string from a chain directory. | ||
""" | ||
# take some alignment file | ||
for alignment_file_type in [ | ||
"mgnify_hits.a3m", | ||
"uniref90_hits.a3m", | ||
"bfd_uniclust_hits.a3m", | ||
]: | ||
alignment_file = dir / alignment_file_type | ||
if alignment_file.exists(): | ||
break | ||
|
||
with open(alignment_file, "r") as f: | ||
next(f) # skip the first line | ||
seq = next(f).strip() | ||
|
||
try: | ||
next_line = next(f) | ||
except StopIteration: | ||
pass | ||
else: | ||
assert next_line.startswith(">") # ensure that sequence ended | ||
|
||
chain_id = dir.name | ||
|
||
return f">{chain_id}\n{seq}\n" | ||
|
||
|
||
def index_entry_to_fasta(index_entry: dict, db_dir: Path, chain_id: str) -> str: | ||
""" | ||
Generates a FASTA string from an alignment-db index entry. | ||
""" | ||
db_file = db_dir / index_entry["db"] | ||
|
||
# look for an alignment file | ||
for alignment_file_type in [ | ||
"mgnify_hits.a3m", | ||
"uniref90_hits.a3m", | ||
"bfd_uniclust_hits.a3m", | ||
]: | ||
for file_info in index_entry["files"]: | ||
if file_info[0] == alignment_file_type: | ||
start, size = file_info[1], file_info[2] | ||
break | ||
|
||
with open(db_file, "rb") as f: | ||
f.seek(start) | ||
msa_lines = f.read(size).decode("utf-8").splitlines() | ||
seq = msa_lines[1] | ||
|
||
try: | ||
next_line = msa_lines[2] | ||
except IndexError: | ||
pass | ||
else: | ||
assert next_line.startswith(">") # ensure that sequence ended | ||
|
||
return f">{chain_id}\n{seq}\n" | ||
|
||
|
||
def main( | ||
output_path: Path, alignment_db_index: Optional[Path], alignment_dir: Optional[Path] | ||
) -> None: | ||
""" | ||
Generate a FASTA file from either an alignment-db index or a chain directory using multi-threading. | ||
""" | ||
fasta = [] | ||
|
||
if alignment_dir and alignment_db_index: | ||
raise ValueError( | ||
"Only one of alignment_db_index and alignment_dir can be provided." | ||
) | ||
|
||
if alignment_dir: | ||
print("Creating FASTA from alignment directory...") | ||
chain_dirs = list(alignment_dir.iterdir()) | ||
|
||
with ThreadPoolExecutor() as executor: | ||
futures = [ | ||
executor.submit(chain_dir_to_fasta, chain_dir) | ||
for chain_dir in chain_dirs | ||
] | ||
for future in tqdm(as_completed(futures), total=len(chain_dirs)): | ||
fasta.append(future.result()) | ||
|
||
elif alignment_db_index: | ||
print("Creating FASTA from alignment dbs...") | ||
|
||
with open(alignment_db_index, "r") as f: | ||
index = json.load(f) | ||
|
||
db_dir = alignment_db_index.parent | ||
|
||
with ThreadPoolExecutor() as executor: | ||
futures = [ | ||
executor.submit(index_entry_to_fasta, index_entry, db_dir, chain_id) | ||
for chain_id, index_entry in index.items() | ||
] | ||
for future in tqdm(as_completed(futures), total=len(index)): | ||
fasta.append(future.result()) | ||
else: | ||
raise ValueError("Either alignment_db_index or alignment_dir must be provided.") | ||
|
||
with open(output_path, "w") as f: | ||
f.write("".join(fasta)) | ||
print(f"FASTA file written to {output_path}.") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser(description=__doc__) | ||
parser.add_argument( | ||
"output_path", | ||
type=Path, | ||
help="Path to output FASTA file.", | ||
) | ||
parser.add_argument( | ||
"--alignment_db_index", | ||
type=Path, | ||
help="Path to alignment-db index file.", | ||
) | ||
parser.add_argument( | ||
"--alignment_dir", | ||
type=Path, | ||
help="Path to alignment directory.", | ||
) | ||
|
||
args = parser.parse_args() | ||
main(args.output_path, args.alignment_db_index, args.alignment_dir) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,17 +5,19 @@ | |
run on the output index. Additionally this script uses threading and | ||
multiprocessing and is much faster than the old version. | ||
""" | ||
|
||
import argparse | ||
import json | ||
from collections import defaultdict | ||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed | ||
import json | ||
from math import ceil | ||
from multiprocessing import cpu_count | ||
from pathlib import Path | ||
from typing import List | ||
|
||
from tqdm import tqdm | ||
from math import ceil | ||
|
||
|
||
def split_file_list(file_list, n_shards): | ||
def split_file_list(file_list: list[Path], n_shards: int): | ||
""" | ||
Split up the total file list into n_shards sublists. | ||
""" | ||
|
@@ -29,26 +31,25 @@ def split_file_list(file_list, n_shards): | |
return split_list | ||
|
||
|
||
def chunked_iterator(lst, chunk_size): | ||
def chunked_iterator(lst: list, chunk_size: int): | ||
"""Iterate over a list in chunks of size chunk_size.""" | ||
for i in range(0, len(lst), chunk_size): | ||
yield lst[i : i + chunk_size] | ||
|
||
|
||
def read_chain_dir(chain_dir) -> dict: | ||
def read_chain_dir(chain_dir: Path) -> dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want, you can also add typing to the dict, i.e. for this specific example, you could write dict[str, Tuple[str, bytes]] They can sometimes be helpful to know what to expect. For this specific function I don't think it's as needed / helpful. |
||
""" | ||
Read all alignment files in a single chain directory and return a dict | ||
mapping chain name to file names and bytes. | ||
""" | ||
if not chain_dir.is_dir(): | ||
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}") | ||
|
||
# ensure that PDB IDs are all lowercase | ||
pdb_id, chain = chain_dir.name.split("_") | ||
pdb_id = pdb_id.lower() | ||
chain_name = f"{pdb_id}_{chain}" | ||
|
||
|
||
|
||
file_data = [] | ||
|
||
for file_path in sorted(chain_dir.iterdir()): | ||
|
@@ -62,7 +63,7 @@ def read_chain_dir(chain_dir) -> dict: | |
return {chain_name: file_data} | ||
|
||
|
||
def process_chunk(chain_files: List[Path]) -> dict: | ||
def process_chunk(chain_files: list[Path]) -> dict: | ||
""" | ||
Returns the file names and bytes for all chains in a chunk of files. | ||
""" | ||
|
@@ -83,7 +84,7 @@ def create_index_default_dict() -> dict: | |
|
||
|
||
def create_shard( | ||
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int | ||
shard_files: list[Path], output_dir: Path, output_name: str, shard_num: int | ||
) -> dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps having a type alias for the index dict could be helpful. If the index file structure is also used in the main library code, then I would consider adding the alias to the main openfold library. But this can also be done in a later PR. |
||
""" | ||
Creates a single shard of the alignment database, and returns the | ||
|
@@ -92,7 +93,7 @@ def create_shard( | |
CHUNK_SIZE = 200 | ||
shard_index = defaultdict( | ||
create_index_default_dict | ||
) # {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...} | ||
) # e.g. {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...} | ||
chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE) | ||
|
||
pbar_desc = f"Shard {shard_num}" | ||
|
@@ -101,7 +102,11 @@ def create_shard( | |
db_offset = 0 | ||
db_file = open(output_path, "wb") | ||
for files_chunk in tqdm( | ||
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False | ||
chunk_iter, | ||
total=ceil(len(shard_files) / CHUNK_SIZE), | ||
desc=pbar_desc, | ||
position=shard_num, | ||
leave=False, | ||
): | ||
# get processed files for one chunk | ||
chunk_data = process_chunk(files_chunk) | ||
|
@@ -125,9 +130,17 @@ def create_shard( | |
def main(args): | ||
alignment_dir = args.alignment_dir | ||
output_dir = args.output_db_path | ||
output_dir.mkdir(exist_ok=True, parents=True) | ||
output_db_name = args.output_db_name | ||
n_shards = args.n_shards | ||
|
||
n_cpus = cpu_count() | ||
if n_shards > n_cpus: | ||
print( | ||
f"Warning: Your number of shards ({n_shards}) is greater than the number of cores on your machine ({n_cpus}). " | ||
"This may result in slower performance. Consider using a smaller number of shards." | ||
) | ||
|
||
# get all chain dirs in alignment_dir | ||
print("Getting chain directories...") | ||
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())]) | ||
|
@@ -153,12 +166,36 @@ def main(args): | |
super_index.update(shard_index) | ||
print("\nCreated all shards.") | ||
|
||
if args.duplicate_chains_file: | ||
print("Extending super index with duplicate chains...") | ||
duplicates_added = 0 | ||
with open(args.duplicate_chains_file, "r") as fp: | ||
duplicate_chains = [line.strip().split() for line in fp] | ||
|
||
for chains in duplicate_chains: | ||
# find representative with alignment | ||
for chain in chains: | ||
if chain in super_index: | ||
representative_chain = chain | ||
break | ||
else: | ||
print(f"No representative chain found for {chains}, skipping...") | ||
continue | ||
|
||
# add duplicates to index | ||
for chain in chains: | ||
if chain != representative_chain: | ||
super_index[chain] = super_index[representative_chain] | ||
duplicates_added += 1 | ||
|
||
print(f"Added {duplicates_added} duplicate chains to index.") | ||
|
||
# write super index to file | ||
print("\nWriting super index...") | ||
index_path = output_dir / f"{output_db_name}.index" | ||
with open(index_path, "w") as fp: | ||
json.dump(super_index, fp, indent=4) | ||
|
||
print("Done.") | ||
|
||
|
||
|
@@ -179,13 +216,27 @@ def main(args): | |
parser.add_argument( | ||
"alignment_dir", | ||
type=Path, | ||
help="""Path to precomputed alignment directory, with one subdirectory | ||
per chain.""", | ||
help="""Path to precomputed flattened alignment directory, with one | ||
subdirectory per chain.""", | ||
) | ||
parser.add_argument("output_db_path", type=Path) | ||
parser.add_argument("output_db_name", type=str) | ||
parser.add_argument( | ||
"n_shards", type=int, help="Number of shards to split the database into" | ||
"--n_shards", | ||
type=int, | ||
help="Number of shards to split the database into", | ||
default=10, | ||
) | ||
parser.add_argument( | ||
"--duplicate_chains_file", | ||
type=Path, | ||
help=""" | ||
Optional path to file containing duplicate chain information, where each | ||
line contains chains that are 100% sequence identical. If provided, | ||
duplicate chains will be added to the index and point to the same | ||
underlying database entry as their representatives in the alignment dir. | ||
""", | ||
default=None, | ||
) | ||
|
||
args = parser.parse_args() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: consider making the expected alignment_file_types a global variable, especially since they get used in multiple functions.
Does this mean that this code would not support having .sto alignment files? I don't think we need support for this now, but perhaps good to mention somewhere.