Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change the pdb_paths working style and support for loading both local… #214

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ jobs:
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install BLAST
run: sudo apt install ncbi-blast+
- name: Install notebook utils
run: pip install nbformat
- name: Install Graphein
run: pip install -e .
- name: Install Extras
Expand Down
162 changes: 98 additions & 64 deletions graphein/ml/datasets/torch_geometric_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def __init__(
pdb_paths: Optional[List[str]] = None,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
graph_label_map: Optional[Dict[str, torch.Tensor]] = None,
node_label_map: Optional[Dict[str, torch.Tensor]] = None,
chain_selection_map: Optional[Dict[str, List[str]]] = None,
graph_labels: Optional[List[torch.Tensor]] = None,
node_labels: Optional[List[torch.Tensor]] = None,
chain_selections: Optional[List[str]] = None,
graphein_config: ProteinGraphConfig = ProteinGraphConfig(),
graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor(
src_format="nx", dst_format="pyg"
Expand Down Expand Up @@ -130,54 +130,61 @@ def __init__(
self.pdb_codes = (
[pdb.lower() for pdb in pdb_codes]
if pdb_codes is not None
else None
else []
)
self.uniprot_ids = (
[up.upper() for up in uniprot_ids]
if uniprot_ids is not None
else None
else []
)
self.pdb_paths = (
pdb_paths if pdb_paths is not None
else []
)

self.pdb_paths = pdb_paths
if self.pdb_paths is None:
if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
# Use local saved pdb_files instead of download or move them to self.root/raw dir
else:
if isinstance(self.pdb_paths, list):
self.structures = [
# make sure root path is unique
if self.pdb_paths:
# add pdb_paths' name into self.structure
self.pdb_paths_name = [
os.path.splitext(os.path.split(pdb_path)[-1])[0]
for pdb_path in self.pdb_paths
]
self.pdb_path, _ = os.path.split(self.pdb_paths[0])

if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
self.af_version = af_version
# if root pdb_path is not unique raise error since we will save all pdb into this root pdb_path and take it as the self.raw_dir
if len(set([os.path.split(pdb_path)[0] for pdb_path in pdb_paths])) != 1:
raise ValueError("pdb_paths should have only one root path not so much!")
else:
self.pdb_paths_name = []

self.structures = list(set(self.pdb_codes + self.uniprot_ids + self.pdb_paths_name)) # remove some pdb_codes is in pdb_path and loaded repeately
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be a set operation. With chain selections you may want to have e.g. 3eiy_A and 3eiy_B as different examples in your dataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, i guess it would'n make a difference at chain selection, this set operation is to drop duplicate in the result list of pdb_codes + uniprot_ids + paths_name. As you can see, local dir may have some pdb files like 10gs.pdb, and if pdb_codes also have 10gs to download, and self.structures would contain double 10gs and so the finial dataset object will have duplicate Data object.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It becomes a problem here though (L283), no?

    def process(self):
        """Process structures into PyG format and save to disk."""
        # Read data into huge `Data` list.
        structure_files = [
            f"{self.raw_dir}/{pdb}.pdb" for pdb in self.structures
        ]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, i guess not.
code like below in the tests/ml/test.ipynb

from graphein.ml.datasets import InMemoryProteinGraphDataset

local_dir = "../protein/test_data"
pdb_paths = [osp.join(local_dir, pdb_file) for pdb_file in os.listdir(local_dir) if pdb_file.endswith(".pdb")]

ds = InMemoryProteinGraphDataset(root = "../protein/test_data/InMemoryProteinGraphDataset",
                    name = "InMemoryProteinGraphDataset_test",
                    pdb_paths=pdb_paths,
                    pdb_codes=["10gs"],
                    uniprot_ids=["A0A6J1BG53", "A0A6P5Z5F7"],
                    af_version=3)

and before running it:
image

then run it :
image

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you see what happens with:

from graphein.ml.dataset import InMemoryProteinGraphDataset

ds = InMemoryProteinGraphDataset(root = ""../protein.test_data/InMemoryProteinGraphDataset", pdb_paths=pdb_paths, pdb_codes = ["4hhb", "4hhb"], chain_selection=["A","B"])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, i'll try later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you see what happens with:

from graphein.ml.dataset import InMemoryProteinGraphDataset

ds = InMemoryProteinGraphDataset(root = ""../protein.test_data/InMemoryProteinGraphDataset", pdb_paths=pdb_paths, pdb_codes = ["4hhb", "4hhb"], chain_selection=["A","B"])

image

and why i ['4hhs', '4hhs']

image

I guess this may need lots of change~


# Labels & Chains
if graph_labels is not None:
self.graph_label_map = dict(enumerate(graph_labels))
else:
self.graph_label_map = None

if node_labels is not None:
self.node_label_map = dict(enumerate(node_labels))
else:
self.node_label_map = None
if chain_selections is not None:
self.chain_selection_map = dict(enumerate(chain_selections))
else:
self.chain_selection_map = None
self.validate_input()
self.bad_pdbs: List[
str
] = [] # list of pdb codes that failed to download

# Labels & Chains
self.graph_label_map = graph_label_map
self.node_label_map = node_label_map
self.chain_selection_map = chain_selection_map

# Configs
self.config = graphein_config
self.graph_format_convertor = graph_format_convertor
self.graph_transformation_funcs = graph_transformation_funcs
self.pdb_transform = pdb_transform
self.num_cores = num_cores
self.af_version = af_version

super().__init__(
root,
transform=transform,
Expand All @@ -200,10 +207,36 @@ def processed_file_names(self) -> List[str]:
@property
def raw_dir(self) -> str:
if self.pdb_paths is not None:
return self.pdb_path # replace raw dir with user local pdb_path
# replace raw dir with user local pdb_path; so pdb_paths should be located in the same place
self.pdb_path, _ = os.path.split(self.pdb_paths[0])
return self.pdb_path
else:
return os.path.join(self.root, "raw")

def validate_input(self):
if self.graph_label_map is not None:
assert len(self.structures) == len(
self.graph_label_map
), "Number of proteins and graph labels must match"
if self.node_label_map is not None:
assert len(self.structures) == len(
self.node_label_map
), "Number of proteins and node labels must match"
if self.chain_selection_map is not None:
print(len(self.structures))
print(len(self.chain_selection_map))
assert len(self.structures) == len(
self.chain_selection_map
), "Number of proteins and chain selections must match"
assert len(
{
f"{pdb}_{chain}"
for pdb, chain in zip(
self.structures, self.chain_selection_map
)
}
) == len(self.structures), "Duplicate protein/chain combinations"

def download(self):
"""Download the PDB files from RCSB or Alphafold."""
self.config.pdb_dir = Path(self.raw_dir)
Expand All @@ -225,6 +258,7 @@ def download(self):
for pdb in set(self.pdb_codes)
if not os.path.exists(Path(self.raw_dir) / f"{pdb}.pdb")
]
print("downloading uniprotids")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using log would be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhhhhhhhhhh, too sry for these print, forget to remove them, XD.

I'll remove them today

if self.uniprot_ids:
[
download_alphafold_structure(
Expand All @@ -237,6 +271,7 @@ def download(self):
]

def __len__(self) -> int:
"""Returns length of data set (number of structures and chain split)."""
return len(self.structures)

def transform_pdbs(self):
Expand All @@ -252,6 +287,7 @@ def transform_pdbs(self):
def process(self):
"""Process structures into PyG format and save to disk."""
# Read data into huge `Data` list.
print(self.structures)
structure_files = [
f"{self.raw_dir}/{pdb}.pdb" for pdb in self.structures
]
Expand Down Expand Up @@ -330,12 +366,9 @@ def __init__(
pdb_paths: Optional[List[str]] = None,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
# graph_label_map: Optional[Dict[str, int]] = None,
graph_labels: Optional[List[torch.Tensor]] = None,
node_labels: Optional[List[torch.Tensor]] = None,
chain_selections: Optional[List[str]] = None,
# node_label_map: Optional[Dict[str, int]] = None,
# chain_selection_map: Optional[Dict[str, List[str]]] = None,
graphein_config: ProteinGraphConfig = ProteinGraphConfig(),
graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor(
src_format="nx", dst_format="pyg"
Expand Down Expand Up @@ -364,14 +397,12 @@ def __init__(
:param uniprot_ids: List of Uniprot IDs to download and parse from
Alphafold Database. Defaults to ``None``.
:type uniprot_ids: Optional[List[str]], optional
:param graph_label_map: Dictionary mapping PDB/Uniprot IDs to
graph-level labels. Defaults to ``None``.
:type graph_label_map: Optional[Dict[str, Tensor]], optional
:param node_label_map: Dictionary mapping PDB/Uniprot IDs to node-level
labels. Defaults to ``None``.
:type node_label_map: Optional[Dict[str, torch.Tensor]], optional
:param chain_selection_map: Dictionary mapping, defaults to ``None``.
:type chain_selection_map: Optional[Dict[str, List[str]]], optional
:param graph_labels: List mapping to self.structures by index to graph-level labels. Defaults to ``None``.
:type graph_labels: Optional[List[torch.Tensor]], optional
:param node_labels: List mapping to self.structures by index to node-level labels. Defaults to ``None``.
:type node_labels: Optional[List[torch.Tensor]], optional
:param chain_selections: List mapping to self.structures by index to chain selection, defaults to ``None``.
:type chain_selections: Optional[List[str]], optional
:param graphein_config: Protein graph construction config, defaults to
``ProteinGraphConfig()``.
:type graphein_config: ProteinGraphConfig, optional
Expand Down Expand Up @@ -412,34 +443,36 @@ def __init__(
self.pdb_codes = (
[pdb.lower() for pdb in pdb_codes]
if pdb_codes is not None
else None
else []
)
self.uniprot_ids = (
[up.upper() for up in uniprot_ids]
if uniprot_ids is not None
else None
else []
)
self.pdb_paths = pdb_paths
if self.pdb_paths is None:
if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
# Use local saved pdb_files instead of download or move them to self.root/raw dir
else:
if isinstance(self.pdb_paths, list):
self.structures = [
self.pdb_paths = (
pdb_paths if pdb_paths is not None
else []
)

# make sure root path is unique
if self.pdb_paths:
# add pdb_paths' name into self.structure
self.pdb_paths_name = [
os.path.splitext(os.path.split(pdb_path)[-1])[0]
for pdb_path in self.pdb_paths
]
self.pdb_path, _ = os.path.split(self.pdb_paths[0])

# Labels & Chains
# if root pdb_path is not unique raise error since we will save all pdb into this root pdb_path and take it as the self.raw_dir
if len(set([os.path.split(pdb_path)[0] for pdb_path in pdb_paths])) != 1:
raise ValueError("pdb_paths should have only one root path not so much!")
else:
self.pdb_paths_name = []

self.structures = list(set(self.pdb_codes + self.uniprot_ids + self.pdb_paths_name)) # remove some pdb_codes is in pdb_path and loaded repeately

self.examples: Dict[int, str] = dict(enumerate(self.structures))

# Labels & Chains
if graph_labels is not None:
self.graph_label_map = dict(enumerate(graph_labels))
else:
Expand All @@ -460,9 +493,9 @@ def __init__(
# Configs
self.config = graphein_config
self.graph_format_convertor = graph_format_convertor
self.num_cores = num_cores
self.pdb_transform = pdb_transform
self.graph_transformation_funcs = graph_transformation_funcs
self.pdb_transform = pdb_transform
self.num_cores = num_cores
self.af_version = af_version
super().__init__(
root,
Expand Down Expand Up @@ -492,8 +525,10 @@ def processed_file_names(self) -> List[str]:

@property
def raw_dir(self) -> str:
if self.pdb_paths is not None:
return self.pdb_path # replace raw dir with user local pdb_path
if self.pdb_paths:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think it would be useful to allow users to choose a path for raw_dir when initialising the Dataset objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree.
If we simply change self.raw_dir instead of self.pdb_paths, where the former is a folder dir the latter is a list containing pdb_file dir, i guess we will use os.listdir to get local pdb files dir.
And the question is if os.listdir in the func, and then the order of self.structure maybe hard to match the order of graph_labels and node_labels, since we match the labels by index of list, i guess.
image

image

image

I'm not sure about this, i prefer to dict, which key is the names like {'10gs':0} would be better than {0:0}. And then we could just change the raw_dir and os.listdir and get a list of pdb file dir containing both local and downloaded pdb files, and process and assign each pdb files with their node_graph_label or chain_selection or graph_label by their name (remove root path and suffix like ./test/10gs.pdb -> 10gs) not by the enumunated index (which i think it is hard to match the correct order with pdb files when passing the graph_labels)

This description is not very clear, i'll try to make it clear later...

If something wrong in my understanding, please tell me 😄 , i'm still reading and learning your code lol. It's really a pythonic code, i learnt a lot 👍 👍

Copy link
Owner

@a-r-j a-r-j Sep 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we simply change self.raw_dir instead of self.pdb_paths, where the former is a folder dir the latter is a list containing pdb_file dir, i guess we will use `os.listdir`` to get local pdb files dir.

I don't think this is the best idea. I think being explicit about the paths users want to use is best. For instance, people may want to use only a subset of their dataset (rather than everything in the directory - e.g. imagine where you want to keep all your pdb files together but train/test on different subsets). It also has the potential problem with hidden files like .DS_Store etc. You're also completely right about the matching the list to node labels etc.

I'm not sure about this, i prefer to dict, which key is the names like {'10gs':0} would be better than {0:0}

This was my initial implementation. However, this ran into the problem where you may have different examples in your dataset drawn from different chains of the same PDB. E.g. imagine you have 3eiy_A and 3eiy_B with different labels. The current implementation allows for this, whereas indexing on the PDB name does not.

If something wrong in my understanding, please tell me 😄 , i'm still reading and learning your code lol. It's really a pythonic code, i learnt a lot 👍 👍

Thanks!! Me too!

# replace raw dir with user local pdb_path; so pdb_paths should be located in the same place
self.pdb_path, _ = os.path.split(self.pdb_paths[0])
return self.pdb_path
else:
return os.path.join(self.root, "raw")

Expand Down Expand Up @@ -610,7 +645,6 @@ def divide_chunks(l: List[str], n: int = 2) -> Generator:
)
if self.graph_transformation_funcs is not None:
graphs = [self.transform_graphein_graphs(g) for g in graphs]

# Convert to PyTorch Geometric Data
graphs = [self.graph_format_convertor(g) for g in graphs]

Expand Down
65 changes: 65 additions & 0 deletions graphein/ml/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Dataset loader transform Functions for loading set"""
# Graphein
# Author: xutingfeng <xutingfeng@big.ac.cn>
# License: MIT

# Project Website: https://github.com/a-r-j/graphein
# Code Repository: https://github.com/a-r-j/graphein

from torch_geometric.transforms import BaseTransform
from typing import Callable, Dict, Generator, List, Optional, Union

import torch
import torch_geometric


def attr2Tensor(data:torch_geometric.data.Data, attr:str, dtype=torch.float32)->torch_geometric.data.Data:

"""
``attr2Tensor`` Change specific attribution ``attr`` of ``torch_geometric.data.Data`` to Tensor. Assume these ``attr`` are [``np.array``, ``np.array``]. If the array's shape is ``(dim,)`` will be reshaped as ``(1, dim)``. Finally the shape of the array will be ``(O_1, dim1, dim2, ...)`` or ``(1, dim)``, and will be concated at axis=0 or dim=0.

:param data: Data to change it's attribution
:type data: torch_geometric.data.Data
:param attr: The name of attribution
:type attr: str
:param dtype: Dtype of new attr, defaults to torch.float32
:type dtype: _type_, optional
:raises ValueError: Passing the correct name of attribution
:return: torch_geometric.data.Data with specific attr Tensor
:rtype: torch_geometric.data.Data
"""
if not hasattr(data, attr):
raise ValueError(f"doesn't have an attr:{attr}")
attr_data = getattr(data, attr)

if isinstance(attr_data, list):
attr_data = [torch.tensor(i, dtype=dtype).unsqueeze(0) if len(i.shape) == 1 else torch.tensor(i, dtype=dtype) for i in attr_data ]

attr_data = torch.concat(attr_data, 0)
setattr(data, attr, attr_data)


class Reshape_Attr2Tensor(BaseTransform):
"""
Reshape_Attr2Tensor Convert ``Data.attr``, like ``[np.array, np.array]`` to ``torch.Tensor``
"""
def __init__(self, attr:Union[str, List]) -> None:
self.attr = attr

def __call__(self, data:torch_geometric.data.Data)->torch_geometric.data.Data:
"""
:param data: Convert attr of ``torch_geometric.data.Data`` to ``torch.Tensor``
:type data: torch_geometric.data.Data
"""
def _attr2Tensor_recursive(x, attr:Union[str, List]):
if isinstance(attr, str):
attr2Tensor(x, attr)
return x
elif isinstance(attr, list): # multiple attr will be done by this
for i in attr:
x = _attr2Tensor_recursive(x, i)
return x
else:
raise ValueError(f"attr name or list of attr name")
_attr2Tensor_recursive(data, self.attr)
return data
24 changes: 12 additions & 12 deletions graphein/ml/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ def plot_pyg_data(
d["colour"] = float(edge_colour_tensor[i])

return plotly_protein_structure_graph(
nx_graph,
plot_title,
figsize,
node_alpha,
node_size_min,
node_size_multiplier,
label_node_ids,
node_colour_map,
edge_color_map,
colour_nodes_by if node_colour_tensor is None else "colour",
colour_edges_by if edge_colour_tensor is None else "colour",
)
G = nx_graph,
plot_title = plot_title,
figsize = figsize,
node_alpha = node_alpha,
node_size_min = node_size_min,
node_size_multiplier = node_size_multiplier,
label_node_ids = label_node_ids,
node_colour_map = node_colour_map,
edge_color_map = edge_color_map,
colour_nodes_by = colour_nodes_by if node_colour_tensor is None else "colour",
colour_edges_by = colour_edges_by if edge_colour_tensor is None else "colour",
)
3 changes: 2 additions & 1 deletion graphein/protein/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ def download_alphafold_structure(
query_url = f"{BASE_URL}AF-{uniprot_id}-F1-model_v{version}.cif"
if pdb:
query_url = f"{BASE_URL}AF-{uniprot_id}-F1-model_v{version}.pdb"

structure_filename = wget.download(query_url, out=out_dir)

if rename:
extension = ".pdb" if pdb else ".cif"
os.rename(
Expand Down