Skip to content

Commit

Permalink
Add updated scripts under aizynthfinder
Browse files Browse the repository at this point in the history
  • Loading branch information
Lakshidaa committed Dec 4, 2023
1 parent 82f19c4 commit ff34bbd
Show file tree
Hide file tree
Showing 54 changed files with 1,456 additions and 2,068 deletions.
41 changes: 30 additions & 11 deletions aizynthfinder/aizynthfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from aizynthfinder.chem import FixedRetroReaction, Molecule, TreeMolecule
from aizynthfinder.context.config import Configuration
from aizynthfinder.context.scoring import CombinedScorer
from aizynthfinder.reactiontree import ReactionTreeFromExpansion
from aizynthfinder.search.andor_trees import AndOrSearchTreeBase
from aizynthfinder.search.mcts import MctsSearchTree
Expand Down Expand Up @@ -60,7 +61,9 @@ class AiZynthFinder:
:param configdict: the config as a dictionary source, defaults to None
"""

def __init__(self, configfile: str = None, configdict: StrDict = None) -> None:
def __init__(
self, configfile: Optional[str] = None, configdict: Optional[StrDict] = None
) -> None:
self._logger = logger()

if configfile:
Expand Down Expand Up @@ -102,7 +105,9 @@ def target_mol(self, mol: Molecule) -> None:
self._target_mol = mol

def build_routes(
self, selection: RouteSelectionArguments = None, scorer: str = "state score"
self,
selection: Optional[RouteSelectionArguments] = None,
scorer: Optional[str] = None,
) -> None:
"""
Build reaction routes
Expand All @@ -114,10 +119,15 @@ def build_routes(
:param scorer: a reference to the object used to score the nodes
:raises ValueError: if the search tree not initialized
"""

scorer = scorer or self.config.post_processing.route_scorer

if not self.tree:
raise ValueError("Search tree not initialized")

self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer])
_scorer = self.scorers[scorer]

self.analysis = TreeAnalysis(self.tree, scorer=_scorer)
config_selection = RouteSelectionArguments(
nmin=self.config.post_processing.min_routes,
nmax=self.config.post_processing.max_routes,
Expand Down Expand Up @@ -157,13 +167,17 @@ def prepare_tree(self) -> None:
raise ValueError("Target molecule unsanitizable")

self.stock.reset_exclusion_list()
if self.config.exclude_target_from_stock and self.target_mol in self.stock:
if (
self.config.search.exclude_target_from_stock
and self.target_mol in self.stock
):
self.stock.exclude(self.target_mol)
self._logger.debug("Excluding the target compound from the stock")

self._setup_search_tree()
self.analysis = None
self.routes = RouteCollection([])
self.expansion_policy.reset_cache()

def stock_info(self) -> StrDict:
"""
Expand Down Expand Up @@ -202,9 +216,12 @@ def tree_search(self, show_progress: bool = False) -> float:
time_past = time.time() - time0

if show_progress:
pbar = tqdm(total=self.config.iteration_limit, leave=False)
pbar = tqdm(total=self.config.search.iteration_limit, leave=False)

while time_past < self.config.time_limit and i <= self.config.iteration_limit:
while (
time_past < self.config.search.time_limit
and i <= self.config.search.iteration_limit
):
if show_progress:
pbar.update(1)
self.search_stats["iterations"] += 1
Expand All @@ -218,7 +235,7 @@ def tree_search(self, show_progress: bool = False) -> float:
self.search_stats["first_solution_time"] = time.time() - time0
self.search_stats["first_solution_iteration"] = i

if self.config.return_first and is_solved:
if self.config.search.return_first and is_solved:
self._logger.debug("Found first solved route")
self.search_stats["returned_first"] = True
break
Expand All @@ -234,12 +251,12 @@ def tree_search(self, show_progress: bool = False) -> float:

def _setup_search_tree(self) -> None:
self._logger.debug("Defining tree root: %s" % self.target_smiles)
if self.config.search_algorithm.lower() == "mcts":
if self.config.search.algorithm.lower() == "mcts":
self.tree = MctsSearchTree(
root_smiles=self.target_smiles, config=self.config
)
else:
cls = load_dynamic_class(self.config.search_algorithm)
cls = load_dynamic_class(self.config.search.algorithm)
self.tree = cls(root_smiles=self.target_smiles, config=self.config)


Expand All @@ -260,7 +277,9 @@ class AiZynthExpander:
:param configdict: the config as a dictionary source, defaults to None
"""

def __init__(self, configfile: str = None, configdict: StrDict = None) -> None:
def __init__(
self, configfile: Optional[str] = None, configdict: Optional[StrDict] = None
) -> None:
self._logger = logger()

if configfile:
Expand All @@ -278,7 +297,7 @@ def do_expansion(
self,
smiles: str,
return_n: int = 5,
filter_func: Callable[[RetroReaction], bool] = None,
filter_func: Optional[Callable[[RetroReaction], bool]] = None,
) -> List[Tuple[FixedRetroReaction, ...]]:
"""
Do the expansion of the given molecule returning a list of
Expand Down
5 changes: 3 additions & 2 deletions aizynthfinder/analysis/tree_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Any,
Iterable,
List,
Optional,
Sequence,
StrDict,
Tuple,
Expand All @@ -41,7 +42,7 @@ class TreeAnalysis:
def __init__(
self,
search_tree: Union[MctsSearchTree, AndOrSearchTreeBase],
scorer: Scorer = None,
scorer: Optional[Scorer] = None,
) -> None:
self.search_tree = search_tree
if scorer is None:
Expand All @@ -65,7 +66,7 @@ def best(self) -> Union[MctsNode, ReactionTree]:
return sorted_routes[0]

def sort(
self, selection: RouteSelectionArguments = None
self, selection: Optional[RouteSelectionArguments] = None
) -> Tuple[Union[Sequence[MctsNode], Sequence[ReactionTree]], Sequence[float]]:
"""
Sort and select the nodes or routes in the search tree.
Expand Down
11 changes: 8 additions & 3 deletions aizynthfinder/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
from aizynthfinder.utils.image import make_visjs_page

if TYPE_CHECKING:
from aizynthfinder.utils.type_utils import FrameColors, Sequence, StrDict, Tuple
from aizynthfinder.utils.type_utils import (
FrameColors,
Optional,
Sequence,
StrDict,
Tuple,
)


@dataclass
Expand Down Expand Up @@ -69,7 +75,7 @@ def to_dict(self) -> StrDict:
def to_visjs_page(
self,
filename: str,
in_stock_colors: FrameColors = None,
in_stock_colors: Optional[FrameColors] = None,
) -> None:
"""
Create a visualization of the combined reaction tree using the vis.js network library.
Expand All @@ -93,7 +99,6 @@ def _add_reaction_trees_to_node(
base_node: UniqueMolecule,
rt_node_spec: Sequence[Tuple[UniqueMolecule, nx.DiGraph]],
) -> None:

reaction_groups = defaultdict(list)
# Group the reactions from the nodes at this level based on the reaction smiles
for node, graph in rt_node_spec:
Expand Down
1 change: 0 additions & 1 deletion aizynthfinder/chem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
)
from aizynthfinder.chem.reaction import (
FixedRetroReaction,
Reaction,
RetroReaction,
SmilesBasedRetroReaction,
TemplatedRetroReaction,
Expand Down
82 changes: 48 additions & 34 deletions aizynthfinder/chem/mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ class Molecule:
"""

def __init__(
self, rd_mol: RdMol = None, smiles: str = None, sanitize: bool = False
self,
rd_mol: Optional[RdMol] = None,
smiles: Optional[str] = None,
sanitize: bool = False,
) -> None:
if not rd_mol and not smiles:
raise MoleculeException(
Expand Down Expand Up @@ -148,19 +151,24 @@ def basic_compare(self, other: "Molecule") -> bool:
"""
return self.inchi_key[:14] == other.inchi_key[:14]

def fingerprint(self, radius: int, nbits: int = 2048) -> np.ndarray:
def fingerprint(
self, radius: int, nbits: int = 2048, chiral: bool = False
) -> np.ndarray:
"""
Returns the Morgan fingerprint of the molecule
:param radius: the radius of the fingerprint
:param nbits: the length of the fingerprint
:param chiral: if True, include chirality information
:return: the fingerprint
"""
key = radius, nbits

if key not in self._fingerprints:
self.sanitize()
bitvect = AllChem.GetMorganFingerprintAsBitVect(self.rd_mol, *key)
bitvect = AllChem.GetMorganFingerprintAsBitVect(
self.rd_mol, *key, useChirality=chiral
)
array = np.zeros((1,))
DataStructs.ConvertToNumpyArray(bitvect, array)
self._fingerprints[key] = array
Expand All @@ -187,7 +195,7 @@ def make_unique(self) -> "UniqueMolecule":
"""
return UniqueMolecule(rd_mol=self.rd_mol)

def remove_atom_mapping(self, exceptions: Sequence[int] = None) -> None:
def remove_atom_mapping(self, exceptions: Optional[Sequence[int]] = None) -> None:
"""
Remove all mappings of the atoms and update the smiles
Expand Down Expand Up @@ -247,7 +255,6 @@ class TreeMolecule(Molecule):
:ivar original_smiles: the SMILES as passed when instantiating the class
:ivar parent: parent molecule
:ivar transform: a numerical number corresponding to the depth in the tree
:ivar tracked_atom_indices: tracked atom indices and what indices they correspond to in this molecule
:param parent: a TreeMolecule object that is the parent
:param transform: the transform value, defaults to None
Expand All @@ -262,11 +269,11 @@ class TreeMolecule(Molecule):
def __init__(
self,
parent: Optional["TreeMolecule"],
transform: int = None,
rd_mol: RdMol = None,
smiles: str = None,
transform: Optional[int] = None,
rd_mol: Optional[RdMol] = None,
smiles: Optional[str] = None,
sanitize: bool = False,
mapping_update_callback: Callable[["TreeMolecule"], None] = None,
mapping_update_callback: Optional[Callable[["TreeMolecule"], None]] = None,
) -> None:
super().__init__(rd_mol=rd_mol, smiles=smiles, sanitize=sanitize)
self.parent = parent
Expand All @@ -276,10 +283,10 @@ def __init__(
self.transform = transform or 0

self.original_smiles = smiles
self.tracked_atom_indices: Dict[int, Optional[int]] = {}
self.mapped_mol = Chem.Mol(self.rd_mol)
self._atom_bonds: List[Tuple[int, int]] = []
if not self.parent:
self._init_tracking()
self._set_atom_mappings()
elif mapping_update_callback is not None:
mapping_update_callback(self)

Expand All @@ -288,7 +295,6 @@ def __init__(

if self.parent:
self.remove_atom_mapping()
self._update_tracked_atoms()

@property
def mapping_to_index(self) -> Dict[int, int]:
Expand All @@ -301,28 +307,33 @@ def mapping_to_index(self) -> Dict[int, int]:
}
return self._atom_mappings

def _init_tracking(self):
self.tracked_atom_indices = dict(self.mapping_to_index)
for idx, atom in enumerate(self.mapped_mol.GetAtoms()):
atom.SetAtomMapNum(idx + 1)
@property
def mapped_atom_bonds(self) -> List[Tuple[int, int]]:
"""Return a list of atom bonds as tuples on the mapped atom indices"""
bonds = []
for bond in self.mapped_mol.GetBonds():
bonds.append((bond.GetBeginAtom().GetIdx(), bond.GetEndAtom().GetIdx()))

self._atom_bonds = [
(self.index_to_mapping[atom_index1], self.index_to_mapping[atom_index2])
for atom_index1, atom_index2 in bonds
]
return self._atom_bonds

def _set_atom_mappings(self) -> None:
atom_mappings = [
atom.GetAtomMapNum()
for atom in self.mapped_mol.GetAtoms()
if atom.GetAtomMapNum() != 0
]

mapper = max(atom_mappings) + 1 if atom_mappings else 1
self._atom_mappings = {}

def _update_tracked_atoms(self) -> None:
if self.parent is None:
return

if not self.parent.tracked_atom_indices:
return

parent2child_map = {
atom_index: self.mapping_to_index.get(mapping_index)
for mapping_index, atom_index in self.parent.mapping_to_index.items()
}

self.tracked_atom_indices = {
tracked_index: parent2child_map[parent_index] # type: ignore
for tracked_index, parent_index in self.parent.tracked_atom_indices.items()
}
for atom_index, atom in enumerate(self.mapped_mol.GetAtoms()):
if atom.GetAtomMapNum() == 0:
atom.SetAtomMapNum(mapper)
mapper += 1
self._atom_mappings[atom.GetAtomMapNum()] = atom_index


class UniqueMolecule(Molecule):
Expand All @@ -337,7 +348,10 @@ class UniqueMolecule(Molecule):
"""

def __init__(
self, rd_mol: RdMol = None, smiles: str = None, sanitize: bool = False
self,
rd_mol: Optional[RdMol] = None,
smiles: Optional[str] = None,
sanitize: bool = False,
) -> None:
super().__init__(rd_mol=rd_mol, smiles=smiles, sanitize=sanitize)

Expand Down

0 comments on commit ff34bbd

Please sign in to comment.