diff --git a/CHANGELOG.md b/CHANGELOG.md index 555dc56..54ffd96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # CHANGELOG +## Version 3.3.0 2022-02-24 + +### Features + +- Support for Retro* tree search +- Support for breadth-first exhaustive tree search +- Support for depth-first proof-number tree search +- Possible to save concatenated reaction trees to separate file + +### Bugfixes + +- RouteCostScorer fix for rare routes + ## Version 3.2.0 2022-02-24 ### Features diff --git a/aizynthfinder/context/cost/collection.py b/aizynthfinder/context/cost/collection.py index 1ad5072..1a8741a 100644 --- a/aizynthfinder/context/cost/collection.py +++ b/aizynthfinder/context/cost/collection.py @@ -71,3 +71,4 @@ def load_from_config(self, **costs_config: Any) -> None: ) self._logger.info(f"Loaded cost: '{repr(obj)}'{config_str}") self._items[repr(obj)] = obj + self.select_last() diff --git a/aizynthfinder/context/scoring/scorers.py b/aizynthfinder/context/scoring/scorers.py index a2b5799..455df64 100644 --- a/aizynthfinder/context/scoring/scorers.py +++ b/aizynthfinder/context/scoring/scorers.py @@ -297,7 +297,7 @@ def _score_node(self, node: MctsNode) -> float: updated_scores = { id(mol): scores[id(mol)] for mol in pnode.state.mols - if mol != reaction.mol + if mol is not reaction.mol } child_sum = sum( 1 / self.average_yield * score diff --git a/aizynthfinder/search/andor_trees.py b/aizynthfinder/search/andor_trees.py index 6303e1e..66d2b5f 100644 --- a/aizynthfinder/search/andor_trees.py +++ b/aizynthfinder/search/andor_trees.py @@ -83,9 +83,11 @@ def __init__( else: self._sampling_cutoff = max_routes self._partition_search_tree(graph, root_node) - self.routes = [ + routes_list = [ ReactionTreeFromAndOrTrace(trace, stock).tree for trace in self._traces ] + routes_map = {route.hash_key(): route for route in routes_list} + self.routes = list(routes_map.values()) def _partition_search_tree(self, graph: _AndOrTrace, node: TreeNodeMixin) -> None: # fmt: off @@ -190,6 +192,8 @@ def _load(self, andor_trace: nx.DiGraph, stock: Stock) -> None: # type: ignore in_stock=self._trace_root.prop["mol"] in self._stock, ) for node1, node2 in andor_trace.edges(): + if "reaction" in node2.prop and not andor_trace[node2]: + continue rt_node1 = self._make_rt_node(node1) rt_node2 = self._make_rt_node(node2) self.tree.graph.add_edge(rt_node1, rt_node2) diff --git a/aizynthfinder/search/breadth_first/__init__.py b/aizynthfinder/search/breadth_first/__init__.py new file mode 100644 index 0000000..69131c7 --- /dev/null +++ b/aizynthfinder/search/breadth_first/__init__.py @@ -0,0 +1,3 @@ +""" Sub-package containing breadth first routines +""" +from aizynthfinder.search.breadth_first.search_tree import SearchTree diff --git a/aizynthfinder/search/breadth_first/nodes.py b/aizynthfinder/search/breadth_first/nodes.py new file mode 100644 index 0000000..dbd815b --- /dev/null +++ b/aizynthfinder/search/breadth_first/nodes.py @@ -0,0 +1,245 @@ +""" Module containing a classes representation various tree nodes +""" +from __future__ import annotations +from typing import TYPE_CHECKING + +from aizynthfinder.chem import TreeMolecule +from aizynthfinder.search.andor_trees import TreeNodeMixin +from aizynthfinder.chem.serialization import deserialize_action, serialize_action + +if TYPE_CHECKING: + from aizynthfinder.context.config import Configuration + from aizynthfinder.chem.serialization import ( + MoleculeDeserializer, + MoleculeSerializer, + ) + from aizynthfinder.utils.type_utils import ( + StrDict, + Sequence, + Set, + List, + ) + from aizynthfinder.chem import RetroReaction + + +class MoleculeNode(TreeNodeMixin): + """ + An OR node representing a molecule + + :ivar expandable: if True, this node is part of the frontier + :ivar mol: the molecule represented by the node + :ivar in_stock: if True the molecule is in stock and hence should not be expanded + :ivar parent: the parent of the node + + :param mol: the molecule to be represented by the node + :param config: the configuration of the search + :param parent: the parent of the node, optional + """ + + def __init__( + self, mol: TreeMolecule, config: Configuration, parent: ReactionNode = None + ) -> None: + self.mol = mol + self._config = config + self.in_stock = mol in config.stock + self.parent = parent + + self._children: List[ReactionNode] = [] + # Makes it unexpandable if we have reached maximum depth + self.expandable = self.mol.transform <= self._config.max_transforms + + if self.in_stock: + self.expandable = False + + @classmethod + def create_root(cls, smiles: str, config: Configuration) -> "MoleculeNode": + """ + Create a root node for a tree using a SMILES. + + :param smiles: the SMILES representation of the root state + :param config: settings of the tree search algorithm + :return: the created node + """ + mol = TreeMolecule(parent=None, transform=0, smiles=smiles) + return MoleculeNode(mol=mol, config=config) + + @classmethod + def from_dict( + cls, + dict_: StrDict, + config: Configuration, + molecules: MoleculeDeserializer, + parent: ReactionNode = None, + ) -> "MoleculeNode": + """ + Create a new node from a dictionary, i.e. deserialization + + :param dict_: the serialized node + :param config: settings of the tree search algorithm + :param molecules: the deserialized molecules + :param parent: the parent node + :return: a deserialized node + """ + mol = molecules.get_tree_molecules([dict_["mol"]])[0] + node = MoleculeNode(mol, config, parent) + node.expandable = dict_["expandable"] + node.children = [ + ReactionNode.from_dict(child, config, molecules, parent=node) + for child in dict_["children"] + ] + return node + + @property # type: ignore + def children(self) -> List[ReactionNode]: # type: ignore + """ Gives the reaction children nodes """ + return self._children + + @children.setter + def children(self, value: List[ReactionNode]) -> None: + self._children = value + + @property + def prop(self) -> StrDict: + return {"solved": self.in_stock, "mol": self.mol} + + def add_stub(self, reaction: RetroReaction) -> Sequence[MoleculeNode]: + """ + Add a stub / sub-tree to this node + + :param reaction: the reaction creating the stub + :return: list of all newly added molecular nodes + """ + reactants = reaction.reactants[reaction.index] + if not reactants: + return [] + + ancestors = self.ancestors() + for mol in reactants: + if mol in ancestors: + return [] + + rxn_node = ReactionNode.create_stub( + reaction=reaction, parent=self, config=self._config + ) + self._children.append(rxn_node) + + return rxn_node.children + + def ancestors(self) -> Set[TreeMolecule]: + """ + Return the ancestors of this node + + :return: the ancestors + :rtype: set + """ + if not self.parent: + return {self.mol} + + ancestors = self.parent.parent.ancestors() + ancestors.add(self.mol) + return ancestors + + def serialize(self, molecule_store: MoleculeSerializer) -> StrDict: + """ + Serialize the node object to a dictionary + + :param molecule_store: the serialized molecules + :return: the serialized node + """ + dict_: StrDict = {"expandable": self.expandable} + dict_["mol"] = molecule_store[self.mol] + dict_["children"] = [child.serialize(molecule_store) for child in self.children] + return dict_ + + +class ReactionNode(TreeNodeMixin): + """ + An AND node representing a reaction + + :ivar parent: the parent of the node + :ivar reaction: the reaction represented by the node + + :param cost: the cost of the reaction + :param reaction: the reaction to be represented by the node + :param parent: the parent of the node + """ + + def __init__(self, reaction: RetroReaction, parent: MoleculeNode) -> None: + self.parent = parent + self.reaction = reaction + + self._children: List[MoleculeNode] = [] + + @classmethod + def create_stub( + cls, + reaction: RetroReaction, + parent: MoleculeNode, + config: Configuration, + ) -> ReactionNode: + """ + Create a ReactionNode and creates all the MoleculeNode objects + that are the children of the node. + + :param reaction: the reaction to be represented by the node + :param parent: the parent of the node + :param config: the configuration of the search tree + """ + node = cls(reaction, parent) + reactants = reaction.reactants[reaction.index] + node.children = [ + MoleculeNode(mol=mol, config=config, parent=node) for mol in reactants + ] + return node + + @classmethod + def from_dict( + cls, + dict_: StrDict, + config: Configuration, + molecules: MoleculeDeserializer, + parent: MoleculeNode, + ) -> ReactionNode: + """ + Create a new node from a dictionary, i.e. deserialization + + :param dict_: the serialized node + :param config: the configuration of the tree search + :param molecules: the deserialized molecules + :param parent: the parent node + :return: a deserialized node + """ + reaction = deserialize_action(dict_["reaction"], molecules) + node = cls(reaction, parent) + + node.children = [ + MoleculeNode.from_dict(child, config, molecules, parent=node) + for child in dict_["children"] + ] + return node + + @property # type: ignore + def children(self) -> List[MoleculeNode]: # type: ignore + """ Gives the molecule children nodes """ + return self._children + + @children.setter + def children(self, value: List[MoleculeNode]) -> None: + self._children = value + + @property + def prop(self) -> StrDict: + return {"solved": False, "reaction": self.reaction} + + def serialize(self, molecule_store: MoleculeSerializer) -> StrDict: + """ + Serialize the node object to a dictionary + + :param molecule_store: the serialized molecules + :return: the serialized node + """ + dict_ = { + "reaction": serialize_action(self.reaction, molecule_store), + "children": [child.serialize(molecule_store) for child in self.children], + } + return dict_ diff --git a/aizynthfinder/search/breadth_first/search_tree.py b/aizynthfinder/search/breadth_first/search_tree.py new file mode 100644 index 0000000..e905562 --- /dev/null +++ b/aizynthfinder/search/breadth_first/search_tree.py @@ -0,0 +1,160 @@ +""" Module containing a class that holds the tree search +""" +from __future__ import annotations +import json +from typing import TYPE_CHECKING + +from aizynthfinder.search.breadth_first.nodes import MoleculeNode +from aizynthfinder.utils.logging import logger +from aizynthfinder.search.andor_trees import AndOrSearchTreeBase, SplitAndOrTree +from aizynthfinder.chem.serialization import MoleculeDeserializer, MoleculeSerializer + +if TYPE_CHECKING: + from aizynthfinder.context.config import Configuration + from aizynthfinder.reactiontree import ReactionTree + from aizynthfinder.chem import RetroReaction + from aizynthfinder.utils.type_utils import Optional, Sequence, List + + +class SearchTree(AndOrSearchTreeBase): + """ + Encapsulation of the a breadth-first exhaustive search algorithm + + :ivar config: settings of the tree search algorithm + :ivar root: the root node + + :param config: settings of the tree search algorithm + :param root_smiles: the root will be set to a node representing this molecule, defaults to None + """ + + def __init__(self, config: Configuration, root_smiles: str = None) -> None: + super().__init__(config, root_smiles) + self._mol_nodes: List[MoleculeNode] = [] + self._added_mol_nodes: List[MoleculeNode] = [] + self._logger = logger() + + if root_smiles: + self.root: Optional[MoleculeNode] = MoleculeNode.create_root( + root_smiles, config + ) + self._mol_nodes.append(self.root) + else: + self.root = None + + self._routes: List[ReactionTree] = [] + + self.profiling = { + "expansion_calls": 0, + "reactants_generations": 0, + } + + @classmethod + def from_json(cls, filename: str, config: Configuration) -> SearchTree: + """ + Create a new search tree by deserialization from a JSON file + + :param filename: the path to the JSON node + :param config: the configuration of the search tree + :return: a deserialized tree + """ + + def _find_mol_nodes(node): + for child_ in node.children: + tree._mol_nodes.append(child_) # pylint: disable=protected-access + for grandchild in child_.children: + _find_mol_nodes(grandchild) + + tree = cls(config) + with open(filename, "r") as fileobj: + dict_ = json.load(fileobj) + mol_deser = MoleculeDeserializer(dict_["molecules"]) + tree.root = MoleculeNode.from_dict(dict_["tree"], config, mol_deser) + tree._mol_nodes.append(tree.root) # pylint: disable=protected-access + for child in tree.root.children: + _find_mol_nodes(child) + return tree + + @property + def mol_nodes(self) -> Sequence[MoleculeNode]: # type: ignore + """ Return the molecule nodes of the tree """ + return self._mol_nodes + + def one_iteration(self) -> bool: + """ + Perform one iteration expansion. + Expands all expandable molecule nodes in the tree, which should be + on the same depth of the tree. + + :raises StopIteration: if the search should be pre-maturely terminated + :return: if a solution was found + :rtype: bool + """ + if self.root is None: + raise ValueError("Root is undefined. Cannot make an iteration") + + self._routes = [] + self._added_mol_nodes = [] + + for next_node in self._mol_nodes: + if next_node.expandable: + self._expand(next_node) + + if not self._added_mol_nodes: + self._logger.debug("No new nodes added in breadth-first iteration") + raise StopIteration + + self._mol_nodes.extend(self._added_mol_nodes) + solved = all(node.in_stock for node in self._mol_nodes if not node.children) + return solved + + def routes(self) -> List[ReactionTree]: + """ + Extracts and returns routes from the AND/OR tree + + :return: the routes + """ + if self.root is None: + return [] + if not self._routes: + self._routes = SplitAndOrTree(self.root, self.config.stock).routes + return self._routes + + def serialize(self, filename: str) -> None: + """ + Seralize the search tree to a JSON file + + :param filename: the path to the JSON file + :type filename: str + """ + if self.root is None: + raise ValueError("Cannot serialize tree as root is not defined") + + mol_ser = MoleculeSerializer() + dict_ = {"tree": self.root.serialize(mol_ser), "molecules": mol_ser.store} + with open(filename, "w") as fileobj: + json.dump(dict_, fileobj, indent=2) + + def _expand(self, node: MoleculeNode) -> None: + node.expandable = False + reactions, _ = self.config.expansion_policy([node.mol]) + self.profiling["expansion_calls"] += 1 + + if not reactions: + return + + reactions_to_expand = [] + for reaction in reactions: + try: + self.profiling["reactants_generations"] += 1 + _ = reaction.reactants + except: # pylint: disable=bare-except + continue + if not reaction.reactants: + continue + for idx, _ in enumerate(reaction.reactants): + rxn_copy = reaction.copy(idx) + reactions_to_expand.append(rxn_copy) + + for rxn in reactions_to_expand: + new_nodes = node.add_stub(rxn) + self._added_mol_nodes.extend(new_nodes) diff --git a/aizynthfinder/search/dfpn/__init__.py b/aizynthfinder/search/dfpn/__init__.py new file mode 100644 index 0000000..4646949 --- /dev/null +++ b/aizynthfinder/search/dfpn/__init__.py @@ -0,0 +1,3 @@ +""" Sub-package containing DFPN routines +""" +from aizynthfinder.search.dfpn.search_tree import SearchTree diff --git a/aizynthfinder/search/dfpn/nodes.py b/aizynthfinder/search/dfpn/nodes.py new file mode 100644 index 0000000..eee4eab --- /dev/null +++ b/aizynthfinder/search/dfpn/nodes.py @@ -0,0 +1,343 @@ +""" Module containing a classes representation various tree nodes +""" +from __future__ import annotations +from typing import TYPE_CHECKING + +import numpy as np + +from aizynthfinder.chem import TreeMolecule +from aizynthfinder.search.andor_trees import TreeNodeMixin + + +if TYPE_CHECKING: + from aizynthfinder.context.config import Configuration + from aizynthfinder.utils.type_utils import StrDict, Sequence, Set, List, Optional + from aizynthfinder.chem import RetroReaction + from aizynthfinder.search.dfpn import SearchTree + +BIG_INT = int(1e10) + + +class _SuperNode(TreeNodeMixin): + def __init__(self) -> None: + # pylint: disable=invalid-name + self.pn = 1 # Proof-number + self.dn = 1 # Disproof-number + self.pn_threshold = BIG_INT + self.dn_threshold = BIG_INT + self._children: List["_SuperNode"] = [] + self.expandable = True + + @property # type: ignore + def children(self) -> List[ReactionNode]: # type: ignore + """ Gives the reaction children nodes """ + return self._children # type: ignore + + @property + def closed(self) -> bool: + """ Return if the node is proven or disproven """ + return self.proven or self.disproven + + @property + def proven(self) -> bool: + """ Return if the node is proven""" + return self.pn == 0 + + @property + def disproven(self) -> bool: + """ Return if the node is disproven""" + return self.dn == 0 + + def explorable(self) -> bool: + """ Return if the node can be explored by the search algorithm""" + return not ( + self.closed or self.pn > self.pn_threshold or self.dn > self.dn_threshold + ) + + def reset(self) -> None: + """ Reset the thresholds """ + if self.closed or self.expandable: + return + for child in self._children: + child.reset() + self.update() + self.pn_threshold = BIG_INT + self.dn_threshold = BIG_INT + + def update(self) -> None: + """ Update the proof and disproof numbers """ + raise NotImplementedError("Implement a child class") + + def _set_disproven(self) -> None: + self.pn = BIG_INT + self.dn = 0 + + def _set_proven(self) -> None: + self.pn = 0 + self.dn = BIG_INT + + +class MoleculeNode(_SuperNode): + """ + An OR node representing a molecule + + :ivar expandable: if True, this node is part of the frontier + :ivar mol: the molecule represented by the node + :ivar in_stock: if True the molecule is in stock and hence should not be expanded + :ivar parent: the parent of the node + :ivar pn: the proof number + :ivar dn: the disproof number + :ivar pn_threshold: the threshold for proof number + :ivar dn_threshold: the threshold for disproof number + + :param mol: the molecule to be represented by the node + :param config: the configuration of the search + :param parent: the parent of the node, optional + """ + + def __init__( + self, + mol: TreeMolecule, + config: Configuration, + owner: SearchTree, + parent: ReactionNode = None, + ) -> None: + super().__init__() + + self.mol = mol + self._config = config + self.in_stock = mol in config.stock + self.parent = parent + self._edge_costs: List[int] = [] + self.tree = owner + + # Makes it unexpandable if we have reached maximum depth + self.expandable = self.mol.transform <= self._config.max_transforms + + if self.in_stock: + self.expandable = False + self._set_proven() + elif not self.expandable: + self._set_disproven() + + @classmethod + def create_root( + cls, smiles: str, config: Configuration, owner: SearchTree + ) -> "MoleculeNode": + """ + Create a root node for a tree using a SMILES. + + :param smiles: the SMILES representation of the root state + :param config: settings of the tree search algorithm + :return: the created node + """ + mol = TreeMolecule(parent=None, transform=0, smiles=smiles) + return MoleculeNode(mol=mol, config=config, owner=owner) + + @property + def prop(self) -> StrDict: + return {"solved": self.proven, "mol": self.mol} + + def expand(self) -> None: + """ Expand the molecule by utilising an expansion policy """ + self.expandable = False + reactions, priors = self._config.expansion_policy([self.mol]) + self.tree.profiling["expansion_calls"] += 1 + + if not reactions: + self._set_disproven() + return + + costs = -np.log(np.clip(priors, 1e-3, 1.0)) + reaction_costs = [] + reactions_to_expand = [] + for reaction, cost in zip(reactions, costs): + try: + _ = reaction.reactants + self.tree.profiling["reactants_generations"] += 1 + except: # pylint: disable=bare-except + continue + if not reaction.reactants: + continue + for idx, _ in enumerate(reaction.reactants): + rxn_copy = reaction.copy(idx) + reactions_to_expand.append(rxn_copy) + reaction_costs.append(cost) + + for cost, rxn in zip(reaction_costs, reactions_to_expand): + self._add_child(rxn, cost) + + if not self._children: + self._set_disproven() + + def promising_child(self) -> Optional[ReactionNode]: + """ + Find and return the most promising child for exploration + Updates the thresholds on that child + """ + min_indices = np.argsort( + [ + edge_cost + child.pn if not child.closed else BIG_INT + for edge_cost, child in zip(self._edge_costs, self._children) + ] + ) + best_child = self._children[min_indices[0]] + if len(self._children) > 1 and not self._children[min_indices[1]].closed: + s2_pn = self._children[min_indices[1]].pn + else: + s2_pn = BIG_INT + + best_child.pn_threshold = ( + min(self.pn_threshold, s2_pn + 2) - self._edge_costs[min_indices[0]] + ) + best_child.dn_threshold = self.dn_threshold - self.dn + best_child.dn + return best_child + + def update(self) -> None: + """ Update the proof and disproof numbers """ + func = all if self.parent is None else any + if func(child.proven for child in self._children): + self._set_proven() + return + if all(child.disproven for child in self._children): + self._set_disproven() + return + + child_dns = [child.dn for child in self._children if not child.closed] + if not child_dns: + self._set_proven() + return + + self.dn = sum(child_dns) + if self.dn >= BIG_INT: + self.pn = 0 + else: + self.pn = min( + edge_cost + child.pn + for edge_cost, child in zip(self._edge_costs, self._children) + if not child.closed + ) + return + + def _add_child(self, reaction: RetroReaction, _: float) -> None: + reactants = reaction.reactants[reaction.index] + if not reactants: + return + + ancestors = self._ancestors() + for mol in reactants: + if mol in ancestors: + return + + rxn_node = ReactionNode( + reaction=reaction, config=self._config, owner=self.tree, parent=self + ) + self._children.append(rxn_node) + self._edge_costs.append(1) + + def _ancestors(self) -> Set[TreeMolecule]: + if not self.parent: + return {self.mol} + + # pylint: disable=protected-access + ancestors = self.parent.parent._ancestors() + ancestors.add(self.mol) + return ancestors + + +class ReactionNode(_SuperNode): + """ + An AND node representing a reaction + + :ivar parent: the parent of the node + :ivar reaction: the reaction represented by the node + :ivar pn: the proof number + :ivar dn: the disproof number + :ivar pn_threshold: the threshold for proof number + :ivar dn_threshold: the threshold for disproof number + :ivar expandable: if the node is expandable + + :param reaction: the reaction to be represented by the node + :param config: the configuration of the search + :param parent: the parent of the node + """ + + def __init__( + self, + reaction: RetroReaction, + config: Configuration, + owner: SearchTree, + parent: MoleculeNode, + ) -> None: + super().__init__() + self._config = config + self.parent = parent + self.reaction = reaction + self.tree = owner + + @property # type: ignore + def children(self) -> List[MoleculeNode]: # type: ignore + """ Gives the molecule children nodes """ + return self._children # type: ignore + + @property + def prop(self) -> StrDict: + return {"solved": self.proven, "reaction": self.reaction} + + @property + def proven(self) -> bool: + """ Return if the node is proven """ + if self.expandable: + return False + if self.pn == 0: + return True + return all(child.proven for child in self._children) + + @property + def disproven(self) -> bool: + """ Return if the node is disproven """ + if self.expandable: + return False + if self.dn == 0: + return True + return any(child.disproven for child in self._children) + + def expand(self) -> None: + """ Expand the node by creating nodes for each reactant """ + self.expandable = False + reactants = self.reaction.reactants[self.reaction.index] + self._children = [ + MoleculeNode(mol=mol, config=self._config, owner=self.tree, parent=self) + for mol in reactants + ] + + def promising_child(self) -> Optional[MoleculeNode]: + """ + Find and return the most promising child for exploration + Updates the thresholds on that child + """ + min_indices = np.argsort( + [child.dn if not child.closed else BIG_INT for child in self._children] + ) + + best_child = self._children[min_indices[0]] + if len(self._children) > 1 and not self._children[min_indices[1]].closed: + s2_dn = self._children[min_indices[1]].dn + else: + s2_dn = BIG_INT + + best_child.pn_threshold = self.pn_threshold - self.pn + best_child.pn + best_child.dn_threshold = min(self.dn_threshold, s2_dn + 1) + return best_child + + def update(self) -> None: + """ Update the proof and disproof numbers""" + if all(child.proven for child in self._children): + self._set_proven() + return + if any(child.disproven for child in self._children): + self._set_disproven() + return + + self.pn = sum(child.pn for child in self._children if not child.closed) + self.dn = min(child.dn for child in self._children if not child.closed) diff --git a/aizynthfinder/search/dfpn/search_tree.py b/aizynthfinder/search/dfpn/search_tree.py new file mode 100644 index 0000000..26c0319 --- /dev/null +++ b/aizynthfinder/search/dfpn/search_tree.py @@ -0,0 +1,141 @@ +""" Module containing a class that holds the tree search +""" +from __future__ import annotations +from typing import TYPE_CHECKING + +from aizynthfinder.search.dfpn.nodes import MoleculeNode, ReactionNode +from aizynthfinder.utils.logging import logger +from aizynthfinder.search.andor_trees import AndOrSearchTreeBase, SplitAndOrTree +from aizynthfinder.reactiontree import ReactionTree + + +if TYPE_CHECKING: + from aizynthfinder.search.andor_trees import TreeNodeMixin + from aizynthfinder.context.config import Configuration + from aizynthfinder.chem import RetroReaction + from aizynthfinder.utils.type_utils import Optional, Sequence, List, Union + + +class SearchTree(AndOrSearchTreeBase): + """ + Encapsulation of the Depth-First Proof-Number (DFPN) search algorithm. + + This algorithm does not support: + 1. Filter policy + 2. Serialization and deserialization + + :ivar config: settings of the tree search algorithm + :ivar root: the root node + + :param config: settings of the tree search algorithm + :param root_smiles: the root will be set to a node representing this molecule, defaults to None + """ + + def __init__(self, config: Configuration, root_smiles: str = None) -> None: + super().__init__(config, root_smiles) + self._mol_nodes: List[MoleculeNode] = [] + self._logger = logger() + self._root_smiles = root_smiles + if root_smiles: + self.root: Optional[MoleculeNode] = MoleculeNode.create_root( + root_smiles, config, self + ) + self._mol_nodes.append(self.root) + else: + self.root = None + + self._routes: List[ReactionTree] = [] + self._frontier: Optional[Union[MoleculeNode, ReactionNode]] = None + self._initiated = False + + self.profiling = { + "expansion_calls": 0, + "reactants_generations": 0, + } + + @property + def mol_nodes(self) -> Sequence[MoleculeNode]: # type: ignore + """ Return the molecule nodes of the tree """ + return self._mol_nodes + + def one_iteration(self) -> bool: + """ + Perform one iteration of expansion. + + If possible expand the frontier node twice, i.e. expanding an OR + node and then and AND node. If frontier not expandable step up in the + tree and find a new frontier to expand. + + If a solution is found, mask that tree for exploration and start over. + + :raises StopIteration: if the search should be pre-maturely terminated + :return: if a solution was found + :rtype: bool + """ + if not self._initiated: + if self.root is None: + raise ValueError("Root is undefined. Cannot make an iteration") + + self._routes = [] + self._frontier = self.root + assert self.root is not None + + while True: + # Expand frontier, should be OR node + assert isinstance(self._frontier, MoleculeNode) + expanded_or = self._search_step() + expanded_and = False + if self._frontier: + # Expand frontier again, this time an AND node + assert isinstance(self._frontier, ReactionNode) + expanded_and = self._search_step() + if ( + expanded_or + or expanded_and + or self._frontier is None + or self._frontier is self.root + ): + break + + found_solution = any(child.proven for child in self.root.children) + if self._frontier is self.root: + self.root.reset() + + if self._frontier is None: + raise StopIteration() + + return found_solution + + def routes(self) -> List[ReactionTree]: + """ + Extracts and returns routes from the AND/OR tree + + :return: the routes + """ + if self.root is None: + return [] + if not self._routes: + self._routes = SplitAndOrTree(self.root, self.config.stock).routes + return self._routes + + def _search_step(self) -> bool: + assert self._frontier is not None + expanded = False + if self._frontier.expandable: + self._frontier.expand() + expanded = True + if isinstance(self._frontier, ReactionNode): + self._mol_nodes.extend(self._frontier.children) + + self._frontier.update() + if not self._frontier.explorable(): + self._frontier = self._frontier.parent + return False + + child = self._frontier.promising_child() + if not child: + self._frontier = self._frontier.parent + return False + + self._frontier = child + return expanded diff --git a/aizynthfinder/search/retrostar/__init__.py b/aizynthfinder/search/retrostar/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aizynthfinder/search/retrostar/cost.py b/aizynthfinder/search/retrostar/cost.py new file mode 100644 index 0000000..5ba0980 --- /dev/null +++ b/aizynthfinder/search/retrostar/cost.py @@ -0,0 +1,78 @@ +""" Module containing Retro* cost model """ +from __future__ import annotations +from typing import TYPE_CHECKING +import pickle + +import numpy as np + +if TYPE_CHECKING: + from aizynthfinder.utils.type_utils import Tuple, List + from aizynthfinder.chem import Molecule + + +class RetroStarCost: + """ + Encapsulation of a the original Retro* molecular cost model + + Numpy implementation of original pytorch model + + The predictions of the score is made on a Molecule object + + .. code-block:: + + mol = Molecule(smiles="CCC") + scorer = RetroStarCost() + score = scorer(mol) + + The model provided when creating the scorer object should be a pickled + tuple. + The first item of the tuple should be a list of the model weights for each layer. + The second item of the tuple should be a list of the model biases for each layer. + + :param model_path: the filename of the model weights and biases + :param fingerprint_length: the number of bits in the fingerprint + :param fingerprint_radius: the radius of the fingerprint + :param dropout_rate: the dropout_rate + """ + + def __init__( + self, + model_path: str, + fingerprint_length: int = 2048, + fingerprint_radius: int = 2, + dropout_rate: float = 0.1, + ): + self._dropout_prob = 1.0 - dropout_rate + self._fingerprint_length = fingerprint_length + self._fingerprint_radius = fingerprint_radius + self._weights, self._biases = self._load_model(model_path) + + def __call__(self, mol: Molecule) -> float: + # pylint: disable=invalid-name + mol.sanitize() + vec = mol.fingerprint( + radius=self._fingerprint_radius, nbits=self._fingerprint_length + ) + for W, b in zip(self._weights[:-1], self._biases[:-1]): + vec = np.matmul(vec, W) + b + vec *= vec > 0 # ReLU + # Drop-out + vec *= np.random.binomial(1, self._dropout_prob, size=vec.shape) / ( + self._dropout_prob + ) + vec = np.matmul(vec, self._weights[-1]) + self._biases[-1] + return float(np.log(1 + np.exp(vec))) + + def __repr__(self) -> str: + return "retrostar" + + @staticmethod + def _load_model(model_path: str) -> Tuple[List[np.ndarray], List[np.ndarray]]: + + with open(model_path, "rb") as fileobj: + weights, biases = pickle.load(fileobj) + + return ( + [np.asarray(item) for item in weights], + [np.asarray(item) for item in biases], + ) diff --git a/aizynthfinder/search/retrostar/nodes.py b/aizynthfinder/search/retrostar/nodes.py new file mode 100644 index 0000000..31acbac --- /dev/null +++ b/aizynthfinder/search/retrostar/nodes.py @@ -0,0 +1,353 @@ +""" Module containing a classes representation various tree nodes +""" +from __future__ import annotations +from typing import TYPE_CHECKING + +import numpy as np + +from aizynthfinder.chem import TreeMolecule +from aizynthfinder.search.andor_trees import TreeNodeMixin +from aizynthfinder.chem.serialization import deserialize_action, serialize_action + +if TYPE_CHECKING: + from aizynthfinder.context.config import Configuration + from aizynthfinder.chem.serialization import ( + MoleculeDeserializer, + MoleculeSerializer, + ) + from aizynthfinder.utils.type_utils import ( + StrDict, + Sequence, + Set, + List, + ) + from aizynthfinder.chem import RetroReaction + + +class MoleculeNode(TreeNodeMixin): + """ + An OR node representing a molecule + + :ivar cost: the cost of synthesizing the molecule + :ivar expandable: if True, this node is part of the frontier + :ivar mol: the molecule represented by the node + :ivar in_stock: if True the molecule is in stock and hence should not be expanded + :ivar parent: the parent of the node + :ivar solved: if True the molecule is in stock or at least one child node is solved + :ivar value: the current rn(m|T) + + :param mol: the molecule to be represented by the node + :param config: the configuration of the search + :param parent: the parent of the node, optional + """ + + def __init__( + self, mol: TreeMolecule, config: Configuration, parent: ReactionNode = None + ) -> None: + self.mol = mol + self._config = config + self.cost = config.molecule_cost(mol) + self.value = self.cost + self.in_stock = mol in config.stock + self.parent = parent + + self._children: List[ReactionNode] = [] + self.solved = self.in_stock + # Makes it unexpandable if we have reached maximum depth + self.expandable = self.mol.transform <= self._config.max_transforms + + if self.in_stock: + self.expandable = False + self.value = 0 + + @classmethod + def create_root(cls, smiles: str, config: Configuration) -> "MoleculeNode": + """ + Create a root node for a tree using a SMILES. + + :param smiles: the SMILES representation of the root state + :param config: settings of the tree search algorithm + :return: the created node + """ + mol = TreeMolecule(parent=None, transform=0, smiles=smiles) + return MoleculeNode(mol=mol, config=config) + + @classmethod + def from_dict( + cls, + dict_: StrDict, + config: Configuration, + molecules: MoleculeDeserializer, + parent: ReactionNode = None, + ) -> "MoleculeNode": + """ + Create a new node from a dictionary, i.e. deserialization + + :param dict_: the serialized node + :param config: settings of the tree search algorithm + :param molecules: the deserialized molecules + :param parent: the parent node + :return: a deserialized node + """ + mol = molecules.get_tree_molecules([dict_["mol"]])[0] + node = MoleculeNode(mol, config, parent) + for attr in ["cost", "expandable", "value"]: + setattr(node, attr, dict_[attr]) + node.children = [ + ReactionNode.from_dict(child, config, molecules, parent=node) + for child in dict_["children"] + ] + return node + + @property # type: ignore + def children(self) -> List[ReactionNode]: # type: ignore + """ Gives the reaction children nodes """ + return self._children + + @children.setter + def children(self, value: List[ReactionNode]) -> None: + self._children = value + + @property + def target_value(self) -> float: + """ + The V_t(m|T) value, + the current cost of the tree containing this node + + :return: the value + """ + if self.parent: + return self.parent.target_value + return self.value + + @property + def prop(self) -> StrDict: + return {"solved": self.solved, "mol": self.mol} + + def add_stub(self, cost: float, reaction: RetroReaction) -> Sequence[MoleculeNode]: + """ + Add a stub / sub-tree to this node + + :param cost: the cost of the reaction + :param reaction: the reaction creating the stub + :return: list of all newly added molecular nodes + """ + reactants = reaction.reactants[reaction.index] + if not reactants: + return [] + + ancestors = self.ancestors() + for mol in reactants: + if mol in ancestors: + return [] + + rxn_node = ReactionNode.create_stub( + cost=cost, reaction=reaction, parent=self, config=self._config + ) + self._children.append(rxn_node) + + return rxn_node.children + + def ancestors(self) -> Set[TreeMolecule]: + """ + Return the ancestors of this node + + :return: the ancestors + :rtype: set + """ + if not self.parent: + return {self.mol} + + ancestors = self.parent.parent.ancestors() + ancestors.add(self.mol) + return ancestors + + def close(self) -> float: + """ + Updates the values of this node after expanding it. + + :return: the delta V value + :rtype: float + """ + self.solved = any(child.solved for child in self.children) + if self.children: + new_value = np.min([child.value for child in self.children]) + else: + new_value = np.inf + + v_delta = new_value - self.value + self.value = new_value + + self.expandable = False + return v_delta + + def serialize(self, molecule_store: MoleculeSerializer) -> StrDict: + """ + Serialize the node object to a dictionary + + :param molecule_store: the serialized molecules + :return: the serialized node + """ + dict_ = {attr: getattr(self, attr) for attr in ["cost", "expandable", "value"]} + dict_["mol"] = molecule_store[self.mol] + dict_["children"] = [child.serialize(molecule_store) for child in self.children] + return dict_ + + def update(self, solved: bool) -> None: + """ + Update the node as part of the update algorithm, + calling the `update()` method of its parent if available. + + :param solved: if the child node was solved + """ + new_value = np.min([child.value for child in self.children]) + new_solv = self.solved or solved + updated = (self.value != new_value) or (self.solved != new_solv) + + v_delta = new_value - self.value + self.value = new_value + self.solved = new_solv + + if updated and self.parent: + self.parent.update(v_delta, from_mol=self.mol) + + +class ReactionNode(TreeNodeMixin): + """ + An AND node representing a reaction + + :ivar cost: the cost of the reaction + :ivar parent: the parent of the node + :ivar reaction: the reaction represented by the node + :ivar solved: if True all children nodes are solved + :ivar target_value: the V(m|T) for the children, the current cost + :ivar value: the current rn(r|T) + + :param cost: the cost of the reaction + :param reaction: the reaction to be represented by the node + :param parent: the parent of the node + """ + + def __init__( + self, cost: float, reaction: RetroReaction, parent: MoleculeNode + ) -> None: + self.parent = parent + self.cost = cost + self.reaction = reaction + + self._children: List[MoleculeNode] = [] + self.solved = False + # rn(R|T) + self.value = self.cost + # V(R|T) = V(m|T) for m in children + self.target_value = self.parent.target_value - self.parent.value + self.value + + @classmethod + def create_stub( + cls, + cost: float, + reaction: RetroReaction, + parent: MoleculeNode, + config: Configuration, + ) -> ReactionNode: + """ + Create a ReactionNode and creates all the MoleculeNode objects + that are the children of the node. + + :param cost: the cost of the reaction + :param reaction: the reaction to be represented by the node + :param parent: the parent of the node + :param config: the configuration of the search tree + """ + node = cls(cost, reaction, parent) + reactants = reaction.reactants[reaction.index] + node.children = [ + MoleculeNode(mol=mol, config=config, parent=node) for mol in reactants + ] + node.solved = all(child.solved for child in node.children) + # rn(R|T) + node.value = node.cost + sum(child.value for child in node.children) + # V(R|T) = V(m|T) for m in children + node.target_value = node.parent.target_value - node.parent.value + node.value + return node + + @classmethod + def from_dict( + cls, + dict_: StrDict, + config: Configuration, + molecules: MoleculeDeserializer, + parent: MoleculeNode, + ) -> ReactionNode: + """ + Create a new node from a dictionary, i.e. deserialization + + :param dict_: the serialized node + :param config: the configuration of the tree search + :param molecules: the deserialized molecules + :param parent: the parent node + :return: a deserialized node + """ + reaction = deserialize_action(dict_["reaction"], molecules) + node = cls(0, reaction, parent) + for attr in ["cost", "value", "target_value"]: + setattr(node, attr, dict_[attr]) + node.children = [ + MoleculeNode.from_dict(child, config, molecules, parent=node) + for child in dict_["children"] + ] + node.solved = all(child.solved for child in node.children) + return node + + @property # type: ignore + def children(self) -> List[MoleculeNode]: # type: ignore + """ Gives the molecule children nodes """ + return self._children + + @children.setter + def children(self, value: List[MoleculeNode]) -> None: + self._children = value + + @property + def prop(self) -> StrDict: + return {"solved": self.solved, "reaction": self.reaction} + + def serialize(self, molecule_store: MoleculeSerializer) -> StrDict: + """ + Serialize the node object to a dictionary + + :param molecule_store: the serialized molecules + :return: the serialized node + """ + dict_ = { + attr: getattr(self, attr) for attr in ["cost", "value", "target_value"] + } + dict_["reaction"] = serialize_action(self.reaction, molecule_store) + dict_["children"] = [child.serialize(molecule_store) for child in self.children] + return dict_ + + def update(self, value: float, from_mol: TreeMolecule = None) -> None: + """ + Update the node as part of the update algorithm, + calling the `update()` method of its parent + + :param value: the delta V value + :param from_mol: the molecule being expanded, used for excluding propagation + """ + self.value += value + self.target_value += value + self.solved = all(node.solved for node in self.children) + + if value != 0: + self._propagate(value, exclude=from_mol) + + self.parent.update(self.solved) + + def _propagate(self, value: float, exclude: TreeMolecule = None) -> None: + if not exclude: + self.target_value += value + + for child in self.children: + if exclude is None or child.mol is not exclude: + for grandchild in child.children: + grandchild._propagate(value) # pylint: disable=protected-access diff --git a/aizynthfinder/search/retrostar/search_tree.py b/aizynthfinder/search/retrostar/search_tree.py new file mode 100644 index 0000000..fac2d8d --- /dev/null +++ b/aizynthfinder/search/retrostar/search_tree.py @@ -0,0 +1,198 @@ +""" Module containing a class that holds the tree search +""" +from __future__ import annotations +import json +from typing import TYPE_CHECKING + +import numpy as np + +from aizynthfinder.search.retrostar.nodes import MoleculeNode +from aizynthfinder.utils.logging import logger +from aizynthfinder.search.andor_trees import AndOrSearchTreeBase, SplitAndOrTree +from aizynthfinder.chem.serialization import MoleculeDeserializer, MoleculeSerializer +from aizynthfinder.utils.exceptions import RejectionException + +if TYPE_CHECKING: + from aizynthfinder.context.config import Configuration + from aizynthfinder.reactiontree import ReactionTree + from aizynthfinder.chem import RetroReaction + from aizynthfinder.utils.type_utils import Optional, Sequence, List + + +class SearchTree(AndOrSearchTreeBase): + """ + Encapsulation of the Retro* search tree (an AND/OR tree). + + :ivar config: settings of the tree search algorithm + :ivar root: the root node + + :param config: settings of the tree search algorithm + :param root_smiles: the root will be set to a node representing this molecule, defaults to None + """ + + def __init__(self, config: Configuration, root_smiles: str = None) -> None: + super().__init__(config, root_smiles) + self._mol_nodes: List[MoleculeNode] = [] + self._logger = logger() + + if root_smiles: + self.root: Optional[MoleculeNode] = MoleculeNode.create_root( + root_smiles, config + ) + self._mol_nodes.append(self.root) + else: + self.root = None + + self._routes: List[ReactionTree] = [] + + self.profiling = { + "expansion_calls": 0, + "reactants_generations": 0, + } + + @classmethod + def from_json(cls, filename: str, config: Configuration) -> SearchTree: + """ + Create a new search tree by deserialization from a JSON file + + :param filename: the path to the JSON node + :param config: the configuration of the search tree + :return: a deserialized tree + """ + + def _find_mol_nodes(node): + for child_ in node.children: + tree._mol_nodes.append(child_) # pylint: disable=protected-access + for grandchild in child_.children: + _find_mol_nodes(grandchild) + + tree = cls(config) + with open(filename, "r") as fileobj: + dict_ = json.load(fileobj) + mol_deser = MoleculeDeserializer(dict_["molecules"]) + tree.root = MoleculeNode.from_dict(dict_["tree"], config, mol_deser) + tree._mol_nodes.append(tree.root) # pylint: disable=protected-access + for child in tree.root.children: + _find_mol_nodes(child) + return tree + + @property + def mol_nodes(self) -> Sequence[MoleculeNode]: # type: ignore + """ Return the molecule nodes of the tree """ + return self._mol_nodes + + def one_iteration(self) -> bool: + """ + Perform one iteration of + 1. Selection + 2. Expansion + 3. Update + + :raises StopIteration: if the search should be pre-maturely terminated + :return: if a solution was found + :rtype: bool + """ + if self.root is None: + raise ValueError("Root is undefined. Cannot make an iteration") + + self._routes = [] + + next_node = self._select() + + if not next_node: + self._logger.debug("No expandable nodes in Retro* iteration") + raise StopIteration + + self._expand(next_node) + + if not next_node.children: + next_node.expandable = False + + self._update(next_node) + + return self.root.solved + + def routes(self) -> List[ReactionTree]: + """ + Extracts and returns routes from the AND/OR tree + + :return: the routes + """ + if self.root is None: + return [] + if not self._routes: + self._routes = SplitAndOrTree(self.root, self.config.stock).routes + return self._routes + + def serialize(self, filename: str) -> None: + """ + Seralize the search tree to a JSON file + + :param filename: the path to the JSON file + :type filename: str + """ + if self.root is None: + raise ValueError("Cannot serialize tree as root is not defined") + + mol_ser = MoleculeSerializer() + dict_ = {"tree": self.root.serialize(mol_ser), "molecules": mol_ser.store} + with open(filename, "w") as fileobj: + json.dump(dict_, fileobj, indent=2) + + def _expand(self, node: MoleculeNode) -> None: + reactions, priors = self.config.expansion_policy([node.mol]) + self.profiling["expansion_calls"] += 1 + + if not reactions: + return + + costs = -np.log(np.clip(priors, 1e-3, 1.0)) + reactions_to_expand = [] + reaction_costs = [] + for reaction, cost in zip(reactions, costs): + try: + self.profiling["reactants_generations"] += 1 + _ = reaction.reactants + except: # pylint: disable=bare-except + continue + if not reaction.reactants: + continue + for idx, _ in enumerate(reaction.reactants): + rxn_copy = reaction.copy(idx) + if self._filter_reaction(rxn_copy): + continue + reactions_to_expand.append(rxn_copy) + reaction_costs.append(cost) + + for cost, rxn in zip(reaction_costs, reactions_to_expand): + new_nodes = node.add_stub(cost, rxn) + self._mol_nodes.extend(new_nodes) + + def _filter_reaction(self, reaction: RetroReaction) -> bool: + if not self.config.filter_policy.selection: + return False + try: + self.config.filter_policy(reaction) + except RejectionException as err: + self._logger.debug(str(err)) + return True + return False + + def _select(self) -> Optional[MoleculeNode]: + scores = np.asarray( + [ + node.target_value if node.expandable else np.inf + for node in self._mol_nodes + ] + ) + + if scores.min() == np.inf: + return None + + return self._mol_nodes[int(np.argmin(scores))] + + @staticmethod + def _update(node: MoleculeNode) -> None: + v_delta = node.close() + if node.parent and np.isfinite(v_delta): + node.parent.update(v_delta, from_mol=node.mol) diff --git a/aizynthfinder/tools/cat_output.py b/aizynthfinder/tools/cat_output.py index cece78f..79eb680 100644 --- a/aizynthfinder/tools/cat_output.py +++ b/aizynthfinder/tools/cat_output.py @@ -20,9 +20,13 @@ def main() -> None: default="output.hdf5", help="the name of the concatenate output file ", ) + parser.add_argument( + "--trees", + help="if given, save all trees to this file", + ) args = parser.parse_args() - cat_hdf_files(args.files, args.output) + cat_hdf_files(args.files, args.output, args.trees) if __name__ == "__main__": diff --git a/aizynthfinder/training/keras_models.py b/aizynthfinder/training/keras_models.py index 5d67eb9..b7449ba 100644 --- a/aizynthfinder/training/keras_models.py +++ b/aizynthfinder/training/keras_models.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING import numpy as np + +# pylint: disable=no-name-in-module from tensorflow.keras.layers import Dense, Dropout, Input, Dot from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam diff --git a/aizynthfinder/utils/files.py b/aizynthfinder/utils/files.py index a83b15a..d6b9ef9 100644 --- a/aizynthfinder/utils/files.py +++ b/aizynthfinder/utils/files.py @@ -4,6 +4,8 @@ import subprocess import time import warnings +import json +import gzip from typing import TYPE_CHECKING import more_itertools @@ -15,21 +17,45 @@ from aizynthfinder.utils.type_utils import List, Sequence, Any, Callable -def cat_hdf_files(input_files: List[str], output_name: str) -> None: +def cat_hdf_files( + input_files: List[str], output_name: str, trees_name: str = None +) -> None: """ Concatenate hdf5 files with the key "table" + if `tree_name` is given, will take out the `trees` column + from the tables and save it to a gzipped-json file. + :param input_files: the paths to the files to concatenate :param output_name: the name of the concatenated file + :param trees_name: the name of the concatenated trees """ data = pd.read_hdf(input_files[0], key="table") + if "trees" not in data.columns: + trees_name = None + + if trees_name: + columns = list(data.columns) + columns.remove("trees") + trees = list(data["trees"].values) + data = data[columns] + for filename in input_files[1:]: new_data = pd.read_hdf(filename, key="table") + if trees_name: + trees.extend(new_data["trees"].values) + new_data = new_data[columns] data = data.append(new_data) with warnings.catch_warnings(): # This wil suppress a PerformanceWarning warnings.simplefilter("ignore") - data.to_hdf(output_name, key="table") + data.reset_index().to_hdf(output_name, key="table") + + if trees_name: + if not trees_name.endswith(".gz"): + trees_name += ".gz" + with gzip.open(trees_name, "wt", encoding="UTF-8") as fileobj: + json.dump(trees, fileobj) def split_file(filename: str, nparts: int) -> List[str]: diff --git a/aizynthfinder/utils/models.py b/aizynthfinder/utils/models.py index 13c4a9d..dc8d4ee 100644 --- a/aizynthfinder/utils/models.py +++ b/aizynthfinder/utils/models.py @@ -15,6 +15,8 @@ get_model_metadata_pb2, prediction_service_pb2_grpc, ) + +# pylint: disable=no-name-in-module from tensorflow.keras.metrics import top_k_categorical_accuracy from tensorflow.keras.models import load_model as load_keras_model diff --git a/docs/conf.py b/docs/conf.py index 2830fbc..90b56e9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,9 +4,9 @@ sys.path.insert(0, os.path.abspath(".")) project = "aizynthfinder" -copyright = "2020, Molecular AI group" +copyright = "2020-2022, Molecular AI group" author = "Molecular AI group" -release = "3.2.0" +release = "3.3.0" # This make sure that the cli_help.txt file is properly formated with open("cli_help.txt", "r") as fileobj: diff --git a/pyproject.toml b/pyproject.toml index 268f40e..d0a23fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aizynthfinder" -version = "3.2.0" +version = "3.3.0" description = "Retrosynthetic route finding using neural network guided Monte-Carlo tree search" authors = ["Molecular AI group "] license = "MIT" diff --git a/tests/breadth_first/__init__.py b/tests/breadth_first/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/breadth_first/test_nodes.py b/tests/breadth_first/test_nodes.py new file mode 100644 index 0000000..7f2f4c9 --- /dev/null +++ b/tests/breadth_first/test_nodes.py @@ -0,0 +1,75 @@ +import pytest + +from aizynthfinder.search.breadth_first.nodes import MoleculeNode +from aizynthfinder.chem.serialization import MoleculeSerializer, MoleculeDeserializer + + +@pytest.fixture +def setup_root(default_config): + def wrapper(smiles): + return MoleculeNode.create_root(smiles, config=default_config) + + return wrapper + + +def test_create_root_node(setup_root): + node = setup_root("CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1") + + assert node.ancestors() == {node.mol} + assert node.expandable + assert not node.children + + +def test_create_stub(setup_root, get_action): + root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" + root = setup_root(root_smiles) + reaction = get_action() + + root.add_stub(reaction=reaction) + + assert len(root.children) == 1 + assert len(root.children[0].children) == 2 + rxn_node = root.children[0] + assert rxn_node.reaction is reaction + exp_list = [node.mol for node in rxn_node.children] + assert exp_list == list(reaction.reactants[0]) + + +def test_initialize_stub_one_solved_leaf( + setup_root, get_action, default_config, setup_stock +): + root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" + root = setup_root(root_smiles) + reaction = get_action() + setup_stock(default_config, reaction.reactants[0][0]) + + root.add_stub(reaction=reaction) + + assert not root.children[0].children[0].expandable + assert root.children[0].children[1].expandable + + +def test_serialization_deserialization( + setup_root, get_action, default_config, setup_stock +): + root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" + root = setup_root(root_smiles) + reaction = get_action() + setup_stock(default_config, *reaction.reactants[0]) + root.add_stub(reaction=reaction) + + molecule_serializer = MoleculeSerializer() + dict_ = root.serialize(molecule_serializer) + + molecule_deserializer = MoleculeDeserializer(molecule_serializer.store) + node = MoleculeNode.from_dict(dict_, default_config, molecule_deserializer) + + assert node.mol == root.mol + assert len(node.children) == len(root.children) + + rxn_node = node.children[0] + assert rxn_node.reaction.smarts == reaction.smarts + assert rxn_node.reaction.metadata == reaction.metadata + + for grandchild1, grandchild2 in zip(rxn_node.children, root.children[0].children): + assert grandchild1.mol == grandchild2.mol diff --git a/tests/breadth_first/test_search.py b/tests/breadth_first/test_search.py new file mode 100644 index 0000000..d24306a --- /dev/null +++ b/tests/breadth_first/test_search.py @@ -0,0 +1,106 @@ +import random + +import pytest + +from aizynthfinder.search.breadth_first.search_tree import SearchTree + + +def test_one_iteration(default_config, setup_policies, setup_stock): + root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" + child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] + child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] + grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] + lookup = { + root_smi: [ + {"smiles": ".".join(child1_smi), "prior": 0.7}, + {"smiles": ".".join(child2_smi), "prior": 0.3}, + ], + child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, + child2_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, + } + stock = [child1_smi[0], child1_smi[2]] + grandchild_smi + setup_policies(lookup, config=default_config) + setup_stock(default_config, *stock) + tree = SearchTree(default_config, root_smi) + + assert len(tree.mol_nodes) == 1 + + assert not tree.one_iteration() + + assert len(tree.mol_nodes) == 6 + smiles = [node.mol.smiles for node in tree.mol_nodes] + assert smiles == [root_smi] + child1_smi + child2_smi + + assert tree.one_iteration() + + assert len(tree.mol_nodes) == 10 + smiles = [node.mol.smiles for node in tree.mol_nodes] + assert ( + smiles == [root_smi] + child1_smi + child2_smi + grandchild_smi + grandchild_smi + ) + + with pytest.raises(StopIteration): + tree.one_iteration() + + +def test_search_incomplete(default_config, setup_policies, setup_stock): + root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" + child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] + child2_smi = ["ClC(=O)c1ccc(F)cc1", "CN1CCC(CC1)C(=O)c1cccc(N)c1F"] + grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] + lookup = { + root_smi: [ + {"smiles": ".".join(child1_smi), "prior": 0.7}, + {"smiles": ".".join(child2_smi), "prior": 0.3}, + ], + child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, + } + stock = [child1_smi[0], child1_smi[2]] + [grandchild_smi[0]] + setup_policies(lookup, config=default_config) + setup_stock(default_config, *stock) + tree = SearchTree(default_config, root_smi) + + assert len(tree.mol_nodes) == 1 + + tree.one_iteration() + assert len(tree.mol_nodes) == 6 + + assert not tree.one_iteration() + + assert len(tree.mol_nodes) == 8 + + with pytest.raises(StopIteration): + tree.one_iteration() + + +def test_routes(default_config, setup_policies, setup_stock): + random.seed(666) + root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" + child1_smi = ["O", "CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] + child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] + grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] + lookup = { + root_smi: [ + {"smiles": ".".join(child1_smi), "prior": 0.7}, + {"smiles": ".".join(child2_smi), "prior": 0.3}, + ], + child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, + } + stock = [child1_smi[0], child1_smi[2]] + grandchild_smi + setup_policies(lookup, config=default_config) + setup_stock(default_config, *stock) + tree = SearchTree(default_config, root_smi) + + while True: + try: + tree.one_iteration() + except StopIteration: + break + + routes = tree.routes() + + assert len(routes) == 2 + smiles = [mol.smiles for mol in routes[1].molecules()] + assert smiles == [root_smi] + child1_smi + grandchild_smi + smiles = [mol.smiles for mol in routes[0].molecules()] + assert smiles == [root_smi] + child2_smi + grandchild_smi diff --git a/tests/dfpn/__init__.py b/tests/dfpn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dfpn/test_nodes.py b/tests/dfpn/test_nodes.py new file mode 100644 index 0000000..9e2d6f5 --- /dev/null +++ b/tests/dfpn/test_nodes.py @@ -0,0 +1,103 @@ +import pytest + +from aizynthfinder.search.dfpn.nodes import MoleculeNode, BIG_INT +from aizynthfinder.search.dfpn import SearchTree + + +@pytest.fixture +def setup_root(default_config): + def wrapper(smiles): + owner = SearchTree(default_config) + return MoleculeNode.create_root(smiles, config=default_config, owner=owner) + + return wrapper + + +def test_create_root_node(setup_root): + node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") + + assert node.expandable + assert not node.children + assert node.dn == 1 + assert node.pn == 1 + + +def test_expand_mol_node( + default_config, setup_root, setup_policies, get_linear_expansion +): + node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") + setup_policies(get_linear_expansion) + + node.expand() + + assert not node.expandable + assert len(node.children) == 1 + + +def test_promising_child( + default_config, setup_root, setup_policies, get_linear_expansion +): + node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") + setup_policies(get_linear_expansion) + node.expand() + + child = node.promising_child() + + assert child is node.children[0] + assert child.pn_threshold == BIG_INT - 1 + assert child.dn_threshold == BIG_INT + + +def test_expand_reaction_node( + default_config, setup_root, setup_policies, get_linear_expansion +): + node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") + setup_policies(get_linear_expansion) + node.expand() + child = node.promising_child() + + child.expand() + + assert len(child.children) == 2 + + +def test_promising_child_reaction_node( + default_config, + setup_root, + setup_policies, + get_linear_expansion, +): + node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") + setup_policies(get_linear_expansion) + node.expand() + + child = node.promising_child() + child.expand() + + grandchild = child.promising_child() + + assert grandchild.mol.smiles == "OOc1ccc(-c2ccccc2)cc1" + assert grandchild.pn_threshold == BIG_INT - 1 + assert grandchild.dn_threshold == 2 + + +def test_update( + default_config, + setup_root, + setup_policies, + get_linear_expansion, + setup_stock, +): + node = setup_root("OOc1ccc(-c2ccc(NC3CCCC(C4C=CC=C4)C3)cc2)cc1") + setup_stock(default_config, "OOc1ccc(-c2ccccc2)cc1", "NC1CCCC(C2C=CC=C2)C1") + setup_policies(get_linear_expansion) + node.expand() + + child = node.promising_child() + child.expand() + child.update() + + node.update() + + assert node.proven + assert not node.disproven diff --git a/tests/dfpn/test_search.py b/tests/dfpn/test_search.py new file mode 100644 index 0000000..c4a029e --- /dev/null +++ b/tests/dfpn/test_search.py @@ -0,0 +1,71 @@ +import random + +import pytest + +from aizynthfinder.search.dfpn.search_tree import SearchTree + + +def test_search(default_config, setup_policies, setup_stock): + root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" + child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] + child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] + grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] + lookup = { + root_smi: [ + {"smiles": ".".join(child1_smi), "prior": 0.7}, + {"smiles": ".".join(child2_smi), "prior": 0.3}, + ], + child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, + child2_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, + } + stock = [child1_smi[0], child1_smi[2]] + grandchild_smi + setup_policies(lookup, config=default_config) + setup_stock(default_config, *stock) + tree = SearchTree(default_config, root_smi) + + assert not tree.one_iteration() + + routes = tree.routes() + assert all([not route.is_solved for route in routes]) + assert len(tree.mol_nodes) == 4 + + assert not tree.one_iteration() + assert tree.one_iteration() + assert tree.one_iteration() + assert tree.one_iteration() + assert tree.one_iteration() + + routes = tree.routes() + assert len(routes) == 2 + assert all(route.is_solved for route in routes) + + with pytest.raises(StopIteration): + tree.one_iteration() + + +def test_search_incomplete(default_config, setup_policies, setup_stock): + root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" + child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] + child2_smi = ["ClC(=O)c1ccc(F)cc1", "CN1CCC(CC1)C(=O)c1cccc(N)c1F"] + grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] + lookup = { + root_smi: [ + {"smiles": ".".join(child1_smi), "prior": 0.7}, + {"smiles": ".".join(child2_smi), "prior": 0.3}, + ], + child1_smi[1]: {"smiles": ".".join(grandchild_smi), "prior": 0.7}, + } + stock = [child1_smi[0], child1_smi[2]] + [grandchild_smi[0]] + setup_policies(lookup, config=default_config) + setup_stock(default_config, *stock) + tree = SearchTree(default_config, root_smi) + + while True: + try: + tree.one_iteration() + except StopIteration: + break + + routes = tree.routes() + assert len(routes) == 2 + assert all(not route.is_solved for route in routes) diff --git a/tests/retrostar/__init__.py b/tests/retrostar/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/retrostar/conftest.py b/tests/retrostar/conftest.py new file mode 100644 index 0000000..7ddb20d --- /dev/null +++ b/tests/retrostar/conftest.py @@ -0,0 +1,59 @@ +import pytest +import numpy as np + + +from aizynthfinder.search.retrostar.search_tree import SearchTree +from aizynthfinder.search.retrostar.nodes import MoleculeNode +from aizynthfinder.aizynthfinder import AiZynthFinder + + +@pytest.fixture +def setup_aizynthfinder(setup_policies, setup_stock): + def wrapper(expansions, stock): + finder = AiZynthFinder() + root_smi = list(expansions.keys())[0] + setup_policies(expansions, config=finder.config) + setup_stock(finder.config, *stock) + finder.target_smiles = root_smi + finder.config.search_algorithm = ( + "aizynthfinder.search.retrostar.search_tree.SearchTree" + ) + return finder + + return wrapper + + +@pytest.fixture +def setup_mocked_model(mocker): + biases = [np.zeros(10), np.zeros(1)] + weights = [np.ones([10, 10]), np.ones([10, 1])] + + mocker.patch("builtins.open") + mocked_pickle_load = mocker.patch("aizynthfinder.search.retrostar.cost.pickle.load") + mocked_pickle_load.return_value = weights, biases + + +@pytest.fixture +def setup_search_tree(default_config, setup_policies, setup_stock): + root_smiles = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" + tree = SearchTree(config=default_config, root_smiles=root_smiles) + lookup = { + root_smiles: { + "smiles": "CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O", + "prior": 1.0, + } + } + setup_policies(lookup) + + setup_stock( + default_config, "CN1CCC(Cl)CC1", "O", "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1" + ) + return tree + + +@pytest.fixture +def setup_star_root(default_config): + def wrapper(smiles): + return MoleculeNode.create_root(smiles, config=default_config) + + return wrapper diff --git a/tests/retrostar/data/andor_tree_for_clustering.json b/tests/retrostar/data/andor_tree_for_clustering.json new file mode 100644 index 0000000..87b9669 --- /dev/null +++ b/tests/retrostar/data/andor_tree_for_clustering.json @@ -0,0 +1,175 @@ +{ + "tree": { + "cost": 0.0, + "expandable": true, + "value": 0.0, + "mol": 140605332219984, + "children": [ + { + "cost": 5.0, + "value": 5.0, + "target_value": 5.0, + "reaction": { + "mol": 140605332219984, + "smarts": "([c:8]-[NH;D2;+0:7]-[c;H0;D3;+0:1]1:[c:2]:[c:3]:[#7;a:4]:[c:5]:[c:6]:1)>>(Cl-[c;H0;D3;+0:1]1:[c:2]:[c:3]:[#7;a:4]:[c:5]:[c:6]:1).([NH2;D1;+0:7]-[c:8])", + "index": 0, + "metadata": {} + }, + "children": [ + { + "cost": 0.0, + "expandable": false, + "value": 0, + "mol": 140605332201936, + "children": [] + }, + { + "cost": 0.0, + "expandable": false, + "value": 0, + "mol": 140605332205072, + "children": [] + } + ] + }, + { + "cost": 10.0, + "value": 10.0, + "target_value": 10.0, + "reaction": { + "mol": 140605332219984, + "smarts": "([S;D1;H0:3]=[C;H0;D3;+0:4](-[NH;D2;+0:1]-[c:2])-[NH;D2;+0:5]-[c:6])>>([NH2;D1;+0:1]-[c:2]).([S;D1;H0:3]=[C;H0;D2;+0:4]=[N;H0;D2;+0:5]-[c:6])", + "index": 1, + "metadata": {} + }, + "children": [ + { + "cost": 0.0, + "expandable": true, + "value": 0.0, + "mol": 140598039857296, + "children": [ + { + "cost": 2.5, + "value": 2.5, + "target_value": 12.5, + "reaction": { + "mol": 140598039857296, + "smarts": "([c:8]-[NH;D2;+0:7]-[c;H0;D3;+0:1]1:[c:2]:[c:3]:[#7;a:4]:[c:5]:[c:6]:1)>>(Cl-[c;H0;D3;+0:1]1:[c:2]:[c:3]:[#7;a:4]:[c:5]:[c:6]:1).([NH2;D1;+0:7]-[c:8])", + "index": 0, + "metadata": {} + }, + "children": [ + { + "cost": 0.0, + "expandable": false, + "value": 0, + "mol": 140598039857616, + "children": [] + }, + { + "cost": 0.0, + "expandable": false, + "value": 0, + "mol": 140605332368528, + "children": [] + } + ] + }, + { + "cost": 7.0, + "value": 7.0, + "target_value": 17.0, + "reaction": { + "mol": 140598039857296, + "smarts": "([c:2]:[c;H0;D3;+0:1](:[c:3])-[NH;D2;+0:4]-[c:5])>>(Br-[c;H0;D3;+0:1](:[c:2]):[c:3]).([NH2;D1;+0:4]-[c:5])", + "index": 1, + "metadata": {} + }, + "children": [ + { + "cost": 0.0, + "expandable": true, + "value": 0.0, + "mol": 140605332368656, + "children": [] + }, + { + "cost": 0.0, + "expandable": false, + "value": 0, + "mol": 140602867801296, + "children": [] + } + ] + } + ] + }, + { + "cost": 0.0, + "expandable": false, + "value": 0, + "mol": 140605332204496, + "children": [] + } + ] + } + ] + }, + "molecules": { + "140605332219984": { + "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(NC(=S)Nc4ccccc4)cc3)c2c1", + "class": "TreeMolecule", + "parent": null, + "transform": 0 + }, + "140605332201936": { + "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1", + "class": "TreeMolecule", + "parent": 140605332219984, + "transform": 1 + }, + "140605332205072": { + "smiles": "Nc1ccc(NC(=S)Nc2ccccc2)cc1", + "class": "TreeMolecule", + "parent": 140605332219984, + "transform": 1 + }, + "140598039857296": { + "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(N=C=S)cc3)c2c1", + "class": "TreeMolecule", + "parent": 140605332219984, + "transform": 1 + }, + "140598039857616": { + "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1", + "class": "TreeMolecule", + "parent": 140598039857296, + "transform": 2 + }, + "140605332368528": { + "smiles": "Nc1ccc(N=C=S)cc1", + "class": "TreeMolecule", + "parent": 140598039857296, + "transform": 2 + }, + "140605332368656": { + "smiles": "Cc1ccc2nc3ccccc3c(Br)c2c1", + "class": "TreeMolecule", + "parent": 140598039857296, + "transform": 2 + }, + "140602867801296": { + "smiles": "Nc1ccc(N=C=S)cc1", + "class": "TreeMolecule", + "parent": 140598039857296, + "transform": 2 + }, + "140605332204496": { + "smiles": "Nc1ccccc1", + "class": "TreeMolecule", + "parent": 140605332219984, + "transform": 1 + } + } +} \ No newline at end of file diff --git a/tests/retrostar/test_retrostar.py b/tests/retrostar/test_retrostar.py new file mode 100644 index 0000000..b0e1ff6 --- /dev/null +++ b/tests/retrostar/test_retrostar.py @@ -0,0 +1,139 @@ +from aizynthfinder.search.retrostar.search_tree import SearchTree +from aizynthfinder.chem.serialization import MoleculeSerializer + + +def test_one_iteration(setup_search_tree): + tree = setup_search_tree + + tree.one_iteration() + + assert len(tree.root.children) == 1 + assert len(tree.root.children[0].children) == 3 + + +def test_one_iteration_filter_unfeasible(setup_search_tree): + tree = setup_search_tree + smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1>>CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O" + tree.config.filter_policy["dummy"].lookup[smi] = 0.0 + + tree.one_iteration() + assert len(tree.root.children) == 0 + + +def test_one_iteration_filter_feasible(setup_search_tree): + tree = setup_search_tree + smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1>>CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O" + tree.config.filter_policy["dummy"].lookup[smi] = 0.5 + + tree.one_iteration() + assert len(tree.root.children) == 1 + + +def test_one_expansion_with_finder(setup_aizynthfinder): + """ + Test the building of this tree: + root + | + child 1 + """ + root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" + child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] + lookup = {root_smi: {"smiles": ".".join(child1_smi), "prior": 1.0}} + finder = setup_aizynthfinder(lookup, child1_smi) + + # Test first with return_first + finder.config.return_first = True + finder.tree_search() + + nodes = finder.tree.mol_nodes + assert len(nodes) == 4 + assert nodes[0].mol.smiles == root_smi + assert nodes[1].mol.smiles == child1_smi[0] + assert finder.search_stats["iterations"] == 1 + assert finder.search_stats["returned_first"] + + # then test with iteration limit + finder.config.return_first = False + finder.config.iteration_limit = 45 + finder.prepare_tree() + finder.tree_search() + + assert len(finder.tree.mol_nodes) == 4 + # It will not continue because it cannot expand any more nodes + assert finder.search_stats["iterations"] == 2 + assert not finder.search_stats["returned_first"] + + +def test_serialization_deserialization( + mocker, setup_search_tree, tmpdir, default_config +): + tree = setup_search_tree + tree.one_iteration() + + mocked_json_dump = mocker.patch( + "aizynthfinder.search.retrostar.search_tree.json.dump" + ) + serializer = MoleculeSerializer() + filename = str(tmpdir / "dummy.json") + + # Test serialization + + tree.serialize(filename) + + expected_dict = { + "tree": tree.root.serialize(serializer), + "molecules": serializer.store, + } + + mocked_json_dump.assert_called_once_with( + expected_dict, mocker.ANY, indent=mocker.ANY + ) + + # Test deserialization + + mocker.patch( + "aizynthfinder.search.retrostar.search_tree.json.load", + return_value=expected_dict, + ) + mocker.patch( + "aizynthfinder.search.retrostar.nodes.deserialize_action", return_value=None + ) + + new_tree = SearchTree.from_json(filename, default_config) + + assert new_tree.root.mol == tree.root.mol + assert len(new_tree.root.children) == len(tree.root.children) + + +def test_split_andor_tree(shared_datadir, default_config): + tree = SearchTree.from_json( + str(shared_datadir / "andor_tree_for_clustering.json"), default_config + ) + + routes = tree.routes() + + assert len(routes) == 3 + + +def test_update(shared_datadir, default_config, setup_stock): + setup_stock( + default_config, + "Nc1ccc(NC(=S)Nc2ccccc2)cc1", + "Cc1ccc2nc3ccccc3c(Cl)c2c1", + "Nc1ccccc1", + "Nc1ccc(N=C=S)cc1", + "Cc1ccc2nc3ccccc3c(Br)c2c1", + "Nc1ccc(Br)cc1", + ) + tree = SearchTree.from_json( + str(shared_datadir / "andor_tree_for_clustering.json"), default_config + ) + + saved_root_value = tree.root.value + tree.mol_nodes[-1].parent.update(35, from_mol=tree.mol_nodes[-1].mol) + + assert [child.value for child in tree.root.children] == [5.0, 45.0] + assert tree.root.value != saved_root_value + assert tree.root.value == 5 + + tree.serialize("temp.json") diff --git a/tests/retrostar/test_retrostar_cost.py b/tests/retrostar/test_retrostar_cost.py new file mode 100644 index 0000000..43a6d73 --- /dev/null +++ b/tests/retrostar/test_retrostar_cost.py @@ -0,0 +1,25 @@ +import numpy as np + +import pytest + +from aizynthfinder.context.cost import MoleculeCost +from aizynthfinder.search.retrostar.cost import RetroStarCost +from aizynthfinder.chem import Molecule + + +def test_retrostar_cost(setup_mocked_model): + mol = Molecule(smiles="CCCC") + + cost = RetroStarCost(model_path="dummy", fingerprint_length=10, dropout_rate=0.0) + assert pytest.approx(cost(mol), abs=0.001) == 30 + + +def test_load_cost_from_config(setup_mocked_model): + cost = MoleculeCost() + + dict_ = { + "aizynthfinder.search.retrostar.cost.RetroStarCost": {"model_path": "dummy"} + } + cost.load_from_config(**dict_) + + assert len(cost) == 2 diff --git a/tests/retrostar/test_retrostar_nodes.py b/tests/retrostar/test_retrostar_nodes.py new file mode 100644 index 0000000..a784fce --- /dev/null +++ b/tests/retrostar/test_retrostar_nodes.py @@ -0,0 +1,140 @@ +import numpy as np +import networkx as nx + +from aizynthfinder.search.retrostar.nodes import MoleculeNode +from aizynthfinder.chem.serialization import MoleculeSerializer, MoleculeDeserializer +from aizynthfinder.search.andor_trees import ReactionTreeFromAndOrTrace + + +def test_create_root_node(setup_star_root): + node = setup_star_root("CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1") + + assert node.target_value == 0 + assert node.ancestors() == {node.mol} + assert node.expandable + assert not node.children + + +def test_close_single_node(setup_star_root): + node = setup_star_root("CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1") + + assert node.expandable + + delta = node.close() + + assert not np.isfinite(delta) + assert not node.solved + assert not node.expandable + + +def test_create_stub(setup_star_root, get_action): + root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" + root = setup_star_root(root_smiles) + reaction = get_action() + + root.add_stub(cost=5.0, reaction=reaction) + + assert len(root.children) == 1 + assert len(root.children[0].children) == 2 + rxn_node = root.children[0] + assert rxn_node.reaction is reaction + exp_list = [node.mol for node in rxn_node.children] + assert exp_list == list(reaction.reactants[0]) + assert rxn_node.value == rxn_node.target_value == 5 + assert not rxn_node.solved + + # This is done after a node has been expanded + delta = root.close() + + assert delta == 5.0 + assert root.value == 5.0 + assert rxn_node.children[0].ancestors() == {root.mol, rxn_node.children[0].mol} + + +def test_initialize_stub_one_solved_leaf( + setup_star_root, get_action, default_config, setup_stock +): + root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" + root = setup_star_root(root_smiles) + reaction = get_action() + setup_stock(default_config, reaction.reactants[0][0]) + + root.add_stub(cost=5.0, reaction=reaction) + root.close() + + assert not root.children[0].solved + assert not root.solved + assert root.children[0].children[0].solved + + +def test_initialize_stub_two_solved_leafs( + setup_star_root, get_action, default_config, setup_stock +): + root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" + root = setup_star_root(root_smiles) + reaction = get_action() + setup_stock(default_config, *reaction.reactants[0]) + + root.add_stub(cost=5.0, reaction=reaction) + root.close() + + assert root.children[0].solved + assert root.solved + + +def test_serialization_deserialization( + setup_star_root, get_action, default_config, setup_stock +): + root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" + root = setup_star_root(root_smiles) + reaction = get_action() + setup_stock(default_config, *reaction.reactants[0]) + root.add_stub(cost=5.0, reaction=reaction) + root.close() + + molecule_serializer = MoleculeSerializer() + dict_ = root.serialize(molecule_serializer) + + molecule_deserializer = MoleculeDeserializer(molecule_serializer.store) + node = MoleculeNode.from_dict(dict_, default_config, molecule_deserializer) + + assert node.mol == root.mol + assert node.value == root.value + assert node.cost == root.cost + assert len(node.children) == len(root.children) + + rxn_node = node.children[0] + assert rxn_node.reaction.smarts == reaction.smarts + assert rxn_node.reaction.metadata == reaction.metadata + assert rxn_node.cost == root.children[0].cost + assert rxn_node.value == root.children[0].value + + for grandchild1, grandchild2 in zip(rxn_node.children, root.children[0].children): + assert grandchild1.mol == grandchild2.mol + + +def test_converstion_to_reaction_tree( + setup_star_root, get_action, default_config, setup_stock +): + root_smiles = "CCCCOc1ccc(CC(=O)N(C)O)cc1" + root = setup_star_root(root_smiles) + reaction = get_action() + setup_stock(default_config, *reaction.reactants[0]) + root.add_stub(cost=5.0, reaction=reaction) + root.close() + graph = nx.DiGraph() + graph.add_edge(root, root.children[0]) + graph.add_edge(root.children[0], root.children[0].children[0]) + graph.add_edge(root.children[0], root.children[0].children[1]) + + rt = ReactionTreeFromAndOrTrace(graph, default_config.stock).tree + + molecules = list(rt.molecules()) + rt_reactions = list(rt.reactions()) + assert len(molecules) == 3 + assert len(list(rt.leafs())) == 2 + assert len(rt_reactions) == 1 + assert molecules[0].inchi_key == root.mol.inchi_key + assert molecules[1].inchi_key == root.children[0].children[0].mol.inchi_key + assert molecules[2].inchi_key == root.children[0].children[1].mol.inchi_key + assert rt_reactions[0].reaction_smiles() == reaction.reaction_smiles() diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index c18c53c..5f051cb 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -1,4 +1,6 @@ import os +import gzip +import json import pytest import pandas as pd @@ -88,3 +90,32 @@ def test_cat_hdf(create_dummy_stock1, create_dummy_stock2, tmpdir): "UHOVQNZJYSORNB-UHFFFAOYSA-N", "ISWSIDIOOBJBQZ-UHFFFAOYSA-N", ] + + +def test_cat_hdf_no_trees(tmpdir, create_dummy_stock1, create_dummy_stock2): + hdf_filename = str(tmpdir / "output.hdf") + tree_filename = str(tmpdir / "trees.json") + inputs = [create_dummy_stock1("hdf5"), create_dummy_stock2] + + cat_hdf_files(inputs, hdf_filename, tree_filename) + + assert not os.path.exists(tree_filename) + + +def test_cat_hdf_trees(tmpdir): + hdf_filename = str(tmpdir / "output.hdf") + tree_filename = str(tmpdir / "trees.json") + filename1 = str(tmpdir / "file1.hdf5") + filename2 = str(tmpdir / "file2.hdf5") + trees1 = [[1], [2]] + trees2 = [[3], [4]] + pd.DataFrame({"mol": ["A", "B"], "trees": trees1}).to_hdf(filename1, "table") + pd.DataFrame({"mol": ["A", "B"], "trees": trees2}).to_hdf(filename2, "table") + + cat_hdf_files([filename1, filename2], hdf_filename, tree_filename) + + assert os.path.exists(tree_filename + ".gz") + with gzip.open(tree_filename + ".gz", "rt", encoding="UTF-8") as fileobj: + trees_cat = json.load(fileobj) + assert trees_cat == trees1 + trees2 + assert "trees" not in pd.read_hdf(hdf_filename, "table")