Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ml dataset #232

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 40 additions & 5 deletions graphein/ml/datasets/torch_geometric_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
from pathlib import Path
from typing import Callable, Dict, Generator, List, Optional
from urllib.error import HTTPError

import networkx as nx
from loguru import logger as log
Expand Down Expand Up @@ -167,6 +168,9 @@ def __init__(
self.bad_pdbs: List[
str
] = [] # list of pdb codes that failed to download
self.bad_uniprot_ids: List[
str
] = [] # list of uniprot ids that failed to download

# Labels & Chains
self.graph_label_map = graph_label_map
Expand Down Expand Up @@ -282,6 +286,9 @@ def process(self):
return_dict=True,
num_cores=self.num_cores,
)
# Keep only graphs that were successfully constructed
graphs = [g for g in graphs if g is not None]

# Transform graphs
if self.graph_transformation_funcs is not None:
print("Transforming Nx Graphs...")
Expand Down Expand Up @@ -412,6 +419,8 @@ def __init__(
defaults to ``2``.
:type af_version: int, optional
"""


self.pdb_codes = (
[pdb.lower() for pdb in pdb_codes]
if pdb_codes is not None
Expand Down Expand Up @@ -460,7 +469,7 @@ def __init__(
self.chain_selection_map = None
self.validate_input()
self.bad_pdbs: List[str] = []

self.bad_uniprot_ids: List[str] = []
# Configs
self.config = graphein_config
self.graph_format_convertor = graph_format_convertor
Expand Down Expand Up @@ -533,6 +542,7 @@ def download(self):
for pdb in set(self.pdb_codes)
if not os.path.exists(Path(self.raw_dir) / f"{pdb}.pdb")
]

download_pdb_multiprocessing(
to_download,
self.raw_dir,
Expand All @@ -545,15 +555,34 @@ def download(self):
if not os.path.exists(Path(self.raw_dir) / f"{pdb}.pdb")
]
if self.uniprot_ids:
[
download_alphafold_structure(

# Only download undownloaded Uniprot IDs
to_download = [
uniprot
for uniprot in set(self.uniprot_ids)
if not os.path.exists(Path(self.raw_dir) / f"{uniprot}.pdb")
]

for uniprot in tqdm(to_download):


fn = download_alphafold_structure(
uniprot,
out_dir=self.raw_dir,
version=self.af_version,
aligned_score=False,
rename=True,
)
for uniprot in tqdm(self.uniprot_ids)

self.bad_uniprot_ids = self.bad_uniprot_ids + [
uniprot
for uniprot in set(self.uniprot_ids)
if not os.path.exists(Path(self.raw_dir) / f"{uniprot}.pdb")
]



# TODO: remove bad uniprot / pdb ids from self.structures

def len(self) -> int:
"""Returns length of data set (number of structures)."""
Expand All @@ -580,7 +609,7 @@ def process(self):
if self.pdb_transform:
self.transform_pdbs()

idx = 0
idx = 0
# Chunk dataset for parallel processing
chunk_size = 128

Expand Down Expand Up @@ -612,6 +641,12 @@ def divide_chunks(l: List[str], n: int = 2) -> Generator:
chain_selections=chain_selections,
return_dict=False,
)
graphs = [
g
for g in graphs
if g is not None
]

if self.graph_transformation_funcs is not None:
graphs = [self.transform_graphein_graphs(g) for g in graphs]

Expand Down
2 changes: 1 addition & 1 deletion graphein/protein/edges/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,7 @@ def node_coords(G: nx.Graph, n: str) -> Tuple[float, float, float]:
:return: Tuple of coordinates ``(x, y, z)``
:rtype: Tuple[float, float, float]
"""
x, y, z = tuple(G.nodes[n]["coords"])
(x, y, z) = tuple(G.nodes[n]["coords"])
return x, y, z


Expand Down