-
Notifications
You must be signed in to change notification settings - Fork 126
/
expansion_strategies.py
162 lines (131 loc) · 6.15 KB
/
expansion_strategies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
""" Module containing classes that implements different expansion policy strategies
"""
from __future__ import annotations
import abc
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
from aizynthfinder.chem import TemplatedRetroReaction
from aizynthfinder.utils.models import load_model
from aizynthfinder.utils.logging import logger
from aizynthfinder.context.policy.utils import _make_fingerprint
from aizynthfinder.utils.exceptions import PolicyException
if TYPE_CHECKING:
from aizynthfinder.utils.type_utils import Any, Sequence, List, Tuple
from aizynthfinder.context.config import Configuration
from aizynthfinder.chem import TreeMolecule
from aizynthfinder.chem.reaction import RetroReaction
class ExpansionStrategy(abc.ABC):
"""
A base class for all expansion strategies.
The strategy can be used by either calling the `get_actions` method
of by calling the instantiated class with a list of molecule.
.. code-block::
expander = MyExpansionStrategy("dummy", config)
actions, priors = expander.get_actions(molecules)
actions, priors = expander(molecules)
:param key: the key or label
:param config: the configuration of the tree search
"""
_required_kwargs: List[str] = []
def __init__(self, key: str, config: Configuration, **kwargs: str) -> None:
if any(name not in kwargs for name in self._required_kwargs):
raise PolicyException(
f"A {self.__class__.__name__} class needs to be initiated "
f"with keyword arguments: {', '.join(self._required_kwargs)}"
)
self._config = config
self._logger = logger()
self.key = key
def __call__(
self, molecules: Sequence[TreeMolecule]
) -> Tuple[List[RetroReaction], List[float]]:
return self.get_actions(molecules)
@abc.abstractmethod
def get_actions(
self, molecules: Sequence[TreeMolecule]
) -> Tuple[List[RetroReaction], List[float]]:
"""
Get all the probable actions of a set of molecules
:param molecules: the molecules to consider
:return: the actions and the priors of those actions
"""
class TemplateBasedExpansionStrategy(ExpansionStrategy):
"""
A template-based expansion strategy that will return `TemplatedRetroReaction` objects upon expansion.
:param key: the key or label
:param config: the configuration of the tree search
:param source: the source of the policy model
:param templatefile: the path to a HDF5 file with the templates
:raises PolicyException: if the length of the model output vector is not same as the number of templates
"""
_required_kwargs = ["source", "templatefile"]
def __init__(self, key: str, config: Configuration, **kwargs: str) -> None:
super().__init__(key, config, **kwargs)
source = kwargs["source"]
templatefile = kwargs["templatefile"]
self._logger.info(
f"Loading template-based expansion policy model from {source} to {self.key}"
)
self.model = load_model(source, self.key, self._config.use_remote_models)
self._logger.info(f"Loading templates from {templatefile} to {self.key}")
self.templates: pd.DataFrame = pd.read_hdf(templatefile, "table")
if hasattr(self.model, "output_size") and len(self.templates) != self.model.output_size: # type: ignore
raise PolicyException(
f"The number of templates ({len(self.templates)}) does not agree with the " # type: ignore
f"output dimensions of the model ({self.model.output_size})"
)
# pylint: disable=R0914
def get_actions(
self, molecules: Sequence[TreeMolecule]
) -> Tuple[List[RetroReaction], List[float]]:
"""
Get all the probable actions of a set of molecules, using the selected policies and given cutoffs
:param molecules: the molecules to consider
:return: the actions and the priors of those actions
"""
possible_actions = []
priors = []
for mol in molecules:
model = self.model
templates = self.templates
all_transforms_prop = self._predict(mol, model)
probable_transforms_idx = self._cutoff_predictions(all_transforms_prop)
possible_moves = templates.iloc[probable_transforms_idx]
probs = all_transforms_prop[probable_transforms_idx]
priors.extend(probs)
for idx, (move_index, move) in enumerate(possible_moves.iterrows()):
metadata = dict(move)
del metadata[self._config.template_column]
metadata["policy_probability"] = float(probs[idx].round(4))
metadata["policy_probability_rank"] = idx
metadata["policy_name"] = self.key
metadata["template_code"] = move_index
metadata["template"] = move[self._config.template_column]
possible_actions.append(
TemplatedRetroReaction(
mol,
smarts=move[self._config.template_column],
metadata=metadata,
use_rdchiral=self._config.use_rdchiral,
)
)
return possible_actions, priors # type: ignore
def _cutoff_predictions(self, predictions: np.ndarray) -> np.ndarray:
"""
Get the top transformations, by selecting those that have:
* cumulative probability less than a threshold (cutoff_cumulative)
* or at most N (cutoff_number)
"""
sortidx = np.argsort(predictions)[::-1]
cumsum: np.ndarray = np.cumsum(predictions[sortidx])
if any(cumsum >= self._config.cutoff_cumulative):
maxidx = int(np.argmin(cumsum < self._config.cutoff_cumulative))
else:
maxidx = len(cumsum)
maxidx = min(maxidx, self._config.cutoff_number) or 1
return sortidx[:maxidx]
@staticmethod
def _predict(mol: TreeMolecule, model: Any) -> np.ndarray:
fp_arr = _make_fingerprint(mol, model)
return np.array(model.predict(fp_arr)).flatten()