diff --git a/CHANGELOG.md b/CHANGELOG.md index 75d174c..555dc56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # CHANGELOG +## Version 3.2.0 2022-02-24 + +### Features + +- Profiling feature enabled in search trees +- New, customizable configuration of training pre-processing tools +- Generic post-processing support in aizynthcli +- Introduce short aliases for filter policies +- Reaction shape support in GraphViz visualisation of routes + ## Version 3.1.0 2021-12-21 ### Features diff --git a/aizynthfinder/analysis/tree_analysis.py b/aizynthfinder/analysis/tree_analysis.py index 07a1ca3..7984dc3 100644 --- a/aizynthfinder/analysis/tree_analysis.py +++ b/aizynthfinder/analysis/tree_analysis.py @@ -152,6 +152,7 @@ def _tree_statistics_andor(self) -> StrDict: "precursors_not_in_stock": mols_not_in_stock, "precursors_availability": availability, "policy_used_counts": policy_used_counts, + "profiling": getattr(self.search_tree, "profiling", {}), } def _tree_statistics_mcts(self) -> StrDict: @@ -192,6 +193,7 @@ def _tree_statistics_mcts(self) -> StrDict: "precursors_not_in_stock": mols_not_in_stock, "precursors_availability": ";".join(top_state.stock_availability), "policy_used_counts": policy_used_counts, + "profiling": getattr(self.search_tree, "profiling", {}), } @staticmethod diff --git a/aizynthfinder/chem/reaction.py b/aizynthfinder/chem/reaction.py index 67ccccc..f06798b 100644 --- a/aizynthfinder/chem/reaction.py +++ b/aizynthfinder/chem/reaction.py @@ -256,6 +256,13 @@ def smiles(self) -> str: self._smiles = "" # noqa return self._smiles + @property + def unqueried(self) -> bool: + """ + Return True if the reactants has never been retrieved + """ + return self._reactants is None + def copy(self, index: int = None) -> "RetroReaction": """ Shallow copy of this instance. diff --git a/aizynthfinder/context/policy/expansion_strategies.py b/aizynthfinder/context/policy/expansion_strategies.py index 4c0ff4d..5d54f5b 100644 --- a/aizynthfinder/context/policy/expansion_strategies.py +++ b/aizynthfinder/context/policy/expansion_strategies.py @@ -130,6 +130,7 @@ def get_actions( metadata["policy_probability_rank"] = idx metadata["policy_name"] = self.key metadata["template_code"] = move_index + metadata["template"] = move[self._config.template_column] possible_actions.append( TemplatedRetroReaction( mol, diff --git a/aizynthfinder/context/policy/filter_strategies.py b/aizynthfinder/context/policy/filter_strategies.py index 8de8007..c730b50 100644 --- a/aizynthfinder/context/policy/filter_strategies.py +++ b/aizynthfinder/context/policy/filter_strategies.py @@ -80,9 +80,7 @@ def __init__(self, key: str, config: Configuration, **kwargs: Any) -> None: self.model = load_model(source, key, self._config.use_remote_models) self._prod_fp_name = kwargs.get("prod_fp_name", "input_1") self._rxn_fp_name = kwargs.get("rxn_fp_name", "input_2") - self._exclude_from_policy: List[str] = kwargs.get( - "exclude_from_policy", [] - ) + self._exclude_from_policy: List[str] = kwargs.get("exclude_from_policy", []) def apply(self, reaction: RetroReaction) -> None: if reaction.metadata.get("policy_name", "") in self._exclude_from_policy: @@ -144,3 +142,10 @@ def apply(self, reaction: RetroReaction) -> None: raise RejectionException( f"{reaction} was filtered out because number of reactants disagree with the template" ) + + +FILTER_STRATEGY_ALIAS = { + "feasibility": "QuickKerasFilter", + "quick_keras_filter": "QuickKerasFilter", + "reactants_count": "ReactantsCountFilter", +} diff --git a/aizynthfinder/context/policy/policies.py b/aizynthfinder/context/policy/policies.py index 33a1b99..b6cbc80 100644 --- a/aizynthfinder/context/policy/policies.py +++ b/aizynthfinder/context/policy/policies.py @@ -13,6 +13,7 @@ from aizynthfinder.context.policy.filter_strategies import ( FilterStrategy, QuickKerasFilter, + FILTER_STRATEGY_ALIAS, ) from aizynthfinder.context.policy.expansion_strategies import ( __name__ as expansion_strategy_module, @@ -196,8 +197,9 @@ def load_from_config(self, **config: Any) -> None: for strategy_spec, strategy_config in config.items(): if strategy_spec in ["files", "quick-filter"]: continue + strategy_spec2 = FILTER_STRATEGY_ALIAS.get(strategy_spec, strategy_spec) cls = load_dynamic_class( - strategy_spec, filter_strategy_module, PolicyException + strategy_spec2, filter_strategy_module, PolicyException ) for key, policy_spec in strategy_config.items(): obj = cls(key, self._config, **(policy_spec or {})) diff --git a/aizynthfinder/data/default_training.yml b/aizynthfinder/data/default_training.yml index 94331d3..7a8657d 100644 --- a/aizynthfinder/data/default_training.yml +++ b/aizynthfinder/data/default_training.yml @@ -1,5 +1,14 @@ library_headers: ["index", "ID", "reaction_hash", "reactants", "products", "classification", "retro_template", "template_hash", "selectivity", "outcomes", "template_code"] +column_map: + reaction_hash: reaction_hash + reactants: reactants + products: products + retro_template: retro_template + template_hash: template_hash metadata_headers: ["template_hash", "classification"] +in_csv_headers: False +csv_sep: "," +reaction_smiles_column: "" output_path: "." file_prefix: "" file_postfix: diff --git a/aizynthfinder/data/templates/reaction_tree.dot b/aizynthfinder/data/templates/reaction_tree.dot index 08ab976..22ed3b2 100644 --- a/aizynthfinder/data/templates/reaction_tree.dot +++ b/aizynthfinder/data/templates/reaction_tree.dot @@ -14,11 +14,11 @@ node [label="\N"]; image="{{ image_filepath }}" ]; {% endfor %} - {% for reaction in reactions %} + {% for reaction, reaction_shape in reactions %} {{ id(reaction) }} [ label="", fillcolor="black", - shape="circle", + shape="{{ reaction_shape }}", style="filled", width="0.1", fixedsize="true" diff --git a/aizynthfinder/interfaces/aizynthcli.py b/aizynthfinder/interfaces/aizynthcli.py index 8334696..e0c375f 100644 --- a/aizynthfinder/interfaces/aizynthcli.py +++ b/aizynthfinder/interfaces/aizynthcli.py @@ -19,7 +19,9 @@ from aizynthfinder.utils.logging import logger, setup_logger if TYPE_CHECKING: - from aizynthfinder.utils.type_utils import StrDict + from aizynthfinder.utils.type_utils import StrDict, Callable, List, Optional + + _PostProcessingJob = Callable[[AiZynthFinder], StrDict] def _do_clustering( @@ -41,6 +43,13 @@ def _do_clustering( results["distance_matrix"] = finder.routes.distance_matrix().tolist() +def _do_post_processing( + finder: AiZynthFinder, results: StrDict, jobs: List[_PostProcessingJob] +) -> None: + for job in jobs: + results.update(job(finder)) + + def _get_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser("aizynthcli") parser.add_argument( @@ -87,9 +96,28 @@ def _get_arguments() -> argparse.Namespace: "--route_distance_model", help="if provided, calculate route distances for clustering with this ML model", ) + parser.add_argument( + "--post_processing", + nargs="+", + help="a number of modules that performs post-processing tasks", + ) return parser.parse_args() +def _load_postprocessing_jobs(modules: Optional[List[str]]) -> List[_PostProcessingJob]: + jobs: List[_PostProcessingJob] = [] + for module_name in modules or []: + try: + module = importlib.import_module(module_name) + except ModuleNotFoundError: + pass + else: + if hasattr(module, "post_processing"): + print(f"Adding post-processing job from {module_name}") + jobs.append(getattr(module, "post_processing")) + return jobs + + def _select_stocks(finder: AiZynthFinder, args: argparse.Namespace) -> None: stocks = list(args.stocks) try: @@ -108,7 +136,8 @@ def _process_single_smiles( finder: AiZynthFinder, output_name: str, do_clustering: bool, - route_distance_model: str = None, + route_distance_model: Optional[str], + post_processing: List[_PostProcessingJob], ) -> None: output_name = output_name or "trees.json" finder.target_smiles = smiles @@ -128,6 +157,7 @@ def _process_single_smiles( _do_clustering( finder, stats, detailed_results=False, model_path=route_distance_model ) + _do_post_processing(finder, stats, post_processing) stats_str = "\n".join( f"{key.replace('_', ' ')}: {value}" for key, value in stats.items() ) @@ -139,7 +169,8 @@ def _process_multi_smiles( finder: AiZynthFinder, output_name: str, do_clustering: bool, - route_distance_model: str = None, + route_distance_model: Optional[str], + post_processing: List[_PostProcessingJob], ) -> None: output_name = output_name or "output.hdf5" with open(filename, "r") as fileobj: @@ -159,6 +190,7 @@ def _process_multi_smiles( _do_clustering( finder, stats, detailed_results=True, model_path=route_distance_model ) + _do_post_processing(finder, stats, post_processing) for key, value in stats.items(): results[key].append(value) results["top_scores"].append( @@ -195,6 +227,8 @@ def create_cmd(index, filename): cmd_args.append("--cluster") if args.route_distance_model: cmd_args.extend(["--route_distance_model", args.route_distance_model]) + if args.post_processing: + cmd_args.extend(["--post_processing"] + args.post_processing) return cmd_args if not os.path.exists(args.smiles): @@ -228,11 +262,19 @@ def main() -> None: finder = AiZynthFinder(configfile=args.config) _select_stocks(finder, args) + post_processing = _load_postprocessing_jobs(args.post_processing) finder.expansion_policy.select(args.policy or finder.expansion_policy.items[0]) finder.filter_policy.select(args.filter) func = _process_multi_smiles if multi_smiles else _process_single_smiles - func(args.smiles, finder, args.output, args.cluster, args.route_distance_model) + func( + args.smiles, + finder, + args.output, + args.cluster, + args.route_distance_model, + post_processing, + ) if __name__ == "__main__": diff --git a/aizynthfinder/search/mcts/node.py b/aizynthfinder/search/mcts/node.py index 509ce2a..05151b2 100644 --- a/aizynthfinder/search/mcts/node.py +++ b/aizynthfinder/search/mcts/node.py @@ -244,6 +244,9 @@ def expand(self) -> None: self.is_expandable = False self.is_expanded = False + if self.tree: + self.tree.profiling["expansion_calls"] += 1 + def is_terminal(self) -> bool: """ Node is terminal if its unexpandable, or the internal state is terminal (solved) @@ -413,7 +416,9 @@ def _select_child(self, child_idx: int) -> Optional["MctsNode"]: return self._children[child_idx] reaction = self._children_actions[child_idx] - if not reaction.reactants: + if reaction.unqueried: + if self.tree: + self.tree.profiling["reactants_generations"] += 1 _ = reaction.reactants if not self._check_child_reaction(reaction): diff --git a/aizynthfinder/search/mcts/search.py b/aizynthfinder/search/mcts/search.py index 753d62a..ba10344 100644 --- a/aizynthfinder/search/mcts/search.py +++ b/aizynthfinder/search/mcts/search.py @@ -34,6 +34,10 @@ def __init__(self, config: Configuration, root_smiles: str = None) -> None: self.root = None self.config = config self._graph: Optional[nx.DiGraph] = None + self.profiling = { + "expansion_calls": 0, + "reactants_generations": 0, + } @classmethod def from_json(cls, filename: str, config: Configuration) -> "MctsSearchTree": diff --git a/aizynthfinder/training/make_false_products.py b/aizynthfinder/training/make_false_products.py index c7c32b0..197fcdb 100644 --- a/aizynthfinder/training/make_false_products.py +++ b/aizynthfinder/training/make_false_products.py @@ -16,6 +16,7 @@ reverse_template, reaction_hash, reactants_to_fingerprint, + split_reaction_smiles, ) from aizynthfinder.utils.models import CUSTOM_OBJECTS, load_keras_model @@ -27,6 +28,7 @@ Any, Callable, List, + Sequence, ) _DfGenerator = Iterable[Optional[pd.DataFrame]] @@ -71,7 +73,9 @@ def recommender_application(library: pd.DataFrame, config: Config, _) -> _DfGene topn = config["negative_data"]["recommender_topn"] def prediction_sampler(row): - fingerprint = reactants_to_fingerprint([row.reactants], config) + fingerprint = reactants_to_fingerprint( + [row[config["column_map"]["reactants"]]], config + ) fingerprint = fingerprint.reshape([1, config["fingerprint_len"]]) prediction = model.predict(fingerprint).flatten() prediction_indices = prediction.argsort()[::-1][:topn] @@ -105,21 +109,23 @@ def strict_application( def _apply_forward_reaction( template_row: pd.Series, config: Config ) -> Optional[pd.DataFrame]: - smarts_fwd = reverse_template(template_row.retro_template) - mols = create_reactants_molecules(template_row.reactants) + smarts_fwd = reverse_template(template_row[config["column_map"]["retro_template"]]) + mols = create_reactants_molecules(template_row[config["column_map"]["reactants"]]) try: - ref_mol = Molecule(smiles=template_row.products, sanitize=True) + ref_mol = Molecule( + smiles=template_row[config["column_map"]["products"]], sanitize=True + ) except MoleculeException as err: raise _ReactionException( - f"reaction {template_row.reaction_hash} failed with msg {str(err)}" + f"reaction {template_row[config['column_map']['reaction_hash']]} failed with msg {str(err)}" ) try: products = Reaction(mols=mols, smarts=smarts_fwd).apply() except ValueError as err: raise _ReactionException( - f"reaction {template_row.reaction_hash} failed with msg {str(err)}" + f"reaction {template_row[config['column_map']['reaction_hash']]} failed with msg {str(err)}" ) new_products = {product[0] for product in products if product[0] != ref_mol} @@ -131,7 +137,7 @@ def _apply_forward_reaction( } if not correct_products: raise _ReactionException( - f"reaction {template_row.reaction_hash} failed to produce correct product" + f"reaction {template_row[config['column_map']['reaction_hash']]} failed to produce correct product" ) return _new_dataframe( @@ -139,13 +145,14 @@ def _apply_forward_reaction( config, nrows=len(new_products), reaction_hash=[ - reaction_hash(template_row.reactants, product) for product in new_products + reaction_hash(template_row[config["column_map"]["reactants"]], product) + for product in new_products ], products=[product.smiles for product in new_products], ) -def _get_config() -> Tuple[Config, str]: +def _get_config(optional_args: Optional[Sequence[str]] = None) -> Tuple[Config, str]: parser = argparse.ArgumentParser("Tool to generate artificial negative reactions") parser.add_argument("config", help="the filename to a configuration file") parser.add_argument( @@ -153,7 +160,7 @@ def _get_config() -> Tuple[Config, str]: choices=["strict", "random", "recommender"], help="the method to create random data", ) - args = parser.parse_args() + args = parser.parse_args(optional_args) return Config(args.config), args.method @@ -161,8 +168,14 @@ def _get_config() -> Tuple[Config, str]: def _new_dataframe( original: pd.Series, config: Config, nrows: int = 1, **kwargs: Any ) -> pd.DataFrame: - dict_ = {"index": 0} - for column in config["library_headers"][1:]: + dict_ = {} + if config["in_csv_headers"]: + columns = list(original.index) + print(columns) + else: + dict_["index"] = 0 + columns = config["library_headers"][1:] + for column in columns: dict_[column] = kwargs.get(column, [original[column]] * nrows) return pd.DataFrame(dict_) @@ -171,18 +184,25 @@ def _sample_library( library: pd.DataFrame, config: Config, sampler_func: Callable ) -> _DfGenerator: for _, row in library.iterrows(): - mols = create_reactants_molecules(row.reactants) + mols = create_reactants_molecules(row[config["column_map"]["reactants"]]) try: - ref_mol = Molecule(smiles=row.products, sanitize=True) + ref_mol = Molecule( + smiles=row[config["column_map"]["products"]], sanitize=True + ) except MoleculeException: yield None continue new_product = None for template_row in sampler_func(row): - if row.template_hash == template_row.template_hash: + if ( + row[config["column_map"]["template_hash"]] + == template_row[config["column_map"]["template_hash"]] + ): continue - smarts_fwd = reverse_template(template_row.retro_template) + smarts_fwd = reverse_template( + template_row[config["column_map"]["retro_template"]] + ) try: new_product = Reaction(mols=mols, smarts=smarts_fwd).apply()[0][0] except (ValueError, IndexError): @@ -199,18 +219,20 @@ def _sample_library( yield _new_dataframe( row, config, - reaction_hash=[reaction_hash(row.reactants, new_product)], + reaction_hash=[ + reaction_hash(row[config["column_map"]["reactants"]], new_product) + ], products=[new_product.smiles], classification=[""], - retro_template=[template_row.retro_template], - template_hash=[template_row.template_hash], + retro_template=[template_row[config["column_map"]["retro_template"]]], + template_hash=[template_row[config["column_map"]["template_hash"]]], selectivity=[0], outcomes=[1], - template_code=[template_row.template_code], + template_code=[template_row["template_code"]], ) -def main() -> None: +def main(optional_args: Optional[Sequence[str]] = None) -> None: """Entry-point for the make_false_products tool""" methods = { "strict": strict_application, @@ -218,15 +240,18 @@ def main() -> None: "recommender": recommender_application, } - config, selected_method = _get_config() + config, selected_method = _get_config(optional_args) filename = config.filename("library") library = pd.read_csv( filename, index_col=False, - header=None, - names=config["library_headers"], + header=0 if config["in_csv_headers"] else None, + names=None if config["in_csv_headers"] else config["library_headers"], + sep=config["csv_sep"], ) - false_lib = pd.DataFrame({column: [] for column in config["library_headers"]}) + if config["reaction_smiles_column"]: + library = split_reaction_smiles(library, config) + false_lib = pd.DataFrame({column: [] for column in library.columns}) progress_bar = tqdm.tqdm(total=len(library)) errors: List[str] = [] @@ -238,9 +263,9 @@ def main() -> None: false_lib.to_csv( config.filename("false_library"), - mode="w", - header=False, + header=config["in_csv_headers"], index=False, + sep=config["csv_sep"], ) with open(config.filename("_errors.txt"), "w") as fileobj: fileobj.write("\n".join(errors)) diff --git a/aizynthfinder/training/preprocess_expansion.py b/aizynthfinder/training/preprocess_expansion.py index 54292be..15cb7fc 100644 --- a/aizynthfinder/training/preprocess_expansion.py +++ b/aizynthfinder/training/preprocess_expansion.py @@ -2,6 +2,7 @@ """ import argparse import os +from typing import Sequence, Optional import pandas as pd import numpy as np @@ -13,6 +14,7 @@ split_and_save_data, smiles_to_fingerprint, is_sanitizable, + split_reaction_smiles, ) @@ -28,59 +30,75 @@ def _filter_dataset(config: Config) -> pd.DataFrame: full_data = pd.read_csv( filename, index_col=False, - header=None, - names=config["library_headers"][:-1], + header=0 if config["in_csv_headers"] else None, + names=None if config["in_csv_headers"] else config["library_headers"][:-1], + sep=config["csv_sep"], ) + if config["reaction_smiles_column"]: + full_data = split_reaction_smiles(full_data, config) if config["remove_unsanitizable_products"]: - products = full_data["products"].to_numpy() + products = full_data[config["column_map"]["products"]].to_numpy() idx = np.apply_along_axis(is_sanitizable, 0, [products]) full_data = full_data[idx] - full_data = full_data.drop_duplicates(subset="reaction_hash") - template_group = full_data.groupby("template_hash") + template_hash_col = config["column_map"]["template_hash"] + full_data = full_data.drop_duplicates(subset=config["column_map"]["reaction_hash"]) + template_group = full_data.groupby(template_hash_col) template_group = template_group.size().sort_values(ascending=False) min_index = template_group[template_group >= config["template_occurrence"]].index - dataset = full_data[full_data["template_hash"].isin(min_index)] + dataset = full_data[full_data[template_hash_col].isin(min_index)] template_labels = LabelEncoder() dataset = dataset.assign( - template_code=template_labels.fit_transform(dataset["template_hash"]) + template_code=template_labels.fit_transform(dataset[template_hash_col]) ) dataset.to_csv( config.filename("library"), mode="w", - header=False, + header=config["in_csv_headers"], index=False, + sep=config["csv_sep"], ) return dataset -def _get_config() -> Config: +def _get_config(optional_args: Optional[Sequence[str]] = None) -> Config: parser = argparse.ArgumentParser( "Tool to pre-process a template library to be used in training a expansion network policy" ) parser.add_argument("config", help="the filename to a configuration file") - args = parser.parse_args() + args = parser.parse_args(optional_args) return Config(args.config) def _save_unique_templates(dataset: pd.DataFrame, config: Config) -> None: - template_group = dataset.groupby("template_hash", sort=False).size() - dataset = dataset[["retro_template", "template_code"] + config["metadata_headers"]] + template_hash_col = config["column_map"]["template_hash"] + template_group = dataset.groupby(template_hash_col, sort=False).size() + dataset = dataset[ + [config["column_map"]["retro_template"], "template_code"] + + config["metadata_headers"] + ] if "classification" in dataset.columns: dataset["classification"].fillna("-", inplace=True) dataset = dataset.drop_duplicates(subset="template_code", keep="first") dataset["library_occurrence"] = template_group.values dataset.set_index("template_code", inplace=True) dataset = dataset.sort_index() + dataset.rename( + columns={ + template_hash_col: "template_hash", + config["column_map"]["retro_template"]: "retro_template", + }, + inplace=True, + ) dataset.to_hdf(config.filename("unique_templates"), "table") -def main() -> None: +def main(optional_args: Optional[Sequence[str]] = None) -> None: """Entry-point for the preprocess_expansion tool""" - config = _get_config() + config = _get_config(optional_args) if config["library_headers"][-1] != "template_code": config["library_headers"].append("template_code") @@ -91,17 +109,20 @@ def main() -> None: dataset = pd.read_csv( filename, index_col=False, - header=None, - names=config["library_headers"], + header=0 if config["in_csv_headers"] else None, + names=None if config["in_csv_headers"] else config["library_headers"], + sep=config["csv_sep"], ) + if config["reaction_smiles_column"]: + dataset = split_reaction_smiles(dataset, config) print("Dataset filtered/loaded, generating labels...", flush=True) labelb = LabelBinarizer(neg_label=0, pos_label=1, sparse_output=True) - labels = labelb.fit_transform(dataset["template_hash"]) + labels = labelb.fit_transform(dataset[config["column_map"]["template_hash"]]) split_and_save_data(labels, "labels", config) print("Labels created and split, generating inputs...", flush=True) - products = dataset["products"].to_numpy() + products = dataset[config["column_map"]["products"]].to_numpy() inputs = np.apply_along_axis(smiles_to_fingerprint, 0, [products], config) inputs = sparse.lil_matrix(inputs.T).tocsr() split_and_save_data(inputs, "inputs", config) diff --git a/aizynthfinder/training/preprocess_filter.py b/aizynthfinder/training/preprocess_filter.py index 5448921..2727a7c 100644 --- a/aizynthfinder/training/preprocess_filter.py +++ b/aizynthfinder/training/preprocess_filter.py @@ -1,6 +1,7 @@ """ Module routines for pre-processing data for filter policy training """ import argparse +from typing import Sequence, Optional import pandas as pd import numpy as np @@ -11,46 +12,52 @@ split_and_save_data, smiles_to_fingerprint, reaction_to_fingerprints, + split_reaction_smiles, ) -def _get_config() -> Config: +def _get_config(optional_args: Optional[Sequence[str]] = None) -> Config: parser = argparse.ArgumentParser( "Tool to pre-process a template library to be used to train a in-scope filter network policy" ) parser.add_argument("config", help="the filename to a configuration file") - args = parser.parse_args() + args = parser.parse_args(optional_args) return Config(args.config) -def main() -> None: +def main(optional_args: Optional[Sequence[str]] = None) -> None: """Entry-point for the preprocess_filter tool""" - config = _get_config() + config = _get_config(optional_args) true_dataset = pd.read_csv( config.filename("library"), index_col=False, - header=None, - names=config["library_headers"][:-1], + header=0 if config["in_csv_headers"] else None, + names=None if config["in_csv_headers"] else config["library_headers"][:-1], + sep=config["csv_sep"], ) true_dataset["true_product"] = 1 false_dataset = pd.read_csv( config.filename("false_library"), index_col=False, - header=None, - names=config["library_headers"][:-1], + header=0 if config["in_csv_headers"] else None, + names=None if config["in_csv_headers"] else config["library_headers"][:-1], + sep=config["csv_sep"], ) false_dataset["true_product"] = 0 dataset = true_dataset.append(false_dataset, sort=False) + if config["reaction_smiles_column"]: + dataset = split_reaction_smiles(dataset, config) + print("Dataset loaded, generating Labels...", flush=True) labels = dataset["true_product"].to_numpy() split_and_save_data(labels, "labels", config) print("Labels created and split, generating Inputs...", flush=True) - products = dataset["products"].to_numpy() - reactants = dataset["reactants"].to_numpy() + products = dataset[config["column_map"]["products"]].to_numpy() + reactants = dataset[config["column_map"]["reactants"]].to_numpy() inputs = np.apply_along_axis( reaction_to_fingerprints, 0, [products, reactants], config ).astype(np.int8) diff --git a/aizynthfinder/training/preprocess_recommender.py b/aizynthfinder/training/preprocess_recommender.py index af65400..ccd7b35 100644 --- a/aizynthfinder/training/preprocess_recommender.py +++ b/aizynthfinder/training/preprocess_recommender.py @@ -1,6 +1,7 @@ """ Module routines for pre-processing data for recommender training """ import argparse +from typing import Sequence, Optional import pandas as pd import numpy as np @@ -11,46 +12,50 @@ Config, split_and_save_data, reactants_to_fingerprint, + split_reaction_smiles, ) -def _get_config() -> Config: +def _get_config(optional_args: Optional[Sequence[str]] = None) -> Config: parser = argparse.ArgumentParser( "Tool to pre-process a template library to be used to train a recommender network" ) parser.add_argument("config", help="the filename to a configuration file") - args = parser.parse_args() + args = parser.parse_args(optional_args) return Config(args.config) def _save_unique_templates(dataset: pd.DataFrame, config: Config) -> None: - dataset = dataset[["retro_template", "template_code"]] + dataset = dataset[[config["column_map"]["retro_template"], "template_code"]] dataset = dataset.drop_duplicates(subset="template_code", keep="first") dataset.set_index("template_code", inplace=True) dataset = dataset.sort_index() dataset.to_hdf(config.filename("unique_templates"), "table") -def main() -> None: +def main(optional_args: Optional[Sequence[str]] = None) -> None: """Entry-point for the preprocess_recommender tool""" - config = _get_config() + config = _get_config(optional_args) filename = config.filename("library") dataset = pd.read_csv( filename, index_col=False, - header=None, - names=config["library_headers"], + header=0 if config["in_csv_headers"] else None, + names=None if config["in_csv_headers"] else config["library_headers"], + sep=config["csv_sep"], ) + if config["reaction_smiles_column"]: + dataset = split_reaction_smiles(dataset, config) print("Dataset loaded, generating Labels...", flush=True) labelbin = LabelBinarizer(neg_label=0, pos_label=1, sparse_output=True) - labels = labelbin.fit_transform(dataset["template_hash"]) + labels = labelbin.fit_transform(dataset[config["column_map"]["template_hash"]]) split_and_save_data(labels, "labels", config) print("Labels created and split, generating Inputs...", flush=True) - reactants = dataset["reactants"].to_numpy() + reactants = dataset[config["column_map"]["reactants"]].to_numpy() inputs = np.apply_along_axis(reactants_to_fingerprint, 0, [reactants], config) inputs = sparse.lil_matrix(inputs.T).tocsr() split_and_save_data(inputs, "inputs", config) diff --git a/aizynthfinder/training/training.py b/aizynthfinder/training/training.py index ebe6dd2..e08fcd5 100644 --- a/aizynthfinder/training/training.py +++ b/aizynthfinder/training/training.py @@ -1,6 +1,7 @@ """ Module containing routines to setup the training of policies. """ import argparse +from typing import Optional, Sequence from aizynthfinder.training.utils import Config from aizynthfinder.training.keras_models import ( @@ -10,7 +11,7 @@ ) -def main() -> None: +def main(optional_args: Optional[Sequence[str]] = None) -> None: """Entry-point for the aizynth_training tool""" parser = argparse.ArgumentParser("Tool to train a network policy") parser.add_argument("config", help="the filename to a configuration file") @@ -19,7 +20,7 @@ def main() -> None: choices=["expansion", "filter", "recommender"], help="the model to train", ) - args = parser.parse_args() + args = parser.parse_args(optional_args) config = Config(args.config) if args.model == "expansion": diff --git a/aizynthfinder/training/utils.py b/aizynthfinder/training/utils.py index 92898cc..120f685 100644 --- a/aizynthfinder/training/utils.py +++ b/aizynthfinder/training/utils.py @@ -193,7 +193,12 @@ def split_and_save_data( for label_prefix, arr in array_dict.items(): filename = config.filename(label_prefix + data_label) if isinstance(data, pd.DataFrame): - arr.to_csv(filename, mode="w", header=False, index=False) + arr.to_csv( + filename, + sep=config["csv_sep"], + header=config["in_csv_headers"], + index=False, + ) elif isinstance(data, np.ndarray): np.savez(filename, arr) else: @@ -257,3 +262,20 @@ def reaction_to_fingerprints(args: Sequence[str], config: Config) -> np.ndarray: reactant_fp = reactants_to_fingerprint([reactants_smiles], config) return (product_fp - reactant_fp).astype(np.int8) + + +def split_reaction_smiles(data: pd.DataFrame, config: Config) -> pd.DataFrame: + """ + Split a column of reaction SMILES into reactant and product columns + + :param data: the dateframe to process + :param config: the training configuration + :return: the new dataframe + """ + smiles_split = data[config["reaction_smiles_column"]].str.split(">", expand=True) + return data.assign( + **{ + config["column_map"]["reactants"]: smiles_split[0], + config["column_map"]["products"]: smiles_split[2], + } + ) diff --git a/aizynthfinder/utils/image.py b/aizynthfinder/utils/image.py index 41ccb55..c18d25e 100644 --- a/aizynthfinder/utils/image.py +++ b/aizynthfinder/utils/image.py @@ -51,23 +51,28 @@ def _clean_up_images() -> None: pass -def molecule_to_image(mol: Molecule, frame_color: PilColor) -> PilImage: +def molecule_to_image( + mol: Molecule, frame_color: PilColor, size: int = 300 +) -> PilImage: """ Create a pretty image of a molecule, with a colored frame around it :param mol: the molecule :param frame_color: the color of the frame + :param size: the size of the image :return: the produced image """ mol = Chem.MolFromSmiles(mol.smiles) - img = Draw.MolToImage(mol) + img = Draw.MolToImage(mol, size=(size, size)) cropped_img = crop_image(img) return draw_rounded_rectangle(cropped_img, frame_color) def molecules_to_images( - mols: Sequence[Molecule], frame_colors: Sequence[PilColor] + mols: Sequence[Molecule], + frame_colors: Sequence[PilColor], + size: int = 300, ) -> List[PilImage]: """ Create pretty images of molecules with a colored frame around each one of them. @@ -76,9 +81,9 @@ def molecules_to_images( :param mols: the molecules :param frame_colors: the color of the frame for each molecule + :param size: the sub-image size :return: the produced images """ - size = 300 # Make sanitized copies of all molecules mol_copies = [mol.make_unique() for mol in mols] for mol in mol_copies: @@ -172,7 +177,7 @@ def draw_rounded_rectangle( def save_molecule_images( - molecules: Sequence[Molecule], frame_colors: Sequence[PilColor] + molecules: Sequence[Molecule], frame_colors: Sequence[PilColor], size: int = 300 ) -> Dict[Molecule, str]: """ Create images of a list of molecules and save them to disc @@ -180,16 +185,17 @@ def save_molecule_images( :param molecules: the molecules to save as images :param frame_colors: the color of the frame around each image + :param size: the sub-image size for each molecule :return: the filename of the created images """ global IMAGE_FOLDER try: - images = molecules_to_images(molecules, frame_colors) + images = molecules_to_images(molecules, frame_colors, size) # pylint: disable=broad-except except Exception: # noqa images = [ - molecule_to_image(molecule, frame_color) + molecule_to_image(molecule, frame_color, size) for molecule, frame_color in zip(molecules, frame_colors) ] @@ -206,6 +212,7 @@ def make_graphviz_image( reactions: Union[Sequence[RetroReaction], Sequence[FixedRetroReaction]], edges: Sequence[Tuple[Any, Any]], frame_colors: Sequence[PilColor], + reaction_shapes: Sequence[str] = None, use_splines: bool = True, ) -> PilImage: """ @@ -216,6 +223,7 @@ def make_graphviz_image( :param reactions: the reaction nodes :param edges: the edges of the graph :param frame_colors: the color of the frame around each image + :param reaction_shapes: optional specification of shapes for each reaction :param use_splines: if True tries to use splines to connect nodes in image :raises FileNotFoundError: if the image could not be produced :return: the create image @@ -224,7 +232,7 @@ def make_graphviz_image( def _create_image(use_splines): txt = template.render( molecules=mol_spec, - reactions=reactions, + reactions=rxn_spec, edges=edges, use_splines=use_splines, ) @@ -242,6 +250,8 @@ def _create_image(use_splines): return output_img2 mol_spec = save_molecule_images(molecules, frame_colors) + reaction_shapes = reaction_shapes or ["circle"] * len(reactions) + rxn_spec = zip(reactions, reaction_shapes) template_filepath = os.path.join(data_path(), "templates", "reaction_tree.dot") with open(template_filepath, "r") as fileobj: diff --git a/docs/conf.py b/docs/conf.py index d62bc7b..2830fbc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -6,7 +6,7 @@ project = "aizynthfinder" copyright = "2020, Molecular AI group" author = "Molecular AI group" -release = "3.1.0" +release = "3.2.0" # This make sure that the cli_help.txt file is properly formated with open("cli_help.txt", "r") as fileobj: diff --git a/docs/configuration.rst b/docs/configuration.rst index 1003676..28fcb6c 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -57,7 +57,7 @@ A more detailed configuration file is shown below filter: files: my_policy: /path/to/keras/model/weights.hdf5 - ReactantsCountFilter: + reactants_count: filter_tag: stock: files: @@ -69,7 +69,7 @@ The (expansion) policy models are specified using two files * a HDF5 file containing templates. The filter policy model is specified using a single checkpoint file from Keras in hdf5 format. Any additional -filters can be specified using the classname, and a tag as seen above using ``ReactantsCountFilter`` with the flag +filters can be specified using the classname, and a tag as seen above using ``reactants_count`` with the flag ``filter_tag``. The template file should be readable by ``pandas`` using the ``table`` key and the ``retro_template`` column. diff --git a/pyproject.toml b/pyproject.toml index b7fb8aa..268f40e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aizynthfinder" -version = "3.1.0" +version = "3.2.0" description = "Retrosynthetic route finding using neural network guided Monte-Carlo tree search" authors = ["Molecular AI group "] license = "MIT" diff --git a/tests/context/test_policy.py b/tests/context/test_policy.py index 7ea5d27..86fd404 100644 --- a/tests/context/test_policy.py +++ b/tests/context/test_policy.py @@ -224,9 +224,7 @@ def test_load_filter_policy_from_config_custom(default_config, mock_keras_model) filter_policy.load_from_config( **{ "QuickKerasFilter": {"policy1": {"source": "dummy1"}}, - "aizynthfinder.context.policy.QuickKerasFilter": { - "policy2": {"source": "dummy1"} - }, + "feasibility": {"policy2": {"source": "dummy1"}}, } ) assert "policy1" in filter_policy.items diff --git a/tests/data/dummy2_raw_template_library.csv b/tests/data/dummy2_raw_template_library.csv new file mode 100644 index 0000000..4b0d4d7 --- /dev/null +++ b/tests/data/dummy2_raw_template_library.csv @@ -0,0 +1,13 @@ +ID,PseudoHash,RSMI,classification,retro_template,template_hash +0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX +0,AAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX +0,BAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX +0,ABA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX +0,AAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX +0,BBA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXX +0,ABB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY +0,BAB,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY +0,CAA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY +0,ACA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY +0,AAC,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),XXY +0,CCA,[C:0]NO.[C:1]CCCOc1ccc(CC(=O)Cl)cc1>>CCCCOc1ccc(CC(=O)N(C)O)cc1,unclassified,([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])>>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6]),DXX \ No newline at end of file diff --git a/tests/data/make_false2_template_library.csv b/tests/data/make_false2_template_library.csv new file mode 100644 index 0000000..88ba6e3 --- /dev/null +++ b/tests/data/make_false2_template_library.csv @@ -0,0 +1,4 @@ +PseudoHash RSMI retro_template template_hash template_code +039c3d4c9304218d50b06ed252b67c1ed287ea5331fa2755fdd984e6 O=C1CCC(=O)N1[Br:1].[F:2][c:3]1[cH:4][cH:5][c:6]2[c:7]([cH:8]1)[N:9]([CH3:10])[C:11](=[O:12])[CH2:13][CH2:14]2.CCOC(C)=O.CN(C)C=O.O>>[Br:1][c:4]1[c:3]([F:2])[cH:8][c:7]2[c:6]([cH:5]1)[CH2:14][CH2:13][C:11](=[O:12])[N:9]2[CH3:10] ([Br;H0;D1;+0:1]-[c;H0;D3;+0:3](:[c:2]):[c:4])>>(O=C1-C-C-C(=O)-N-1-[Br;H0;D1;+0:1]).([c:2]:[cH;D2;+0:3]:[c:4]) 04aac57316494f5621172b80523cd496e017f4d0104703ad33b2fa4b 0 +17774f52ad0f5f143d46c931fd76d0019b368a6127051cd9630916da [C:1]([C:2]#[CH:3])(=[O:4])[O:5][CH2:6][CH3:7].[N:8](=[N+:9]=[N-:10])[CH2:11][CH2:12][CH2:13][CH2:14][n:15]1[c:16](=[O:17])[cH:18][c:19]([NH:20][C:21]([CH2:22][c:23]2[cH:24][cH:25][cH:26][cH:27][cH:28]2)=[O:29])[cH:30][cH:31]1.CC(=O)O.CCN(C(C)C)C(C)C.ClCCl.[Cu]I>>[C:1]([c:2]1[cH:3][n:8]([CH2:11][CH2:12][CH2:13][CH2:14][n:15]2[c:16](=[O:17])[cH:18][c:19]([NH:20][C:21]([CH2:22][c:23]3[cH:24][cH:25][cH:26][cH:27][cH:28]3)=[O:29])[cH:30][cH:31]2)[n:9][n:10]1)(=[O:4])[O:5][CH2:6][CH3:7] ([C:8]-[n;H0;D3;+0:9]1:[cH;D2;+0:7]:[c;H0;D3;+0:6](-[C:4](=[O;D1;H0:5])-[#8:3]-[C:2]-[C;D1;H3:1]):[n;H0;D2;+0:11]:[n;H0;D2;+0:10]:1)>>([C;D1;H3:1]-[C:2]-[#8:3]-[C:4](=[O;D1;H0:5])-[C;H0;D2;+0:6]#[CH;D1;+0:7]).([C:8]-[N;H0;D2;+0:9]=[N+;H0;D2:10]=[N-;H0;D1:11]) 9af1e46e34a295ad5081f37e7127e808912f4d1725b636f26cc50a4b 1 +6fde2b6d9c02503241835575c60690f5530e5207ed95efe44724f1d2 CC[O:1][C:2]([C:3]([CH2:4][c:5]1[cH:6][cH:7][cH:8][cH:9][cH:10]1)([CH3:11])[S:12](=[O:13])(=[O:14])[CH2:15][c:16]1[cH:17][cH:18][c:19]([O:20][CH3:21])[cH:22][cH:23]1)=[O:24].CO.[Na+].[OH-]>>[OH:1][C:2]([C:3]([CH2:4][c:5]1[cH:6][cH:7][cH:8][cH:9][cH:10]1)([CH3:11])[S:12](=[O:13])(=[O:14])[CH2:15][c:16]1[cH:17][cH:18][c:19]([O:20][CH3:21])[cH:22][cH:23]1)=[O:24] ([C:3]-[C:2](=[O;D1;H0:4])-[OH;D1;+0:1])>>(C-C-[O;H0;D2;+0:1]-[C:2](-[C:3])=[O;D1;H0:4]) 9a4c5ceadb3b4a753bb8f03491f138cd875a84f1d327e818237f90df 2 diff --git a/tests/data/post_processing_test.py b/tests/data/post_processing_test.py new file mode 100644 index 0000000..61caeed --- /dev/null +++ b/tests/data/post_processing_test.py @@ -0,0 +1,2 @@ +def post_processing(finder): + return {"quantity": 5, "another_quantity": 10} diff --git a/tests/test_cli.py b/tests/test_cli.py index b2b7ece..ed8e963 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,6 @@ import os import glob +import sys import pandas as pd import yaml @@ -126,6 +127,29 @@ def test_cli_multiple_smiles( assert f"Output saved to {output_name}" in output.out +def test_cli_single_smile_with_postprocessing( + mocker, add_cli_arguments, tmpdir, capsys, shared_datadir +): + module_path = str(shared_datadir) + sys.path.append(module_path) + finder_patch = mocker.patch("aizynthfinder.interfaces.aizynthcli.AiZynthFinder") + finder_patch.return_value.extract_statistics.return_value = {"a": 1, "b": 2} + mocker.patch("aizynthfinder.interfaces.aizynthcli.json.dump") + output_name = str(tmpdir / "trees.json") + add_cli_arguments( + "--post_processing post_processing_test --smiles COO --config config_local.yml --output " + + output_name + ) + + cli_main() + + output = capsys.readouterr() + assert "quantity: 5" in output.out + assert "another quantity: 10" in output.out + + sys.path.remove(module_path) + + def test_make_stock_from_plain_file( create_dummy_smiles_source, tmpdir, add_cli_arguments, default_config ): @@ -158,9 +182,8 @@ def test_preprocess_expansion(write_yaml, shared_datadir, add_cli_arguments): "split_size": {"training": 0.6, "testing": 0.2, "validation": 0.2}, } ) - add_cli_arguments(config_path) - expansion_main() + expansion_main([config_path]) with open(shared_datadir / "dummy_template_library.csv", "r") as fileobj: lines = fileobj.read().splitlines() @@ -204,9 +227,8 @@ def test_preprocess_expansion_no_class(write_yaml, shared_datadir, add_cli_argum "split_size": {"training": 0.6, "testing": 0.2, "validation": 0.2}, } ) - add_cli_arguments(config_path) - expansion_main() + expansion_main([config_path]) with open(shared_datadir / "dummy_noclass_template_library.csv", "r") as fileobj: lines = fileobj.read().splitlines() @@ -233,6 +255,46 @@ def test_preprocess_expansion_no_class(write_yaml, shared_datadir, add_cli_argum assert column in data.columns +def test_preprocess_expansion_csv_headers(write_yaml, shared_datadir): + config_path = write_yaml( + { + "file_prefix": str(shared_datadir / "dummy2"), + "split_size": {"training": 0.6, "testing": 0.2, "validation": 0.2}, + "column_map": { + "reaction_hash": "PseudoHash", + }, + "in_csv_headers": True, + "reaction_smiles_column": "RSMI", + } + ) + + expansion_main([config_path]) + + with open(shared_datadir / "dummy2_template_library.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 11 + + with open(shared_datadir / "dummy2_training.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 7 + + with open(shared_datadir / "dummy2_testing.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 3 + + with open(shared_datadir / "dummy2_validation.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 3 + + data = pd.read_hdf(shared_datadir / "dummy2_unique_templates.hdf5", "table") + config = Config(config_path) + assert len(data) == 2 + assert "retro_template" in data.columns + assert "library_occurrence" in data.columns + for column in config["metadata_headers"]: + assert column in data.columns + + def test_preprocess_expansion_bad_product( write_yaml, shared_datadir, add_cli_arguments ): @@ -242,9 +304,9 @@ def test_preprocess_expansion_bad_product( "split_size": {"training": 0.6, "testing": 0.2, "validation": 0.2}, } ) - add_cli_arguments(config_path) + with pytest.raises(MoleculeException): - expansion_main() + expansion_main([config_path]) def test_preprocess_expansion_skip_bad_product( @@ -257,9 +319,8 @@ def test_preprocess_expansion_skip_bad_product( "remove_unsanitizable_products": True, } ) - add_cli_arguments(config_path) - expansion_main() + expansion_main([config_path]) with open(shared_datadir / "dummy_sani_template_library.csv", "r") as fileobj: lines = fileobj.read().splitlines() @@ -273,9 +334,8 @@ def test_preprocess_recommender(write_yaml, shared_datadir, add_cli_arguments): "split_size": {"training": 0.6, "testing": 0.2, "validation": 0.2}, } ) - add_cli_arguments(config_path) - expansion_main() + expansion_main([config_path]) with open(shared_datadir / "dummy_template_library.csv", "r") as fileobj: lines = fileobj.read().splitlines() @@ -286,7 +346,7 @@ def test_preprocess_recommender(write_yaml, shared_datadir, add_cli_arguments): os.remove(shared_datadir / "dummy_validation.csv") os.remove(shared_datadir / "dummy_unique_templates.hdf5") - recommender_main() + recommender_main([config_path]) with open(shared_datadir / "dummy_training.csv", "r") as fileobj: lines = fileobj.read().splitlines() @@ -304,6 +364,50 @@ def test_preprocess_recommender(write_yaml, shared_datadir, add_cli_arguments): assert len(data) == 2 +def test_preprocess_recommender_csv_headers( + write_yaml, shared_datadir, add_cli_arguments +): + config_path = write_yaml( + { + "file_prefix": str(shared_datadir / "dummy2"), + "split_size": {"training": 0.6, "testing": 0.2, "validation": 0.2}, + "column_map": { + "reaction_hash": "PseudoHash", + }, + "in_csv_headers": True, + "reaction_smiles_column": "RSMI", + } + ) + + expansion_main([config_path]) + + with open(shared_datadir / "dummy2_template_library.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 11 + + os.remove(shared_datadir / "dummy2_training.csv") + os.remove(shared_datadir / "dummy2_testing.csv") + os.remove(shared_datadir / "dummy2_validation.csv") + os.remove(shared_datadir / "dummy2_unique_templates.hdf5") + + recommender_main([config_path]) + + with open(shared_datadir / "dummy2_training.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 7 + + with open(shared_datadir / "dummy2_testing.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 3 + + with open(shared_datadir / "dummy2_validation.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 3 + + data = pd.read_hdf(shared_datadir / "dummy2_unique_templates.hdf5", "table") + assert len(data) == 2 + + def test_preprocess_filter(write_yaml, shared_datadir, add_cli_arguments): def duplicate_file(filename): with open(shared_datadir / filename, "r") as fileobj: @@ -328,14 +432,12 @@ def duplicate_file(filename): } ) - add_cli_arguments(f"{config_path} strict") - make_false_main() + make_false_main([config_path, "strict"]) duplicate_file("make_false_template_library_false.csv") duplicate_file("make_false_template_library.csv") - add_cli_arguments(config_path) - filter_main() + filter_main([config_path]) with open(shared_datadir / "make_false_training.csv", "r") as fileobj: lines = fileobj.read().splitlines() @@ -350,6 +452,55 @@ def duplicate_file(filename): assert len(lines) == 2 +def test_preprocess_filter_csv_headers(write_yaml, shared_datadir, add_cli_arguments): + def duplicate_file(filename): + with open(shared_datadir / filename, "r") as fileobj: + lines = fileobj.read().splitlines() + lines = lines + lines[1:] + with open(shared_datadir / filename, "w") as fileobj: + fileobj.write("\n".join(lines)) + + config_path = write_yaml( + { + "file_prefix": str(shared_datadir / "make_false2"), + "split_size": {"training": 0.6, "testing": 0.2, "validation": 0.2}, + "library_headers": [ + "index", + "reaction_hash", + "reactants", + "products", + "retro_template", + "template_hash", + "template_code", + ], + "column_map": { + "reaction_hash": "PseudoHash", + }, + "in_csv_headers": True, + "reaction_smiles_column": "RSMI", + "csv_sep": "\t", + } + ) + + make_false_main([config_path, "strict"]) + + duplicate_file("make_false2_template_library_false.csv") + duplicate_file("make_false2_template_library.csv") + + filter_main([config_path]) + with open(shared_datadir / "make_false2_training.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 7 + + with open(shared_datadir / "make_false2_testing.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 3 + + with open(shared_datadir / "make_false2_validation.csv", "r") as fileobj: + lines = fileobj.read().splitlines() + assert len(lines) == 3 + + def test_make_false_products(write_yaml, shared_datadir, add_cli_arguments): config_path = write_yaml( { @@ -365,9 +516,8 @@ def test_make_false_products(write_yaml, shared_datadir, add_cli_arguments): ], } ) - add_cli_arguments(f"{config_path} strict") - make_false_main() + make_false_main([config_path, "strict"]) with open(shared_datadir / "make_false_template_library_false.csv", "r") as fileobj: lines = fileobj.read().splitlines()