-
-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change the pdb_paths working style and support for loading both local… #214
base: master
Are you sure you want to change the base?
Changes from all commits
524beb2
f5b017c
b24bdeb
ce6d36b
f5da2c4
9a67631
07cd92a
a3dfff8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,9 +44,9 @@ def __init__( | |
pdb_paths: Optional[List[str]] = None, | ||
pdb_codes: Optional[List[str]] = None, | ||
uniprot_ids: Optional[List[str]] = None, | ||
graph_label_map: Optional[Dict[str, torch.Tensor]] = None, | ||
node_label_map: Optional[Dict[str, torch.Tensor]] = None, | ||
chain_selection_map: Optional[Dict[str, List[str]]] = None, | ||
graph_labels: Optional[List[torch.Tensor]] = None, | ||
node_labels: Optional[List[torch.Tensor]] = None, | ||
chain_selections: Optional[List[str]] = None, | ||
graphein_config: ProteinGraphConfig = ProteinGraphConfig(), | ||
graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor( | ||
src_format="nx", dst_format="pyg" | ||
|
@@ -130,54 +130,61 @@ def __init__( | |
self.pdb_codes = ( | ||
[pdb.lower() for pdb in pdb_codes] | ||
if pdb_codes is not None | ||
else None | ||
else [] | ||
) | ||
self.uniprot_ids = ( | ||
[up.upper() for up in uniprot_ids] | ||
if uniprot_ids is not None | ||
else None | ||
else [] | ||
) | ||
self.pdb_paths = ( | ||
pdb_paths if pdb_paths is not None | ||
else [] | ||
) | ||
|
||
self.pdb_paths = pdb_paths | ||
if self.pdb_paths is None: | ||
if self.pdb_codes and self.uniprot_ids: | ||
self.structures = self.pdb_codes + self.uniprot_ids | ||
elif self.pdb_codes: | ||
self.structures = pdb_codes | ||
elif self.uniprot_ids: | ||
self.structures = uniprot_ids | ||
# Use local saved pdb_files instead of download or move them to self.root/raw dir | ||
else: | ||
if isinstance(self.pdb_paths, list): | ||
self.structures = [ | ||
# make sure root path is unique | ||
if self.pdb_paths: | ||
# add pdb_paths' name into self.structure | ||
self.pdb_paths_name = [ | ||
os.path.splitext(os.path.split(pdb_path)[-1])[0] | ||
for pdb_path in self.pdb_paths | ||
] | ||
self.pdb_path, _ = os.path.split(self.pdb_paths[0]) | ||
|
||
if self.pdb_codes and self.uniprot_ids: | ||
self.structures = self.pdb_codes + self.uniprot_ids | ||
elif self.pdb_codes: | ||
self.structures = pdb_codes | ||
elif self.uniprot_ids: | ||
self.structures = uniprot_ids | ||
self.af_version = af_version | ||
# if root pdb_path is not unique raise error since we will save all pdb into this root pdb_path and take it as the self.raw_dir | ||
if len(set([os.path.split(pdb_path)[0] for pdb_path in pdb_paths])) != 1: | ||
raise ValueError("pdb_paths should have only one root path not so much!") | ||
else: | ||
self.pdb_paths_name = [] | ||
|
||
self.structures = list(set(self.pdb_codes + self.uniprot_ids + self.pdb_paths_name)) # remove some pdb_codes is in pdb_path and loaded repeately | ||
|
||
# Labels & Chains | ||
if graph_labels is not None: | ||
self.graph_label_map = dict(enumerate(graph_labels)) | ||
else: | ||
self.graph_label_map = None | ||
|
||
if node_labels is not None: | ||
self.node_label_map = dict(enumerate(node_labels)) | ||
else: | ||
self.node_label_map = None | ||
if chain_selections is not None: | ||
self.chain_selection_map = dict(enumerate(chain_selections)) | ||
else: | ||
self.chain_selection_map = None | ||
self.validate_input() | ||
self.bad_pdbs: List[ | ||
str | ||
] = [] # list of pdb codes that failed to download | ||
|
||
# Labels & Chains | ||
self.graph_label_map = graph_label_map | ||
self.node_label_map = node_label_map | ||
self.chain_selection_map = chain_selection_map | ||
|
||
# Configs | ||
self.config = graphein_config | ||
self.graph_format_convertor = graph_format_convertor | ||
self.graph_transformation_funcs = graph_transformation_funcs | ||
self.pdb_transform = pdb_transform | ||
self.num_cores = num_cores | ||
self.af_version = af_version | ||
|
||
super().__init__( | ||
root, | ||
transform=transform, | ||
|
@@ -200,10 +207,36 @@ def processed_file_names(self) -> List[str]: | |
@property | ||
def raw_dir(self) -> str: | ||
if self.pdb_paths is not None: | ||
return self.pdb_path # replace raw dir with user local pdb_path | ||
# replace raw dir with user local pdb_path; so pdb_paths should be located in the same place | ||
self.pdb_path, _ = os.path.split(self.pdb_paths[0]) | ||
return self.pdb_path | ||
else: | ||
return os.path.join(self.root, "raw") | ||
|
||
def validate_input(self): | ||
if self.graph_label_map is not None: | ||
assert len(self.structures) == len( | ||
self.graph_label_map | ||
), "Number of proteins and graph labels must match" | ||
if self.node_label_map is not None: | ||
assert len(self.structures) == len( | ||
self.node_label_map | ||
), "Number of proteins and node labels must match" | ||
if self.chain_selection_map is not None: | ||
print(len(self.structures)) | ||
print(len(self.chain_selection_map)) | ||
assert len(self.structures) == len( | ||
self.chain_selection_map | ||
), "Number of proteins and chain selections must match" | ||
assert len( | ||
{ | ||
f"{pdb}_{chain}" | ||
for pdb, chain in zip( | ||
self.structures, self.chain_selection_map | ||
) | ||
} | ||
) == len(self.structures), "Duplicate protein/chain combinations" | ||
|
||
def download(self): | ||
"""Download the PDB files from RCSB or Alphafold.""" | ||
self.config.pdb_dir = Path(self.raw_dir) | ||
|
@@ -225,6 +258,7 @@ def download(self): | |
for pdb in set(self.pdb_codes) | ||
if not os.path.exists(Path(self.raw_dir) / f"{pdb}.pdb") | ||
] | ||
print("downloading uniprotids") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohhhhhhhhhhh, too sry for these I'll remove them today |
||
if self.uniprot_ids: | ||
[ | ||
download_alphafold_structure( | ||
|
@@ -237,6 +271,7 @@ def download(self): | |
] | ||
|
||
def __len__(self) -> int: | ||
"""Returns length of data set (number of structures and chain split).""" | ||
return len(self.structures) | ||
|
||
def transform_pdbs(self): | ||
|
@@ -252,6 +287,7 @@ def transform_pdbs(self): | |
def process(self): | ||
"""Process structures into PyG format and save to disk.""" | ||
# Read data into huge `Data` list. | ||
print(self.structures) | ||
structure_files = [ | ||
f"{self.raw_dir}/{pdb}.pdb" for pdb in self.structures | ||
] | ||
|
@@ -330,12 +366,9 @@ def __init__( | |
pdb_paths: Optional[List[str]] = None, | ||
pdb_codes: Optional[List[str]] = None, | ||
uniprot_ids: Optional[List[str]] = None, | ||
# graph_label_map: Optional[Dict[str, int]] = None, | ||
graph_labels: Optional[List[torch.Tensor]] = None, | ||
node_labels: Optional[List[torch.Tensor]] = None, | ||
chain_selections: Optional[List[str]] = None, | ||
# node_label_map: Optional[Dict[str, int]] = None, | ||
# chain_selection_map: Optional[Dict[str, List[str]]] = None, | ||
graphein_config: ProteinGraphConfig = ProteinGraphConfig(), | ||
graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor( | ||
src_format="nx", dst_format="pyg" | ||
|
@@ -364,14 +397,12 @@ def __init__( | |
:param uniprot_ids: List of Uniprot IDs to download and parse from | ||
Alphafold Database. Defaults to ``None``. | ||
:type uniprot_ids: Optional[List[str]], optional | ||
:param graph_label_map: Dictionary mapping PDB/Uniprot IDs to | ||
graph-level labels. Defaults to ``None``. | ||
:type graph_label_map: Optional[Dict[str, Tensor]], optional | ||
:param node_label_map: Dictionary mapping PDB/Uniprot IDs to node-level | ||
labels. Defaults to ``None``. | ||
:type node_label_map: Optional[Dict[str, torch.Tensor]], optional | ||
:param chain_selection_map: Dictionary mapping, defaults to ``None``. | ||
:type chain_selection_map: Optional[Dict[str, List[str]]], optional | ||
:param graph_labels: List mapping to self.structures by index to graph-level labels. Defaults to ``None``. | ||
:type graph_labels: Optional[List[torch.Tensor]], optional | ||
:param node_labels: List mapping to self.structures by index to node-level labels. Defaults to ``None``. | ||
:type node_labels: Optional[List[torch.Tensor]], optional | ||
:param chain_selections: List mapping to self.structures by index to chain selection, defaults to ``None``. | ||
:type chain_selections: Optional[List[str]], optional | ||
:param graphein_config: Protein graph construction config, defaults to | ||
``ProteinGraphConfig()``. | ||
:type graphein_config: ProteinGraphConfig, optional | ||
|
@@ -412,34 +443,36 @@ def __init__( | |
self.pdb_codes = ( | ||
[pdb.lower() for pdb in pdb_codes] | ||
if pdb_codes is not None | ||
else None | ||
else [] | ||
) | ||
self.uniprot_ids = ( | ||
[up.upper() for up in uniprot_ids] | ||
if uniprot_ids is not None | ||
else None | ||
else [] | ||
) | ||
self.pdb_paths = pdb_paths | ||
if self.pdb_paths is None: | ||
if self.pdb_codes and self.uniprot_ids: | ||
self.structures = self.pdb_codes + self.uniprot_ids | ||
elif self.pdb_codes: | ||
self.structures = pdb_codes | ||
elif self.uniprot_ids: | ||
self.structures = uniprot_ids | ||
# Use local saved pdb_files instead of download or move them to self.root/raw dir | ||
else: | ||
if isinstance(self.pdb_paths, list): | ||
self.structures = [ | ||
self.pdb_paths = ( | ||
pdb_paths if pdb_paths is not None | ||
else [] | ||
) | ||
|
||
# make sure root path is unique | ||
if self.pdb_paths: | ||
# add pdb_paths' name into self.structure | ||
self.pdb_paths_name = [ | ||
os.path.splitext(os.path.split(pdb_path)[-1])[0] | ||
for pdb_path in self.pdb_paths | ||
] | ||
self.pdb_path, _ = os.path.split(self.pdb_paths[0]) | ||
|
||
# Labels & Chains | ||
# if root pdb_path is not unique raise error since we will save all pdb into this root pdb_path and take it as the self.raw_dir | ||
if len(set([os.path.split(pdb_path)[0] for pdb_path in pdb_paths])) != 1: | ||
raise ValueError("pdb_paths should have only one root path not so much!") | ||
else: | ||
self.pdb_paths_name = [] | ||
|
||
self.structures = list(set(self.pdb_codes + self.uniprot_ids + self.pdb_paths_name)) # remove some pdb_codes is in pdb_path and loaded repeately | ||
|
||
self.examples: Dict[int, str] = dict(enumerate(self.structures)) | ||
|
||
# Labels & Chains | ||
if graph_labels is not None: | ||
self.graph_label_map = dict(enumerate(graph_labels)) | ||
else: | ||
|
@@ -460,9 +493,9 @@ def __init__( | |
# Configs | ||
self.config = graphein_config | ||
self.graph_format_convertor = graph_format_convertor | ||
self.num_cores = num_cores | ||
self.pdb_transform = pdb_transform | ||
self.graph_transformation_funcs = graph_transformation_funcs | ||
self.pdb_transform = pdb_transform | ||
self.num_cores = num_cores | ||
self.af_version = af_version | ||
super().__init__( | ||
root, | ||
|
@@ -492,8 +525,10 @@ def processed_file_names(self) -> List[str]: | |
|
||
@property | ||
def raw_dir(self) -> str: | ||
if self.pdb_paths is not None: | ||
return self.pdb_path # replace raw dir with user local pdb_path | ||
if self.pdb_paths: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think it would be useful to allow users to choose a path for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree. I'm not sure about this, i prefer to dict, which key is the names like
If something wrong in my understanding, please tell me 😄 , i'm still reading and learning your code lol. It's really a pythonic code, i learnt a lot 👍 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't think this is the best idea. I think being explicit about the paths users want to use is best. For instance, people may want to use only a subset of their dataset (rather than everything in the directory - e.g. imagine where you want to keep all your pdb files together but train/test on different subsets). It also has the potential problem with hidden files like
This was my initial implementation. However, this ran into the problem where you may have different examples in your dataset drawn from different chains of the same PDB. E.g. imagine you have
Thanks!! Me too! |
||
# replace raw dir with user local pdb_path; so pdb_paths should be located in the same place | ||
self.pdb_path, _ = os.path.split(self.pdb_paths[0]) | ||
return self.pdb_path | ||
else: | ||
return os.path.join(self.root, "raw") | ||
|
||
|
@@ -610,7 +645,6 @@ def divide_chunks(l: List[str], n: int = 2) -> Generator: | |
) | ||
if self.graph_transformation_funcs is not None: | ||
graphs = [self.transform_graphein_graphs(g) for g in graphs] | ||
|
||
# Convert to PyTorch Geometric Data | ||
graphs = [self.graph_format_convertor(g) for g in graphs] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Dataset loader transform Functions for loading set""" | ||
# Graphein | ||
# Author: xutingfeng <xutingfeng@big.ac.cn> | ||
# License: MIT | ||
|
||
# Project Website: https://github.com/a-r-j/graphein | ||
# Code Repository: https://github.com/a-r-j/graphein | ||
|
||
from torch_geometric.transforms import BaseTransform | ||
from typing import Callable, Dict, Generator, List, Optional, Union | ||
|
||
import torch | ||
import torch_geometric | ||
|
||
|
||
def attr2Tensor(data:torch_geometric.data.Data, attr:str, dtype=torch.float32)->torch_geometric.data.Data: | ||
|
||
""" | ||
``attr2Tensor`` Change specific attribution ``attr`` of ``torch_geometric.data.Data`` to Tensor. Assume these ``attr`` are [``np.array``, ``np.array``]. If the array's shape is ``(dim,)`` will be reshaped as ``(1, dim)``. Finally the shape of the array will be ``(O_1, dim1, dim2, ...)`` or ``(1, dim)``, and will be concated at axis=0 or dim=0. | ||
|
||
:param data: Data to change it's attribution | ||
:type data: torch_geometric.data.Data | ||
:param attr: The name of attribution | ||
:type attr: str | ||
:param dtype: Dtype of new attr, defaults to torch.float32 | ||
:type dtype: _type_, optional | ||
:raises ValueError: Passing the correct name of attribution | ||
:return: torch_geometric.data.Data with specific attr Tensor | ||
:rtype: torch_geometric.data.Data | ||
""" | ||
if not hasattr(data, attr): | ||
raise ValueError(f"doesn't have an attr:{attr}") | ||
attr_data = getattr(data, attr) | ||
|
||
if isinstance(attr_data, list): | ||
attr_data = [torch.tensor(i, dtype=dtype).unsqueeze(0) if len(i.shape) == 1 else torch.tensor(i, dtype=dtype) for i in attr_data ] | ||
|
||
attr_data = torch.concat(attr_data, 0) | ||
setattr(data, attr, attr_data) | ||
|
||
|
||
class Reshape_Attr2Tensor(BaseTransform): | ||
""" | ||
Reshape_Attr2Tensor Convert ``Data.attr``, like ``[np.array, np.array]`` to ``torch.Tensor`` | ||
""" | ||
def __init__(self, attr:Union[str, List]) -> None: | ||
self.attr = attr | ||
|
||
def __call__(self, data:torch_geometric.data.Data)->torch_geometric.data.Data: | ||
""" | ||
:param data: Convert attr of ``torch_geometric.data.Data`` to ``torch.Tensor`` | ||
:type data: torch_geometric.data.Data | ||
""" | ||
def _attr2Tensor_recursive(x, attr:Union[str, List]): | ||
if isinstance(attr, str): | ||
attr2Tensor(x, attr) | ||
return x | ||
elif isinstance(attr, list): # multiple attr will be done by this | ||
for i in attr: | ||
x = _attr2Tensor_recursive(x, i) | ||
return x | ||
else: | ||
raise ValueError(f"attr name or list of attr name") | ||
_attr2Tensor_recursive(data, self.attr) | ||
return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be a set operation. With chain selections you may want to have e.g.
3eiy_A
and3eiy_B
as different examples in your dataset.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, i guess it would'n make a difference at chain selection, this set operation is to drop duplicate in the result list of
pdb_codes
+uniprot_ids
+paths_name
. As you can see, local dir may have some pdb files like10gs.pdb
, and ifpdb_codes
also have10gs
to download, and self.structures would contain double10gs
and so the finialdataset
object will have duplicateData
object.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It becomes a problem here though (L283), no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, i guess not.
code like below in the
tests/ml/test.ipynb
and before running it:
then run it :
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you see what happens with:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, i'll try later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and why i
['4hhs', '4hhs']
I guess this may need lots of change~