Skip to content

Commit

Permalink
Version 3.2.0 public release
Browse files Browse the repository at this point in the history
  • Loading branch information
SGenheden committed Feb 24, 2022
1 parent d1fb308 commit 52243b9
Show file tree
Hide file tree
Showing 26 changed files with 455 additions and 110 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,15 @@
# CHANGELOG

## Version 3.2.0 2022-02-24

### Features

- Profiling feature enabled in search trees
- New, customizable configuration of training pre-processing tools
- Generic post-processing support in aizynthcli
- Introduce short aliases for filter policies
- Reaction shape support in GraphViz visualisation of routes

## Version 3.1.0 2021-12-21

### Features
Expand Down
2 changes: 2 additions & 0 deletions aizynthfinder/analysis/tree_analysis.py
Expand Up @@ -152,6 +152,7 @@ def _tree_statistics_andor(self) -> StrDict:
"precursors_not_in_stock": mols_not_in_stock,
"precursors_availability": availability,
"policy_used_counts": policy_used_counts,
"profiling": getattr(self.search_tree, "profiling", {}),
}

def _tree_statistics_mcts(self) -> StrDict:
Expand Down Expand Up @@ -192,6 +193,7 @@ def _tree_statistics_mcts(self) -> StrDict:
"precursors_not_in_stock": mols_not_in_stock,
"precursors_availability": ";".join(top_state.stock_availability),
"policy_used_counts": policy_used_counts,
"profiling": getattr(self.search_tree, "profiling", {}),
}

@staticmethod
Expand Down
7 changes: 7 additions & 0 deletions aizynthfinder/chem/reaction.py
Expand Up @@ -256,6 +256,13 @@ def smiles(self) -> str:
self._smiles = "" # noqa
return self._smiles

@property
def unqueried(self) -> bool:
"""
Return True if the reactants has never been retrieved
"""
return self._reactants is None

def copy(self, index: int = None) -> "RetroReaction":
"""
Shallow copy of this instance.
Expand Down
1 change: 1 addition & 0 deletions aizynthfinder/context/policy/expansion_strategies.py
Expand Up @@ -130,6 +130,7 @@ def get_actions(
metadata["policy_probability_rank"] = idx
metadata["policy_name"] = self.key
metadata["template_code"] = move_index
metadata["template"] = move[self._config.template_column]
possible_actions.append(
TemplatedRetroReaction(
mol,
Expand Down
11 changes: 8 additions & 3 deletions aizynthfinder/context/policy/filter_strategies.py
Expand Up @@ -80,9 +80,7 @@ def __init__(self, key: str, config: Configuration, **kwargs: Any) -> None:
self.model = load_model(source, key, self._config.use_remote_models)
self._prod_fp_name = kwargs.get("prod_fp_name", "input_1")
self._rxn_fp_name = kwargs.get("rxn_fp_name", "input_2")
self._exclude_from_policy: List[str] = kwargs.get(
"exclude_from_policy", []
)
self._exclude_from_policy: List[str] = kwargs.get("exclude_from_policy", [])

def apply(self, reaction: RetroReaction) -> None:
if reaction.metadata.get("policy_name", "") in self._exclude_from_policy:
Expand Down Expand Up @@ -144,3 +142,10 @@ def apply(self, reaction: RetroReaction) -> None:
raise RejectionException(
f"{reaction} was filtered out because number of reactants disagree with the template"
)


FILTER_STRATEGY_ALIAS = {
"feasibility": "QuickKerasFilter",
"quick_keras_filter": "QuickKerasFilter",
"reactants_count": "ReactantsCountFilter",
}
4 changes: 3 additions & 1 deletion aizynthfinder/context/policy/policies.py
Expand Up @@ -13,6 +13,7 @@
from aizynthfinder.context.policy.filter_strategies import (
FilterStrategy,
QuickKerasFilter,
FILTER_STRATEGY_ALIAS,
)
from aizynthfinder.context.policy.expansion_strategies import (
__name__ as expansion_strategy_module,
Expand Down Expand Up @@ -196,8 +197,9 @@ def load_from_config(self, **config: Any) -> None:
for strategy_spec, strategy_config in config.items():
if strategy_spec in ["files", "quick-filter"]:
continue
strategy_spec2 = FILTER_STRATEGY_ALIAS.get(strategy_spec, strategy_spec)
cls = load_dynamic_class(
strategy_spec, filter_strategy_module, PolicyException
strategy_spec2, filter_strategy_module, PolicyException
)
for key, policy_spec in strategy_config.items():
obj = cls(key, self._config, **(policy_spec or {}))
Expand Down
9 changes: 9 additions & 0 deletions aizynthfinder/data/default_training.yml
@@ -1,5 +1,14 @@
library_headers: ["index", "ID", "reaction_hash", "reactants", "products", "classification", "retro_template", "template_hash", "selectivity", "outcomes", "template_code"]
column_map:
reaction_hash: reaction_hash
reactants: reactants
products: products
retro_template: retro_template
template_hash: template_hash
metadata_headers: ["template_hash", "classification"]
in_csv_headers: False
csv_sep: ","
reaction_smiles_column: ""
output_path: "."
file_prefix: ""
file_postfix:
Expand Down
4 changes: 2 additions & 2 deletions aizynthfinder/data/templates/reaction_tree.dot
Expand Up @@ -14,11 +14,11 @@ node [label="\N"];
image="{{ image_filepath }}"
];
{% endfor %}
{% for reaction in reactions %}
{% for reaction, reaction_shape in reactions %}
{{ id(reaction) }} [
label="",
fillcolor="black",
shape="circle",
shape="{{ reaction_shape }}",
style="filled",
width="0.1",
fixedsize="true"
Expand Down
50 changes: 46 additions & 4 deletions aizynthfinder/interfaces/aizynthcli.py
Expand Up @@ -19,7 +19,9 @@
from aizynthfinder.utils.logging import logger, setup_logger

if TYPE_CHECKING:
from aizynthfinder.utils.type_utils import StrDict
from aizynthfinder.utils.type_utils import StrDict, Callable, List, Optional

_PostProcessingJob = Callable[[AiZynthFinder], StrDict]


def _do_clustering(
Expand All @@ -41,6 +43,13 @@ def _do_clustering(
results["distance_matrix"] = finder.routes.distance_matrix().tolist()


def _do_post_processing(
finder: AiZynthFinder, results: StrDict, jobs: List[_PostProcessingJob]
) -> None:
for job in jobs:
results.update(job(finder))


def _get_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser("aizynthcli")
parser.add_argument(
Expand Down Expand Up @@ -87,9 +96,28 @@ def _get_arguments() -> argparse.Namespace:
"--route_distance_model",
help="if provided, calculate route distances for clustering with this ML model",
)
parser.add_argument(
"--post_processing",
nargs="+",
help="a number of modules that performs post-processing tasks",
)
return parser.parse_args()


def _load_postprocessing_jobs(modules: Optional[List[str]]) -> List[_PostProcessingJob]:
jobs: List[_PostProcessingJob] = []
for module_name in modules or []:
try:
module = importlib.import_module(module_name)
except ModuleNotFoundError:
pass
else:
if hasattr(module, "post_processing"):
print(f"Adding post-processing job from {module_name}")
jobs.append(getattr(module, "post_processing"))
return jobs


def _select_stocks(finder: AiZynthFinder, args: argparse.Namespace) -> None:
stocks = list(args.stocks)
try:
Expand All @@ -108,7 +136,8 @@ def _process_single_smiles(
finder: AiZynthFinder,
output_name: str,
do_clustering: bool,
route_distance_model: str = None,
route_distance_model: Optional[str],
post_processing: List[_PostProcessingJob],
) -> None:
output_name = output_name or "trees.json"
finder.target_smiles = smiles
Expand All @@ -128,6 +157,7 @@ def _process_single_smiles(
_do_clustering(
finder, stats, detailed_results=False, model_path=route_distance_model
)
_do_post_processing(finder, stats, post_processing)
stats_str = "\n".join(
f"{key.replace('_', ' ')}: {value}" for key, value in stats.items()
)
Expand All @@ -139,7 +169,8 @@ def _process_multi_smiles(
finder: AiZynthFinder,
output_name: str,
do_clustering: bool,
route_distance_model: str = None,
route_distance_model: Optional[str],
post_processing: List[_PostProcessingJob],
) -> None:
output_name = output_name or "output.hdf5"
with open(filename, "r") as fileobj:
Expand All @@ -159,6 +190,7 @@ def _process_multi_smiles(
_do_clustering(
finder, stats, detailed_results=True, model_path=route_distance_model
)
_do_post_processing(finder, stats, post_processing)
for key, value in stats.items():
results[key].append(value)
results["top_scores"].append(
Expand Down Expand Up @@ -195,6 +227,8 @@ def create_cmd(index, filename):
cmd_args.append("--cluster")
if args.route_distance_model:
cmd_args.extend(["--route_distance_model", args.route_distance_model])
if args.post_processing:
cmd_args.extend(["--post_processing"] + args.post_processing)
return cmd_args

if not os.path.exists(args.smiles):
Expand Down Expand Up @@ -228,11 +262,19 @@ def main() -> None:

finder = AiZynthFinder(configfile=args.config)
_select_stocks(finder, args)
post_processing = _load_postprocessing_jobs(args.post_processing)
finder.expansion_policy.select(args.policy or finder.expansion_policy.items[0])
finder.filter_policy.select(args.filter)

func = _process_multi_smiles if multi_smiles else _process_single_smiles
func(args.smiles, finder, args.output, args.cluster, args.route_distance_model)
func(
args.smiles,
finder,
args.output,
args.cluster,
args.route_distance_model,
post_processing,
)


if __name__ == "__main__":
Expand Down
7 changes: 6 additions & 1 deletion aizynthfinder/search/mcts/node.py
Expand Up @@ -244,6 +244,9 @@ def expand(self) -> None:
self.is_expandable = False
self.is_expanded = False

if self.tree:
self.tree.profiling["expansion_calls"] += 1

def is_terminal(self) -> bool:
"""
Node is terminal if its unexpandable, or the internal state is terminal (solved)
Expand Down Expand Up @@ -413,7 +416,9 @@ def _select_child(self, child_idx: int) -> Optional["MctsNode"]:
return self._children[child_idx]

reaction = self._children_actions[child_idx]
if not reaction.reactants:
if reaction.unqueried:
if self.tree:
self.tree.profiling["reactants_generations"] += 1
_ = reaction.reactants

if not self._check_child_reaction(reaction):
Expand Down
4 changes: 4 additions & 0 deletions aizynthfinder/search/mcts/search.py
Expand Up @@ -34,6 +34,10 @@ def __init__(self, config: Configuration, root_smiles: str = None) -> None:
self.root = None
self.config = config
self._graph: Optional[nx.DiGraph] = None
self.profiling = {
"expansion_calls": 0,
"reactants_generations": 0,
}

@classmethod
def from_json(cls, filename: str, config: Configuration) -> "MctsSearchTree":
Expand Down

0 comments on commit 52243b9

Please sign in to comment.