Skip to content

Commit

Permalink
Merge branch 'dl_updates' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-j committed Jul 14, 2022
2 parents e80b7d4 + 0c2ac86 commit bb4ba76
Showing 1 changed file with 87 additions and 43 deletions.
130 changes: 87 additions & 43 deletions graphein/ml/datasets/torch_geometric_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,12 @@ def __init__(
root,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
graph_label_map: Optional[Dict[str, int]] = None,
node_label_map: Optional[Dict[str, int]] = None,
chain_selection_map: Optional[Dict[str, List[str]]] = None,
# graph_label_map: Optional[Dict[str, int]] = None,
graph_labels: Optional[List[torch.Tensor]] = None,
node_labels: Optional[List[torch.Tensor]] = None,
chain_selections: Optional[List[str]] = None,
# node_label_map: Optional[Dict[str, int]] = None,
# chain_selection_map: Optional[Dict[str, List[str]]] = None,
graphein_config: ProteinGraphConfig = ProteinGraphConfig(),
graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor(
src_format="nx", dst_format="pyg"
Expand Down Expand Up @@ -395,10 +398,25 @@ def __init__(
self.af_version = af_version

# Labels & Chains
self.graph_label_map = graph_label_map
self.node_label_map = node_label_map
self.chain_selection_map = chain_selection_map
self.bad_pdbs: List[str] = []

self.examples: Dict[int, str] = dict(enumerate(self.structures))

if graph_labels is not None:
self.graph_label_map = dict(enumerate(graph_labels))
else:
self.graph_label_map = None

if node_labels is not None:
self.node_label_map = dict(enumerate(node_labels))
else:
self.node_label_map = None

if chain_selections is not None:
self.chain_selection_map = dict(enumerate(chain_selections))
else:
self.graph_label_map = None
self.validate_input()
self.bad_pdbs: List[str] = []

# Configs
self.config = graphein_config
Expand All @@ -422,7 +440,34 @@ def raw_file_names(self) -> List[str]:
@property
def processed_file_names(self) -> List[str]:
"""Names of processed files to look for"""
return [f"{pdb}.pt" for pdb in self.structures]
if self.chain_selection_map is not None:
return [
f"{pdb}_{chain}.pt"
for pdb, chain in zip(
self.structures, self.chain_selection_map.values()
)
]
else:
return [f"{pdb}.pt" for pdb in self.structures]

def validate_input(self):
assert len(self.structures) == len(
self.graph_label_map
), "Number of proteins and graph labels must match"
assert len(self.structures) == len(
self.node_label_map
), "Number of proteins and node labels must match"
assert len(self.structures) == len(
self.chain_selection_map
), "Number of proteins and chain selections must match"
assert len(
{
f"{pdb}_{chain}"
for pdb, chain in zip(
self.structures, self.chain_selection_map
)
}
) == len(self.structures), "Duplicate protein/chain combinations"

def download(self):
"""Download the PDB files from RCSB or Alphafold."""
Expand Down Expand Up @@ -489,48 +534,46 @@ def divide_chunks(l: List[str], n: int = 2) -> List[List[str]]:
for i in range(0, len(l), n):
yield l[i : i + n]

chunks = list(divide_chunks(self.structures, chunk_size))
# chunks = list(divide_chunks(self.structures, chunk_size))
chunks: List[int] = list(
divide_chunks(list(self.examples.keys()), chunk_size)
)

for chunk in tqdm(chunks):
pdbs = [self.examples[idx] for idx in chunk]
# Get chain selections
if self.chain_selection_map:
if self.chain_selection_map is not None:
chain_selections = [
self.chain_selection_map[pdb]
if pdb in self.chain_selection_map.keys()
else "all"
for pdb in chunk
self.chain_selection_map[idx] for idx in chunk
]
else:
chain_selections = None
chain_selections = ["all"] * len(chunk)

# Create graph objects
file_names = [f"{self.raw_dir}/{pdb}.pdb" for pdb in chunk]
file_names = [f"{self.raw_dir}/{pdb}.pdb" for pdb in pdbs]
graphs = construct_graphs_mp(
pdb_path_it=file_names,
config=self.config,
chain_selections=chain_selections,
return_dict=True,
return_dict=False,
)
if self.graph_transformation_funcs is not None:
graphs = {
k: self.transform_graphein_graphs(v)
for k, v in graphs.items()
}
graphs = [self.transform_graphein_graphs(g) for g in graphs]

# Convert to PyTorch Geometric Data
graphs = {
k: self.graph_format_convertor(v) for k, v in graphs.items()
}
graphs = dict(zip(chunk, graphs.values()))
graphs = [self.graph_format_convertor(g) for g in graphs]

# Assign labels
if self.graph_label_map:
for k, v in self.graph_label_map.items():
graphs[k].graph_y = v
labels = [self.graph_label_map[idx] for idx in chunk]
for i, _ in enumerate(chunk):
graphs[i].graph_y = labels[i]
if self.node_label_map:
for k, v in self.node_label_map.items():
graphs[k].node_y = v
labels = [self.node_label_map[idx] for idx in chunk]
for i, _ in enumerate(chunk):
graphs[i].graph_y = labels[i]

data_list = list(graphs.values())
data_list = graphs

del graphs

Expand All @@ -540,18 +583,11 @@ def divide_chunks(l: List[str], n: int = 2) -> List[List[str]]:
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]

idxs = [
i
for i in range(idx * chunk_size, idx * chunk_size + len(chunk))
]

for data, id in zip(data_list, idxs):
for i, (pdb, chain) in enumerate(zip(pdbs, chain_selections)):

torch.save(
data,
os.path.join(
self.processed_dir, f"{self.structures[id]}.pt"
),
data_list[i],
os.path.join(self.processed_dir, f"{pdb}_{chain}.pt"),
)
idx += 1

Expand All @@ -563,9 +599,17 @@ def get(self, idx: int):
:type idx: int
:return: PyTorch Geometric Data object.
"""
return torch.load(
os.path.join(self.processed_dir, f"{self.structures[idx]}.pt")
)
if self.chain_selection_map is not None:
return torch.load(
os.path.join(
self.processed_dir,
f"{self.structures[idx]}_{self.chain_selection_map[idx]}.pt",
)
)
else:
return torch.load(
os.path.join(self.processed_dir, f"{self.structures[idx]}.pt")
)


class ProteinGraphListDataset(InMemoryDataset):
Expand Down

0 comments on commit bb4ba76

Please sign in to comment.