diff --git a/sklearn/__init__.py b/sklearn/__init__.py index ecb32f9dc0da3..a1b5b75b9dec1 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -89,6 +89,7 @@ __all__ = [ "calibration", + "callback", "cluster", "covariance", "cross_decomposition", diff --git a/sklearn/base.py b/sklearn/base.py index a7c93937ebe72..62c99025b37b3 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -6,15 +6,20 @@ import copy import functools import inspect +import pickle import platform import re import warnings from collections import defaultdict +from functools import partial +from shutil import rmtree import numpy as np from . import __version__ from ._config import config_context, get_config +from .callback import BaseCallback, ComputationTree +from .callback._base import CallbackContext from .exceptions import InconsistentVersionWarning from .utils import _IS_32BIT from .utils._estimator_html_repr import estimator_html_repr @@ -115,6 +120,10 @@ def _clone_parametrized(estimator, *, safe=True): params_set = new_object.get_params(deep=False) + # copy callbacks + if hasattr(estimator, "_callbacks"): + new_object._callbacks = estimator._callbacks + # quick sanity check of the parameters of the clone for name in new_object_params: param1 = new_object_params[name] @@ -641,6 +650,181 @@ class attribute, which is a dictionary `param_name: list of constraints`. See caller_name=self.__class__.__name__, ) + def _set_callbacks(self, callbacks): + """Set callbacks for the estimator. + + Parameters + ---------- + callbacks : callback or list of callbacks + the callbacks to set. + + Returns + ------- + self : estimator instance + The estimator instance itself. + """ + if not isinstance(callbacks, list): + callbacks = [callbacks] + + if not all(isinstance(callback, BaseCallback) for callback in callbacks): + raise TypeError("callbacks must be subclasses of BaseCallback.") + + self._callbacks = callbacks + + return self + + # XXX should be a method of MetaEstimatorMixin but this mixin can't handle all + # meta-estimators. + def _propagate_callbacks(self, sub_estimator, *, parent_node): + """Propagate the auto-propagated callbacks to a sub-estimator + + Parameters + ---------- + sub_estimator : estimator instance + The sub-estimator to propagate the callbacks to. + + parent_node : ComputationNode instance + The computation node in this estimator to set as parent_node to the + computation tree of the sub-estimator. It must be the node where the fit + method of the sub-estimator is called. + """ + if hasattr(sub_estimator, "_callbacks") and any( + callback.auto_propagate for callback in sub_estimator._callbacks + ): + bad_callbacks = [ + callback.__class__.__name__ + for callback in sub_estimator._callbacks + if callback.auto_propagate + ] + raise TypeError( + f"The sub-estimators ({sub_estimator.__class__.__name__}) of a" + f" meta-estimator ({self.__class__.__name__}) can't have" + f" auto-propagated callbacks ({bad_callbacks})." + " Set them directly on the meta-estimator." + ) + + if not hasattr(self, "_callbacks"): + return + + propagated_callbacks = [ + callback for callback in self._callbacks if callback.auto_propagate + ] + + if not propagated_callbacks: + return + + sub_estimator._parent_ct_node = parent_node + + sub_estimator._set_callbacks( + getattr(sub_estimator, "_callbacks", []) + propagated_callbacks + ) + + def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): + """Evaluate the on_fit_begin method of the callbacks + + The computation tree is also built at this point. + + This method should be called after all data and parameters validation. + + Parameters + ---------- + X : ndarray or sparse matrix, default=None + The training data. + + y : ndarray, default=None + The target. + + levels : list of dict + A description of the nested levels of computation of the estimator to build + the computation tree. It's a list of dict with "descr" and "max_iter" keys. + + Returns + ------- + root : ComputationNode instance + The root of the computation tree. + """ + self._computation_tree = ComputationTree( + estimator_name=self.__class__.__name__, + levels=levels, + parent_node=getattr(self, "_parent_ct_node", None), + ) + + if not hasattr(self, "_callbacks"): + return self._computation_tree.root, None, None, None, None + + X_val, y_val = None, None + + if any(callback.request_validation_split for callback in self._callbacks): + splitter = next( + callback.validation_split + for callback in self._callbacks + if hasattr(callback, "validation_split") + ) + + train, val = next(splitter.split(X)) + if X is not None: + X, X_val = X[train], X[val] + if y is not None: + y, y_val = y[train], y[val] + + # + CallbackContext( + self._callbacks, + finalizer=partial(rmtree, ignore_errors=True), + finalizer_args=self._computation_tree.tree_dir, + ) + + # + file_path = self._computation_tree.tree_dir / "computation_tree.pkl" + with open(file_path, "wb") as f: + pickle.dump(self._computation_tree, f) + + # Only call the on_fit_begin method of callbacks that are not + # propagated from a meta-estimator. + for callback in self._callbacks: + if not callback._is_propagated(estimator=self): + callback.on_fit_begin(estimator=self, X=X, y=y) + + return self._computation_tree.root, X, y, X_val, y_val + + def _eval_callbacks_on_fit_end(self): + """Evaluate the on_fit_end method of the callbacks""" + if not hasattr(self, "_callbacks"): + return + + if not hasattr(self, "_computation_tree"): + return + + self._computation_tree._tree_status[0] = True + + # Only call the on_fit_end method of callbacks that are not + # propagated from a meta-estimator. + for callback in self._callbacks: + if not callback._is_propagated(estimator=self): + callback.on_fit_end() + + def _from_reconstruction_attributes(self, *, reconstruction_attributes): + """Return an as if fitted copy of this estimator + + Parameters + ---------- + reconstruction_attributes : callable + A callable that has no arguments and returns the necessary fitted attributes + to create a working fitted estimator from this instance. + + Using a callable allows lazy evaluation of the potentially costly + reconstruction attributes. + + Returns + ------- + fitted_estimator : estimator instance + The fitted copy of this estimator. + """ + new_estimator = copy.copy(self) # XXX deepcopy ? + for key, val in reconstruction_attributes().items(): + setattr(new_estimator, key, val) + return new_estimator + @property def _repr_html_(self): """HTML representation of estimator. @@ -1212,7 +1396,10 @@ def wrapper(estimator, *args, **kwargs): prefer_skip_nested_validation or global_skip_validation ) ): - return fit_method(estimator, *args, **kwargs) + try: + return fit_method(estimator, *args, **kwargs) + finally: + estimator._eval_callbacks_on_fit_end() return wrapper diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py new file mode 100644 index 0000000000000..2069ae3a58681 --- /dev/null +++ b/sklearn/callback/__init__.py @@ -0,0 +1,21 @@ +# License: BSD 3 clause + +from ._base import BaseCallback +from ._computation_tree import ComputationNode, ComputationTree, load_computation_tree +from ._early_stopping import EarlyStopping +from ._monitoring import Monitoring +from ._progressbar import ProgressBar +from ._snapshot import Snapshot +from ._text_verbose import TextVerbose + +__all__ = [ + "BaseCallback", + "ComputationNode", + "ComputationTree", + "load_computation_tree", + "Monitoring", + "EarlyStopping", + "ProgressBar", + "Snapshot", + "TextVerbose", +] diff --git a/sklearn/callback/_base.py b/sklearn/callback/_base.py new file mode 100644 index 0000000000000..42e18b431db52 --- /dev/null +++ b/sklearn/callback/_base.py @@ -0,0 +1,159 @@ +# License: BSD 3 clause + +import weakref +from abc import ABC, abstractmethod + + +# Not a method of BaseEstimator because it might not be directly called from fit but +# by a non-method function called by fit +def _eval_callbacks_on_fit_iter_end(**kwargs): + """Evaluate the on_fit_iter_end method of the callbacks + + This function should be called at the end of each computation node. + + Parameters + ---------- + kwargs : dict + arguments passed to the callback. + + Returns + ------- + stop : bool + Whether or not to stop the fit at this node. + """ + estimator = kwargs.get("estimator") + node = kwargs.get("node") + + if not hasattr(estimator, "_callbacks") or node is None: + return False + + estimator._computation_tree._tree_status[node.tree_status_idx] = True + + # stopping_criterion and reconstruction_attributes can be costly to compute. + # They are passed as lambdas for lazy evaluation. We only actually + # compute them if a callback requests it. + if any(cb.request_stopping_criterion for cb in estimator._callbacks): + kwarg = kwargs.pop("stopping_criterion", lambda: None)() + kwargs["stopping_criterion"] = kwarg + + if any(cb.request_from_reconstruction_attributes for cb in estimator._callbacks): + kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)() + kwargs["from_reconstruction_attributes"] = kwarg + + return any(callback.on_fit_iter_end(**kwargs) for callback in estimator._callbacks) + + +class BaseCallback(ABC): + """Abstract class for the callbacks""" + + @abstractmethod + def on_fit_begin(self, estimator, *, X=None, y=None): + """Method called at the beginning of the fit method of the estimator + + Only called + + Parameters + ---------- + estimator: estimator instance + The estimator the callback is set on. + X: ndarray or sparse matrix, default=None + The training data. + y: ndarray, default=None + The target. + """ + pass + + @abstractmethod + def on_fit_end(self): + """Method called at the end of the fit method of the estimator""" + pass + + @abstractmethod + def on_fit_iter_end(self, estimator, node, **kwargs): + """Method called at the end of each computation node of the estimator + + Parameters + ---------- + estimator : estimator instance + The caller estimator. It might differ from the estimator passed to the + `on_fit_begin` method for auto-propagated callbacks. + + node : ComputationNode instance + The caller computation node. + + kwargs : dict + arguments passed to the callback. Possible keys are + + - stopping_criterion: float + Usually iterations stop when `stopping_criterion <= tol`. + This is only provided at the innermost level of iterations. + + - tol: float + Tolerance for the stopping criterion. + This is only provided at the innermost level of iterations. + + - from_reconstruction_attributes: estimator instance + A ready to predict, transform, etc ... estimator as if the fit stopped + at this node. Usually it's a copy of the caller estimator with the + necessary attributes set but it can sometimes be an instance of another + class (e.g. LogisticRegressionCV -> LogisticRegression) + + - fit_state: dict + Model specific quantities updated during fit. This is not meant to be + used by generic callbacks but by a callback designed for a specific + estimator instead. + + - extra_verbose: dict + Model specific . This is not meant to be + used by generic callbacks but by a callback designed for a specific + estimator instead. + + Returns + ------- + stop : bool or None + Whether or not to stop the current level of iterations at this node. + """ + pass + + @property + def auto_propagate(self): + """Whether or not this callback should be propagated to sub-estimators. + + An auto-propagated callback (from a meta-estimator to its sub-estimators) must + be set on the meta-estimator. Its `on_fit_begin` and `on_fit_end` methods will + only be called at the beginning and end of the fit method of the meta-estimator, + while its `on_fit_iter_end` method will be called at each computation node of + the meta-estimator and its sub-estimators. + """ + return False + + def _is_propagated(self, estimator): + """Check if this callback attached to estimator has been propagated from a + meta-estimator. + """ + return self.auto_propagate and hasattr(estimator, "_parent_ct_node") + + @property + def request_stopping_criterion(self): + return False + + @property + def request_from_reconstruction_attributes(self): + return False + + @property + def request_validation_split(self): + return False + + def _set_context(self, context): + if not hasattr(self, "_callback_contexts"): + self._callback_contexts = [] + + self._callback_contexts.append(context) + + +class CallbackContext: + def __init__(self, callbacks, finalizer, finalizer_args): + for callback in callbacks: + callback._set_context(self) + weakref.finalize(self, finalizer, finalizer_args) diff --git a/sklearn/callback/_computation_tree.py b/sklearn/callback/_computation_tree.py new file mode 100644 index 0000000000000..e5be721f60ff6 --- /dev/null +++ b/sklearn/callback/_computation_tree.py @@ -0,0 +1,285 @@ +# License: BSD 3 clause + +import os +import pickle +from pathlib import Path +from tempfile import mkdtemp +from uuid import uuid4 + +import numpy as np + + +class ComputationNode: + """A node in a ComputationTree + + Parameters + ---------- + computation_tree : ComputationTree instance + The computation tree it belongs to. + + parent : ComputationNode instance, default=None + The parent node. None means this is the root. + + max_iter : int, default=None + The number of its children. None means it's a leaf. + + description : str, default=None + A description of this computation node. None means it's a leaf. + + tree_status_idx : int, default=0 + The index of the status of this node in the `tree_status` array of its + computation tree. + + idx : int, default=0 + The index of this node in the children list of its parent. + + Attributes + ---------- + children : list + The list of its children nodes. For a leaf, it's an empty list + + depth : int + The depth of this node in its computation tree. The root has a depth of 0. + """ + + def __init__( + self, + computation_tree, + parent=None, + max_iter=None, + description=None, + tree_status_idx=0, + idx=0, + ): + self.computation_tree = computation_tree + self.parent = parent + self.max_iter = max_iter + self.description = description + self.tree_status_idx = tree_status_idx + self.idx = idx + self.children = [] + self.depth = 0 if self.parent is None else self.parent.depth + 1 + + def get_ancestors(self, include_ancestor_trees=True): + """Get the list of all nodes in the path from the node to the root + + Parameters + ---------- + include_ancestor_trees : bool, default=True + If True, propagate to the tree of the `parent_node` of this tree if it + exists and so on. + + Returns + ------- + ancestors : list + The list of ancestors of this node (included). + """ + node = self + ancestors = [node] + + while node.parent is not None: + node = node.parent + ancestors.append(node) + + if include_ancestor_trees: + node_parent_tree = node.computation_tree.parent_node + if node_parent_tree is not None: + ancestors.extend(node_parent_tree.get_ancestors()) + + return ancestors + + def __repr__(self): + return ( + f"ComputationNode(description={self.description}, " + f"depth={self.depth}, idx={self.idx})" + ) + + +class ComputationTree: + """Data structure to store the computation tree of an estimator + + Parameters + ---------- + estimator_name : str + The name of the estimator. + + levels : list of dict + A description of the nested levels of computation of the estimator to build the + tree. It's a list of dict with "descr" and "max_iter" keys. + + parent_node : ComputationNode, default=None + The node where the estimator is used in the computation tree of a + meta-estimator. This node is not set to be the parent of the root of this tree. + + Attributes + ---------- + depth : int + The depth of the tree. It corresponds to the depth of its deepest leaf. + + root : ComputationNode instance + The root of the computation tree. + + tree_dir : pathlib.Path instance + The path of the directory where the computation tree is dumped during the fit of + its estimator. If it has a parent tree, this is a sub-directory of the + `tree_dir` of its parent. + + uid : uuid.UUID + Unique indentifier for a ComputationTree instance. + """ + + def __init__(self, estimator_name, levels, *, parent_node=None): + self.estimator_name = estimator_name + self.parent_node = parent_node + + self.depth = len(levels) - 1 + self.root, self.n_nodes = self._build_tree(levels) + + self.uid = uuid4() + + parent_tree_dir = ( + None + if self.parent_node is None + else self.parent_node.computation_tree.tree_dir + ) + if parent_tree_dir is None: + self.tree_dir = Path(mkdtemp()) + else: + # This tree has a parent tree. Place it in a subdir of its parent dir + # and give it a name that allows from the parent tree to find the sub dir + # of the sub tree of a given leaf. + self.tree_dir = parent_tree_dir / str(parent_node.tree_status_idx) + self.tree_dir.mkdir() + self._filename = self.tree_dir / "tree_status.memmap" + + self._set_tree_status(mode="w+") + self._tree_status[:] = False + + def _build_tree(self, levels): + """Build the computation tree from the description of the levels""" + root = ComputationNode( + computation_tree=self, + max_iter=levels[0]["max_iter"], + description=levels[0]["descr"], + ) + + n_nodes = self._recursive_build_tree(root, levels) + + return root, n_nodes + + def _recursive_build_tree(self, parent, levels, n_nodes=1): + """Recursively build the tree from the root the leaves""" + if parent.depth == self.depth: + return n_nodes + + for i in range(parent.max_iter): + children_max_iter = levels[parent.depth + 1]["max_iter"] + description = levels[parent.depth + 1]["descr"] + + node = ComputationNode( + computation_tree=self, + parent=parent, + max_iter=children_max_iter, + description=description, + tree_status_idx=n_nodes, + idx=i, + ) + parent.children.append(node) + + n_nodes = self._recursive_build_tree(node, levels, n_nodes + 1) + + return n_nodes + + def _set_tree_status(self, mode): + """Create a memory-map to the tree_status array stored on the disk""" + # This has to be done each time we unpickle the tree + self._tree_status = np.memmap( + self._filename, dtype=bool, mode=mode, shape=(self.n_nodes,) + ) + + def get_progress(self, node): + """Return the number of finished child nodes of this node""" + if self._tree_status[node.tree_status_idx]: + return node.max_iter + + # Since the children of a node are not ordered (to account for parallel + # execution), we can't rely on the highest index for which the status is True. + return sum( + [self._tree_status[child.tree_status_idx] for child in node.children] + ) + + def get_child_computation_tree_dir(self, node): + if node.children: + raise ValueError("node is not a leaf") + return self.tree_dir / str(node.tree_status_idx) + + def iterate(self, include_leaves=False): + """Return an iterable over the nodes of the computation tree + + Nodes are discovered in a depth first search manner. + + Parameters + ---------- + include_leaves : bool + Whether or not to include the leaves of the tree in the iterable + + Returns + ------- + nodes_list : list + A list of the nodes of the computation tree. + """ + return self._recursive_iterate(include_leaves=include_leaves) + + def _recursive_iterate(self, node=None, include_leaves=False, node_list=None): + """Recursively constructs the iterable""" + # TODO make it an iterator ? + if node is None: + node = self.root + node_list = [] + + if node.children or include_leaves: + node_list.append(node) + + for child in node.children: + self._recursive_iterate(child, include_leaves, node_list) + + return node_list + + def __repr__(self): + res = ( + f"[{self.estimator_name}] {self.root.description} : progress " + f"{self.get_progress(self.root)} / {self.root.max_iter}\n" + ) + for node in self.iterate(include_leaves=False): + if node is not self.root: + res += ( + f"{' ' * node.depth}{node.description} {node.idx}: progress " + f"{self.get_progress(node)} / {node.max_iter}\n" + ) + return res + + +def load_computation_tree(directory): + """load the computation tree of a directory + + Parameters + ---------- + directory : pathlib.Path instance + The directory where the computation tree is dumped + + Returns + ------- + computation_tree : ComputationTree instance + The loaded computation tree + """ + file_path = directory / "computation_tree.pkl" + if not file_path.exists() or not os.path.getsize(file_path) > 0: + # Do not try to load the tree when it's created but not yet written + return + + with open(file_path, "rb") as f: + computation_tree = pickle.load(f) + + computation_tree._set_tree_status(mode="r") + + return computation_tree diff --git a/sklearn/callback/_early_stopping.py b/sklearn/callback/_early_stopping.py new file mode 100644 index 0000000000000..b3137c4ff7812 --- /dev/null +++ b/sklearn/callback/_early_stopping.py @@ -0,0 +1,81 @@ +# License: BSD 3 clause + +from . import BaseCallback + + +class EarlyStopping(BaseCallback): + request_from_reconstruction_attributes = True + + def __init__( + self, + monitor="objective_function", + on="validation_set", + higher_is_better=False, + validation_split="auto", + max_no_improvement=10, + threshold=1e-2, + ): + from ..model_selection import KFold + + self.validation_split = validation_split + if validation_split == "auto": + self.validation_split = KFold(n_splits=5, shuffle=True, random_state=42) + self.monitor = monitor + self.on = on + self.higher_is_better = higher_is_better + self.max_no_improvement = max_no_improvement + self.threshold = threshold + + self._no_improvement = {} + self._last_monitored = {} + + def on_fit_begin(self, estimator, X=None, y=None): + pass + + def on_fit_iter_end(self, *, estimator, node, **kwargs): + if node.depth != node.computation_tree.depth: + return + + reconstructed_estimator = kwargs.pop("from_reconstruction_attributes") + data = kwargs.pop("data") + + X = data["X_val"] if self.on == "validation_set" else data["X"] + y = data["y_val"] if self.on == "validation_set" else data["y"] + + if self.monitor == "objective_function": + new_monitored, *_ = reconstructed_estimator.objective_function( + X, y, normalize=True + ) + elif callable(self.monitor): + new_monitored = self.monitor(reconstructed_estimator, X, y) + elif self.monitor is None or isinstance(self.monitor, str): + from ..metrics import check_scoring + + scorer = check_scoring(reconstructed_estimator, self.monitor) + new_monitored = scorer(reconstructed_estimator, X, y) + + if self._score_improved(node, new_monitored): + self._no_improvement[node.parent] = 0 + self._last_monitored[node.parent] = new_monitored + else: + self._no_improvement[node.parent] += 1 + + if self._no_improvement[node.parent] >= self.max_no_improvement: + return True + + def _score_improved(self, node, new_monitored): + if node.parent not in self._last_monitored: + return True + + last_monitored = self._last_monitored[node.parent] + if self.higher_is_better: + return new_monitored > last_monitored * (1 + self.threshold) + else: + return new_monitored < last_monitored * (1 - self.threshold) + + def on_fit_end(self): + pass + + @property + def request_validation_split(self): + return self.on == "validation_set" diff --git a/sklearn/callback/_monitoring.py b/sklearn/callback/_monitoring.py new file mode 100644 index 0000000000000..cfff4d1215c3b --- /dev/null +++ b/sklearn/callback/_monitoring.py @@ -0,0 +1,124 @@ +# License: BSD 3 clause + +# import os +from pathlib import Path +from tempfile import TemporaryDirectory + +import matplotlib.pyplot as plt +import pandas as pd + +from . import BaseCallback + + +class Monitoring(BaseCallback): + """Monitor model convergence. + + Parameters + ---------- + monitor : + + X_val : ndarray, default=None + Validation data + + y_val : ndarray, default=None + Validation target + + Attributes + ---------- + data : pandas.DataFrame + The monitored quantities at each iteration. + """ + + request_from_reconstruction_attributes = True + + def __init__( + self, + *, + monitor="objective_function", + on="validation_set", + validation_split="auto", + ): + from ..model_selection import KFold + + self.validation_split = validation_split + if validation_split == "auto": + self.validation_split = KFold(n_splits=5, shuffle=True, random_state=42) + self.monitor = monitor + self.on = on + + self._data_dir = TemporaryDirectory() + self._data_files = {} + + if isinstance(self.monitor, str): + self.monitor_name = self.monitor + elif callable(self.monitor): + self.monitor_name = self.monitor.__name__ + + def on_fit_begin(self, estimator, *, X=None, y=None): + fname = Path(self._data_dir.name) / f"{estimator._computation_tree.uid}.csv" + with open(fname, "w") as file: + file.write(f"iteration,{self.monitor_name}_train,{self.monitor_name}_val\n") + self._data_files[estimator._computation_tree] = fname + + def on_fit_iter_end( + self, *, estimator, node, from_reconstruction_attributes, data, **kwargs + ): + if node.depth != node.computation_tree.depth: + return + + new_estimator = from_reconstruction_attributes + + X, y, X_val, y_val = data["X"], data["y"], data["X_val"], data["y_val"] + + if self.monitor == "objective_function": + new_monitored_train, *_ = new_estimator.objective_function( + X, y, normalize=True + ) + if X_val is not None: + new_monitored_val, *_ = new_estimator.objective_function( + X_val, y_val, normalize=True + ) + elif callable(self.monitor): + new_monitored_train = self.monitor(new_estimator, X, y) + if X_val is not None: + new_monitored_val = self.monitor(new_estimator, X_val, y_val) + elif self.monitor is None or isinstance(self.monitor, str): + from ..metrics import check_scoring + + scorer = check_scoring(new_estimator, self.monitor) + new_monitored_train = scorer(new_estimator, X, y) + if X_val is not None: + new_monitored_val = scorer(new_estimator, X_val, y_val) + + if X_val is None: + new_monitored_val = None + + with open(self._data_files[node.computation_tree], "a") as f: + f.write(f"{node.idx},{new_monitored_train},{new_monitored_val}\n") + + def on_fit_end(self): + pass + + # @property + # def data(self): + + def plot(self): + data_files = [p for p in Path(self._data_dir.name).iterdir() if p.is_file()] + for f in data_files: + data = pd.read_csv(f) + fig, ax = plt.subplots() + ax.plot( + data["iteration"], data[f"{self.monitor_name}_train"], label="train set" + ) + if self.on != "train_set": + ax.plot( + data["iteration"], + data[f"{self.monitor_name}_val"], + label="validation set", + ) + + ax.set_xlabel("Number of iterations") + ax.set_ylabel(self.monitor_name) + + ax.legend() + plt.show() diff --git a/sklearn/callback/_progressbar.py b/sklearn/callback/_progressbar.py new file mode 100644 index 0000000000000..f8ed251add34a --- /dev/null +++ b/sklearn/callback/_progressbar.py @@ -0,0 +1,309 @@ +# License: BSD 3 clause + +import importlib +from threading import Event, Thread + +from . import BaseCallback, load_computation_tree + + +def _check_backend_support(backend, caller_name): + """Raise ImportError with detailed error message if backend is not installed. + + Parameters + ---------- + backend : {"rich", "tqdm"} + The requested backend. + + caller_name : str + The name of the caller that requires the backend. + """ + try: + importlib.import_module(backend) # noqa + except ImportError as e: + raise ImportError(f"{caller_name} requires {backend} installed.") from e + + +class ProgressBar(BaseCallback): + """Callback that displays progress bars for each iterative steps of the estimator + + Parameters + ---------- + backend: {"rich", "tqdm"}, default="rich" + The backend for the progress bars display. + + max_depth_show : int, default=None + The maximum nested level of progress bars to display. + + max_depth_keep : int, default=None + The maximum nested level of progress bars to keep displayed when they are + finished. + """ + + auto_propagate = True + + def __init__(self, backend="rich", max_depth_show=None, max_depth_keep=None): + if backend not in ("rich", "tqdm"): + raise ValueError( + f"backend should be 'rich' or 'tqdm', got {self.backend} instead." + ) + _check_backend_support(backend, caller_name="Progressbar") + self.backend = backend + + if max_depth_show is not None and max_depth_show < 0: + raise ValueError("max_depth_show should be >= 0.") + self.max_depth_show = max_depth_show + + if max_depth_keep is not None and max_depth_keep < 0: + raise ValueError("max_depth_keep should be >= 0.") + self.max_depth_keep = max_depth_keep + + def on_fit_begin(self, estimator, X=None, y=None): + self._stop_event = Event() + + if self.backend == "rich": + self.progress_monitor = _RichProgressMonitor( + estimator=estimator, + event=self._stop_event, + max_depth_show=self.max_depth_show, + max_depth_keep=self.max_depth_keep, + ) + elif self.backend == "tqdm": + self.progress_monitor = _TqdmProgressMonitor( + estimator=estimator, + event=self._stop_event, + ) + + self.progress_monitor.start() + + def on_fit_iter_end(self, *, estimator, node, **kwargs): + pass + + def on_fit_end(self): + self._stop_event.set() + self.progress_monitor.join() + + def __getstate__(self): + state = self.__dict__.copy() + if "_stop_event" in state: + del state["_stop_event"] + if "progress_monitor" in state: + del state["progress_monitor"] + return state + + +# Custom Progress class to allow showing the tasks in a given order (given by setting +# the _ordered_tasks attribute). In particular it allows to dynamically create and +# insert tasks between existing tasks. + +try: + from rich.progress import Progress + + class _Progress(Progress): + def get_renderables(self): + table = self.make_tasks_table(getattr(self, "_ordered_tasks", [])) + yield table + +except ImportError: + pass + + +class _RichProgressMonitor(Thread): + """Thread monitoring the progress of an estimator with rich based display + + The display is a list of nested rich tasks using rich.Progress. There is one for + each node in the computation tree of the estimator and in the computation trees of + estimators used in the estimator. + + Parameters + ---------- + estimator : estimator instance + The estimator to monitor + + event : threading.Event instance + This thread will run until event is set. + + max_depth_show : int, default=None + The maximum nested level of progress bars to display. + + max_depth_keep : int, default=None + The maximum nested level of progress bars to keep displayed when they are + finished. + """ + + def __init__(self, estimator, event, max_depth_show=None, max_depth_keep=None): + Thread.__init__(self) + self.computation_tree = estimator._computation_tree + self.event = event + self.max_depth_show = max_depth_show + self.max_depth_keep = max_depth_keep + + # _computation_trees is a dict `directory: tuple` where + # - tuple[0] is the computation tree of the directory + # - tuple[1] is a dict `node.tree_status_idx: task_id` + self._computation_trees = {} + + def run(self): + from rich.progress import BarColumn, TextColumn, TimeRemainingColumn + from rich.style import Style + + with _Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn( + complete_style=Style(color="dark_orange"), + finished_style=Style(color="cyan"), + ), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + auto_refresh=False, + ) as progress_ctx: + self._progress_ctx = progress_ctx + + while not self.event.wait(0.05): + self._recursive_update_tasks() + self._progress_ctx.refresh() + + self._recursive_update_tasks() + self._progress_ctx.refresh() + + def _recursive_update_tasks(self, this_dir=None, depth=0): + """Recursively loop through directories and init or update tasks + + Parameters + ---------- + this_dir : pathlib.Path instance + The directory to + + depth : int + The current depth + """ + if self.max_depth_show is not None and depth > self.max_depth_show: + # Fast exit if this dir is deeper than what we want to show anyway + return + + if this_dir is None: + this_dir = self.computation_tree.tree_dir + # _ordered_tasks holds the list of the tasks in the order we want them to + # be displayed. + self._progress_ctx._ordered_tasks = [] + + if this_dir not in self._computation_trees: + # First time we discover this directory -> store the computation tree + # If the computation tree is not readable yet, skip and try again next time + computation_tree = load_computation_tree(this_dir) + if computation_tree is None: + return + + self._computation_trees[this_dir] = (computation_tree, {}) + + computation_tree, task_ids = self._computation_trees[this_dir] + + for node in computation_tree.iterate(include_leaves=True): + if node.children: + # node is not a leaf, create or update its task + if node.tree_status_idx not in task_ids: + visible = True + if ( + self.max_depth_show is not None + and depth + node.depth > self.max_depth_show + ): + # If this node is deeper than what we want to show, we create + # the task anyway but make it not visible + visible = False + + task_ids[node.tree_status_idx] = self._progress_ctx.add_task( + self._format_task_description(node, computation_tree, depth), + total=node.max_iter, + visible=visible, + ) + + task_id = task_ids[node.tree_status_idx] + task = self._progress_ctx.tasks[task_id] + self._progress_ctx._ordered_tasks.append(task) + + parent_task = self._get_parent_task(node, computation_tree, task_ids) + if parent_task is not None and parent_task.finished: + # If the task of the parent node is finished, make this task + # finished. It can happen if some computations are stopped + # before reaching max_iter. + visible = True + if ( + self.max_depth_keep is not None + and depth + node.depth > self.max_depth_keep + ): + # If this node is deeper than what we want to keep in the output + # make it not visible + visible = False + self._progress_ctx.update( + task_id, completed=node.max_iter, visible=visible, refresh=False + ) + else: + node_progress = computation_tree.get_progress(node) + if node_progress != task.completed: + self._progress_ctx.update( + task_id, completed=node_progress, refresh=False + ) + else: + # node is a leaf, look for tasks of its sub computation tree before + # going to the next node + child_dir = computation_tree.get_child_computation_tree_dir(node) + # child_dir = this_dir / str(node.tree_status_idx) + if child_dir.exists(): + self._recursive_update_tasks( + child_dir, depth + computation_tree.depth + ) + + def _format_task_description(self, node, computation_tree, depth): + """Return a formatted description for the task of the node""" + colors = ["red", "green", "blue", "yellow"] + + indent = f"{' ' * (depth + node.depth)}" + style = f"[{colors[(depth + node.depth)%len(colors)]}]" + + description = f"{computation_tree.estimator_name} - {node.description}" + if node.parent is None and computation_tree.parent_node is not None: + description = ( + f"{computation_tree.parent_node.description} " + f"{computation_tree.parent_node.idx} |" + f" {description}" + ) + if node.parent is not None: + description = f"{description} {node.idx}" + + return f"{style}{indent}{description}" + + def _get_parent_task(self, node, computation_tree, task_ids): + """Get the task of the parent node""" + if node.parent is not None: + # node is not the root, return the task of its parent + task_id = task_ids[node.parent.tree_status_idx] + return self._progress_ctx.tasks[task_id] + if computation_tree.parent_node is not None: + # node is the root, return the task of the parent of the parent_node of + # its computation tree + parent_dir = computation_tree.parent_node.computation_tree.tree_dir + _, parent_tree_task_ids = self._computation_trees[parent_dir] + task_id = parent_tree_task_ids[ + computation_tree.parent_node.parent.tree_status_idx + ] + return self._progress_ctx._tasks[task_id] + return + + +class _TqdmProgressMonitor(Thread): + def __init__(self, estimator, event): + Thread.__init__(self) + self.computation_tree = estimator._computation_tree + self.event = event + + def run(self): + from tqdm import tqdm + + root = self.computation_tree.root + + with tqdm(total=len(root.children)) as pbar: + while not self.event.wait(0.05): + node_progress = self.computation_tree.get_progress(root) + if node_progress != pbar.total: + pbar.update(node_progress - pbar.n) + + pbar.update(pbar.total - pbar.n) diff --git a/sklearn/callback/_snapshot.py b/sklearn/callback/_snapshot.py new file mode 100644 index 0000000000000..cfb76c5ec1139 --- /dev/null +++ b/sklearn/callback/_snapshot.py @@ -0,0 +1,65 @@ +# License: BSD 3 clause + +import pickle +from datetime import datetime +from pathlib import Path + +from . import BaseCallback + + +class Snapshot(BaseCallback): + """Take regular snapshots of an estimator + + Parameters + ---------- + keep_last_n : int or None, default=1 + Only the last `keep_last_n` snapshots are kept on the disk. None means all + snapshots are kept. + + base_dir : str or pathlib.Path instance, default=None + The directory where the snapshots should be stored. If None, they are stored in + the current directory. + """ + + request_from_reconstruction_attributes = True + + def __init__(self, keep_last_n=1, base_dir=None): + self.keep_last_n = keep_last_n + if keep_last_n is not None and keep_last_n <= 0: + raise ValueError( + "keep_last_n must be a positive integer, got" + f" {self.keep_last_n} instead." + ) + + self.base_dir = Path("." if base_dir is None else base_dir) + + def on_fit_begin(self, estimator, X=None, y=None): + subdir = self._get_subdir(estimator._computation_tree) + subdir.mkdir() + + def on_fit_iter_end(self, *, estimator, node, **kwargs): + new_estimator = kwargs.get("from_reconstruction_attributes", None) + if new_estimator is None: + return + + subdir = self._get_subdir(node.computation_tree) + snapshot_filename = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}.pkl" + + with open(subdir / snapshot_filename, "wb") as f: + pickle.dump(new_estimator, f) + + if self.keep_last_n is not None: + for snapshot in sorted(subdir.iterdir())[: -self.keep_last_n]: + snapshot.unlink(missing_ok=True) + + def on_fit_end(self): + pass + + def _get_subdir(self, computation_tree): + """Return the sub directory containing the snapshots of the estimator""" + subdir = ( + self.base_dir + / f"snapshots_{computation_tree.estimator_name}_{str(computation_tree.uid)}" + ) + + return subdir diff --git a/sklearn/callback/_text_verbose.py b/sklearn/callback/_text_verbose.py new file mode 100644 index 0000000000000..9773f1c8a6f51 --- /dev/null +++ b/sklearn/callback/_text_verbose.py @@ -0,0 +1,40 @@ +# License: BSD 3 clause + +import time + +from . import BaseCallback + + +class TextVerbose(BaseCallback): + auto_propagate = True + request_stopping_criterion = True + + def on_fit_begin(self, estimator, X=None, y=None): + self._start_time = time.perf_counter() + + def on_fit_iter_end(self, *, node, **kwargs): + if node.depth != node.computation_tree.depth: + return + + stopping_criterion = kwargs.get("stopping_criterion", None) + tol = kwargs.get("tol", None) + + current_time = time.perf_counter() - self._start_time + + s = f"{node.description} {node.idx}" + parent = node.parent + while parent is not None and parent.parent is not None: + s = f"{parent.description} {parent.idx} - {s}" + parent = parent.parent + + msg = ( + f"[{parent.computation_tree.estimator_name}] {s} | time {current_time:.5f}s" + ) + + if stopping_criterion is not None and tol is not None: + msg += f" | stopping_criterion={stopping_criterion:.3E} | tol={tol:.3E}" + + print(msg) + + def on_fit_end(self): + pass diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py new file mode 100644 index 0000000000000..d867bdcfa77d2 --- /dev/null +++ b/sklearn/callback/tests/_utils.py @@ -0,0 +1,114 @@ +from functools import partial + +from joblib.parallel import Parallel, delayed + +from sklearn.base import BaseEstimator, _fit_context, clone +from sklearn.callback import BaseCallback +from sklearn.callback._base import _eval_callbacks_on_fit_iter_end + + +class TestingCallback(BaseCallback): + def on_fit_begin(self, estimator, *, X=None, y=None): + pass + + def on_fit_end(self): + pass + + def on_fit_iter_end(self, estimator, node, **kwargs): + pass + + +class TestingAutoPropagatedCallback(TestingCallback): + auto_propagate = True + + +class NotValidCallback: + def on_fit_begin(self, estimator, *, X=None, y=None): + pass + + def on_fit_end(self): + pass + + def on_fit_iter_end(self, estimator, node, **kwargs): + pass + + +class Estimator(BaseEstimator): + _parameter_constraints = {} + + def __init__(self, max_iter=20): + self.max_iter = max_iter + + @_fit_context(prefer_skip_nested_validation=False) + def fit(self, X, y): + root, X, y, X_val, y_val = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": self.max_iter}, + {"descr": "iter", "max_iter": None}, + ], + X=X, + y=y, + ) + + for i in range(self.max_iter): + if _eval_callbacks_on_fit_iter_end( + estimator=self, + node=root.children[i], + from_reconstruction_attributes=partial( + self._from_reconstruction_attributes, + reconstruction_attributes=lambda: {"n_iter_": i + 1}, + ), + data={"X": X, "y": y, "X_val": X_val, "y_val": y_val}, + ): + break + + self.n_iter_ = i + 1 + + return self + + def objective_function(self, X, y=None, normalize=False): + return 0, 0, 0 + + +class MetaEstimator(BaseEstimator): + _parameter_constraints = {} + + def __init__( + self, estimator, n_outer=4, n_inner=3, n_jobs=None, prefer="processes" + ): + self.estimator = estimator + self.n_outer = n_outer + self.n_inner = n_inner + self.n_jobs = n_jobs + self.prefer = prefer + + @_fit_context(prefer_skip_nested_validation=False) + def fit(self, X, y): + root, X, y, _, _ = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": self.n_outer}, + {"descr": "outer", "max_iter": self.n_inner}, + {"descr": "inner", "max_iter": None}, + ], + X=X, + y=y, + ) + + Parallel(n_jobs=self.n_jobs, prefer=self.prefer)( + delayed(self._func)(self.estimator, X, y, node, i) + for i, node in enumerate(root.children) + ) + + return self + + def _func(self, estimator, X, y, parent_node, i): + for j, node in enumerate(parent_node.children): + est = clone(estimator) + self._propagate_callbacks(est, parent_node=node) + est.fit(X, y) + + _eval_callbacks_on_fit_iter_end(estimator=self, node=node) + + _eval_callbacks_on_fit_iter_end(estimator=self, node=parent_node) + + return diff --git a/sklearn/callback/tests/test_base_estimator_callback_methods.py b/sklearn/callback/tests/test_base_estimator_callback_methods.py new file mode 100644 index 0000000000000..4bfdc27db7b52 --- /dev/null +++ b/sklearn/callback/tests/test_base_estimator_callback_methods.py @@ -0,0 +1,126 @@ +# License: BSD 3 clause + +from pathlib import Path + +import pytest + +from sklearn.callback.tests._utils import ( + Estimator, + MetaEstimator, + NotValidCallback, + TestingAutoPropagatedCallback, + TestingCallback, +) + + +@pytest.mark.parametrize( + "callbacks", + [ + TestingCallback(), + [TestingCallback()], + [TestingCallback(), TestingAutoPropagatedCallback()], + ], +) +def test_set_callbacks(callbacks): + """Sanity check for the _set_callbacks method""" + estimator = Estimator() + + set_callbacks_return = estimator._set_callbacks(callbacks) + assert hasattr(estimator, "_callbacks") + assert estimator._callbacks in (callbacks, [callbacks]) + assert set_callbacks_return is estimator + + +@pytest.mark.parametrize("callbacks", [None, NotValidCallback()]) +def test_set_callbacks_error(callbacks): + """Check the error message when not passing a valid callback to _set_callbacks""" + estimator = Estimator() + + with pytest.raises(TypeError, match="callbacks must be subclasses of BaseCallback"): + estimator._set_callbacks(callbacks) + + +def test_propagate_callbacks(): + """Sanity check for the _propagate_callbacks method""" + not_propagated_callback = TestingCallback() + propagated_callback = TestingAutoPropagatedCallback() + + estimator = Estimator() + estimator._set_callbacks([not_propagated_callback, propagated_callback]) + + sub_estimator = Estimator() + estimator._propagate_callbacks(sub_estimator, parent_node=None) + + assert hasattr(sub_estimator, "_parent_ct_node") + assert not_propagated_callback not in sub_estimator._callbacks + assert propagated_callback in sub_estimator._callbacks + + +def test_propagate_callback_no_callback(): + """Check that no callback is propagated if there's no callback""" + estimator = Estimator() + sub_estimator = Estimator() + estimator._propagate_callbacks(sub_estimator, parent_node=None) + + assert not hasattr(estimator, "_callbacks") + assert not hasattr(sub_estimator, "_callbacks") + + +def test_auto_propagated_callbacks(): + """Check that it's not possible to set an auto-propagated callback on the + sub-estimator of a meta-estimator. + """ + estimator = Estimator() + estimator._set_callbacks(TestingAutoPropagatedCallback()) + + meta_estimator = MetaEstimator(estimator=estimator) + + match = ( + r"sub-estimators .*of a meta-estimator .*can't have auto-propagated callbacks" + ) + with pytest.raises(TypeError, match=match): + meta_estimator.fit(X=None, y=None) + + +def test_eval_callbacks_on_fit_begin(): + """Check that _eval_callbacks_on_fit_begin creates and dumps the computation tree""" + estimator = Estimator()._set_callbacks(TestingCallback()) + assert not hasattr(estimator, "_computation_tree") + + levels = [ + {"descr": "fit", "max_iter": 10}, + {"descr": "iter", "max_iter": None}, + ] + ct_root, *_ = estimator._eval_callbacks_on_fit_begin(levels=levels) + assert hasattr(estimator, "_computation_tree") + assert ct_root is estimator._computation_tree.root + + ct_pickle = Path(estimator._computation_tree.tree_dir) / "computation_tree.pkl" + assert ct_pickle.exists() + + +def test_callback_context_finalize(): + """Check that the folder containing the computation tree of the estimator is + deleted when there are no reference left to its callbacks. + """ + callback = TestingCallback() + + # estimator is not fitted, its computation tree is not built yet + est = Estimator()._set_callbacks(callbacks=callback) + assert not hasattr(est, "_computation_tree") + + # estimator is fitted, a folder has been created to hold its computation tree + est.fit(X=None, y=None) + assert hasattr(est, "_computation_tree") + tree_dir = est._computation_tree.tree_dir + assert tree_dir.is_dir() + + # there is no more reference to the estimator, but there is still a reference to the + # callback which might need to access the computation tree + del est + assert tree_dir.is_dir() + + # there is no more reference to the callback, the computation tree folder must be + # deleted + del callback + assert not tree_dir.is_dir() diff --git a/sklearn/callback/tests/test_callbacks.py b/sklearn/callback/tests/test_callbacks.py new file mode 100644 index 0000000000000..5adb16a79bef9 --- /dev/null +++ b/sklearn/callback/tests/test_callbacks.py @@ -0,0 +1,77 @@ +# License: BSD 3 clause + +import pickle +import sys +import tempfile + +import numpy as np +import pytest + +from sklearn.callback import ( + EarlyStopping, + Monitoring, + ProgressBar, + Snapshot, + TextVerbose, +) +from sklearn.callback.tests._utils import Estimator, MetaEstimator + +X = np.zeros((100, 3)) +y = np.zeros(100, dtype=int) + + +@pytest.mark.parametrize( + "Callback", + [ + Monitoring, + EarlyStopping, + ProgressBar, + Snapshot, + TextVerbose, + ], +) +def test_callback_doesnt_hold_ref_to_estimator(Callback): + callback = Callback() + est = Estimator() + callback_refcount = sys.getrefcount(callback) + est_refcount = sys.getrefcount(est) + + est._set_callbacks(callbacks=callback) + est.fit(X, y) + # estimator has a ref on the callback but the callback has no ref to the estimator + assert sys.getrefcount(est) == est_refcount + assert sys.getrefcount(callback) == callback_refcount + 1 + + +@pytest.mark.parametrize("n_jobs", (1, 2)) +@pytest.mark.parametrize("prefer", ("threads", "processes")) +def test_snapshot_meta_estimator(n_jobs, prefer): + """Test for the Snapshot callback""" + estimator = Estimator(max_iter=20) + + with tempfile.TemporaryDirectory() as tmp_dir: + keep_last_n = 5 + callback = Snapshot(keep_last_n=keep_last_n, base_dir=tmp_dir) + estimator._set_callbacks(callback) + metaestimator = MetaEstimator( + estimator=estimator, n_outer=4, n_inner=3, n_jobs=n_jobs, prefer=prefer + ) + + metaestimator.fit(X, y) + + # There's a subdir of base_dir for each clone of estimator fitted in + # metaestimator. There are n_outer * n_inner such clones + snapshot_dirs = list(callback.base_dir.iterdir()) + assert len(snapshot_dirs) == metaestimator.n_outer * metaestimator.n_inner + + for snapshot_dir in snapshot_dirs: + snapshots = sorted(snapshot_dir.iterdir()) + assert len(snapshots) == keep_last_n + + for i, snapshot in enumerate(snapshots): + with open(snapshot, "rb") as f: + loaded_estimator = pickle.load(f) + + # We kept last 5 snapshots out of 20 iterations. + # This one is the 16 + i-th. + assert loaded_estimator.n_iter_ == 16 + i diff --git a/sklearn/callback/tests/test_computation_tree.py b/sklearn/callback/tests/test_computation_tree.py new file mode 100644 index 0000000000000..5a6da95eea469 --- /dev/null +++ b/sklearn/callback/tests/test_computation_tree.py @@ -0,0 +1,107 @@ +# License: BSD 3 clause + +import numpy as np + +from sklearn.callback import ComputationTree + +levels = [ + {"descr": "level0", "max_iter": 3}, + {"descr": "level1", "max_iter": 5}, + {"descr": "level2", "max_iter": 7}, + {"descr": "level3", "max_iter": None}, +] + + +def test_computation_tree(): + """Check the construction of the computation tree""" + computation_tree = ComputationTree(estimator_name="estimator", levels=levels) + assert computation_tree.estimator_name == "estimator" + + root = computation_tree.root + assert root.parent is None + assert root.idx == 0 + + assert len(root.children) == root.max_iter == 3 + assert [node.idx for node in root.children] == list(range(3)) + + for node1 in root.children: + assert len(node1.children) == 5 + assert [n.idx for n in node1.children] == list(range(5)) + + for node2 in node1.children: + assert len(node2.children) == 7 + assert [n.idx for n in node2.children] == list(range(7)) + + for node3 in node2.children: + assert not node3.children + + +def test_n_nodes(): + """Check that the number of node in a comutation tree corresponds to what we expect + from the level descriptions + """ + computation_tree = ComputationTree(estimator_name="", levels=levels) + + max_iter_per_level = [level["max_iter"] for level in levels[:-1]] + expected_n_nodes = 1 + np.sum(np.cumprod(max_iter_per_level)) + + assert computation_tree.n_nodes == expected_n_nodes + assert len(computation_tree.iterate(include_leaves=True)) == expected_n_nodes + assert computation_tree._tree_status.shape == (expected_n_nodes,) + + +def test_tree_status_idx(): + """Check that each node has a unique index in the _tree_status array and that their + order corresponds to the order given by a depth first search. + """ + computation_tree = ComputationTree(estimator_name="", levels=levels) + + indexes = [ + node.tree_status_idx for node in computation_tree.iterate(include_leaves=True) + ] + assert indexes == list(range(computation_tree.n_nodes)) + + +def test_get_ancestors(): + """Check the ancestor search and its propagation to parent trees""" + parent_levels = [ + {"descr": "parent_level0", "max_iter": 2}, + {"descr": "parent_level1", "max_iter": 4}, + {"descr": "parent_level2", "max_iter": None}, + ] + + parent_computation_tree = ComputationTree( + estimator_name="parent_estimator", levels=parent_levels + ) + parent_node = parent_computation_tree.root.children[0].children[2] + # indices of each node (in its parent children) in this chain are 0, 0, 2. + # (root is always 0). + expected_parent_indices = [2, 0, 0] + + computation_tree = ComputationTree( + estimator_name="estimator", levels=levels, parent_node=parent_node + ) + node = computation_tree.root.children[1].children[3].children[5] + expected_node_indices = [5, 3, 1, 0] + + ancestors = node.get_ancestors(include_ancestor_trees=False) + assert ancestors == [ + node, + node.parent, + node.parent.parent, + node.parent.parent.parent, + ] + assert [n.idx for n in ancestors] == expected_node_indices + assert computation_tree.root in ancestors + + ancestors = node.get_ancestors(include_ancestor_trees=True) + assert ancestors == [ + node, + node.parent, + node.parent.parent, + node.parent.parent.parent, + parent_node, + parent_node.parent, + parent_node.parent.parent, + ] + assert [n.idx for n in ancestors] == expected_node_indices + expected_parent_indices diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index db46540e26708..46f8645a1f06d 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -10,6 +10,7 @@ import time import warnings from abc import ABC +from functools import partial from math import sqrt from numbers import Integral, Real @@ -24,6 +25,7 @@ TransformerMixin, _fit_context, ) +from ..callback._base import _eval_callbacks_on_fit_iter_end from ..exceptions import ConvergenceWarning from ..utils import check_array, check_random_state, gen_batches, metadata_routing from ..utils._param_validation import ( @@ -408,6 +410,7 @@ def _update_coordinate_descent(X, W, Ht, l1_reg, l2_reg, shuffle, random_state): def _fit_coordinate_descent( X, + X_val, W, H, tol=1e-4, @@ -420,6 +423,8 @@ def _fit_coordinate_descent( verbose=0, shuffle=False, random_state=None, + estimator=None, + parent_node=None, ): """Compute Non-negative Matrix Factorization (NMF) with Coordinate Descent @@ -432,6 +437,9 @@ def _fit_coordinate_descent( X : array-like of shape (n_samples, n_features) Constant matrix. + X_val : array-like of shape (n_samples_val, n_features) + Constant validation matrix. + W : array-like of shape (n_samples, n_components) Initial guess for the solution. @@ -472,6 +480,12 @@ def _fit_coordinate_descent( results across multiple function calls. See :term:`Glossary `. + estimator : estimator instance, default=None + The estimator calling this function. Used by callbacks. + + parent_node : ComputationNode instance, default=None + The parent node of the current node. Used by callbacks. + Returns ------- W : ndarray of shape (n_samples, n_components) @@ -493,6 +507,8 @@ def _fit_coordinate_descent( # so W and Ht are both in C order in memory Ht = check_array(H.T, order="C") X = check_array(X, accept_sparse="csr") + if X_val is not None: + X_val = check_array(X_val, accept_sparse="csr") rng = check_random_state(random_state) @@ -515,6 +531,25 @@ def _fit_coordinate_descent( if violation_init == 0: break + if _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=parent_node.children[n_iter - 1] if parent_node is not None else None, + stopping_criterion=lambda: violation / violation_init, + tol=tol, + fit_state={"H": Ht.T, "W": W}, + from_reconstruction_attributes=partial( + estimator._from_reconstruction_attributes, + reconstruction_attributes=lambda: { + "n_components_": Ht.T.shape[0], + "components_": H, + "n_iter_": n_iter, + "reconstruction_err_": _beta_divergence(X, W, Ht.T, 2, True), + }, + ), + data={"X": X, "y": None, "X_val": X_val, "y_val": None}, + ): + break + if verbose: print("violation:", violation / violation_init) @@ -733,6 +768,7 @@ def _multiplicative_update_h( def _fit_multiplicative_update( X, + X_val, W, H, beta_loss="frobenius", @@ -744,6 +780,8 @@ def _fit_multiplicative_update( l2_reg_H=0, update_H=True, verbose=0, + estimator=None, + parent_node=None, ): """Compute Non-negative Matrix Factorization with Multiplicative Update. @@ -756,6 +794,9 @@ def _fit_multiplicative_update( X : array-like of shape (n_samples, n_features) Constant input matrix. + X_val : array-like of shape (n_samples_val, n_features) + Constant validation matrix. + W : array-like of shape (n_samples, n_components) Initial guess for the solution. @@ -796,6 +837,12 @@ def _fit_multiplicative_update( verbose : int, default=0 The verbosity level. + estimator : estimator instance, default=None + The estimator calling this function. Used by callbacks. + + parent_node : ComputationNode instance, default=None + The parent node of the current node. Used by callbacks. + Returns ------- W : ndarray of shape (n_samples, n_components) @@ -871,6 +918,31 @@ def _fit_multiplicative_update( if beta_loss <= 1: H[H < np.finfo(np.float64).eps] = 0.0 + if _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=parent_node.children[n_iter - 1] if parent_node is not None else None, + stopping_criterion=lambda: ( + ( + previous_error + - _beta_divergence(X, W, H, beta_loss, square_root=True) + ) + / error_at_init + ), + tol=tol, + fit_state={"H": H, "W": W}, + from_reconstruction_attributes=partial( + estimator._from_reconstruction_attributes, + reconstruction_attributes=lambda: { + "n_components_": H.shape[0], + "components_": H, + "n_iter_": n_iter, + "reconstruction_err_": _beta_divergence(X, W, H, beta_loss, True), + }, + ), + data={"X": X, "y": None, "X_val": X_val, "y_val": None}, + ): + break + # test convergence criterion every 10 iterations if tol > 0 and n_iter % 10 == 0: error = _beta_divergence(X, W, H, beta_loss, square_root=True) @@ -1349,6 +1421,28 @@ def inverse_transform(self, Xt=None, W=None): check_is_fitted(self) return Xt @ self.components_ + def objective_function(self, X, y=None, *, W=None, H=None, normalize=False): + if W is None: + W = self.transform(X) + if H is None: + H = self.components_ + + data_fit = _beta_divergence(X, W, H, self._beta_loss) + + l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H = self._compute_regularization(X) + penalization = ( + l1_reg_W * W.sum() + + l1_reg_H * H.sum() + + l2_reg_W * (W**2).sum() + + l2_reg_H * (H**2).sum() + ) + + if normalize: + data_fit /= X.shape[0] + penalization /= X.shape[0] + + return data_fit + penalization, data_fit, penalization + @property def _n_features_out(self): """Number of transformed output features.""" @@ -1662,20 +1756,28 @@ def fit_transform(self, X, y=None, W=None, H=None): X, accept_sparse=("csr", "csc"), dtype=[np.float64, np.float32] ) - with config_context(assume_finite=True): - W, H, n_iter = self._fit_transform(X, W=W, H=H) + root, X, _, X_val, _ = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": self.max_iter}, + {"descr": "iter", "max_iter": None}, + ], + X=X, + ) + + W, H, n_iter = self._fit_transform(X, X_val, W=W, H=H, parent_node=root) self.reconstruction_err_ = _beta_divergence( X, W, H, self._beta_loss, square_root=True ) - self.n_components_ = H.shape[0] self.components_ = H self.n_iter_ = n_iter return W - def _fit_transform(self, X, y=None, W=None, H=None, update_H=True): + def _fit_transform( + self, X, X_val=None, W=None, H=None, update_H=True, parent_node=None + ): """Learn a NMF model for the data X and returns the transformed data. Parameters @@ -1735,6 +1837,7 @@ def _fit_transform(self, X, y=None, W=None, H=None, update_H=True): if self.solver == "cd": W, H, n_iter = _fit_coordinate_descent( X, + X_val, W, H, self.tol, @@ -1747,10 +1850,13 @@ def _fit_transform(self, X, y=None, W=None, H=None, update_H=True): verbose=self.verbose, shuffle=self.shuffle, random_state=self.random_state, + estimator=self, + parent_node=parent_node, ) elif self.solver == "mu": W, H, n_iter, *_ = _fit_multiplicative_update( X, + X_val, W, H, self._beta_loss, @@ -1760,8 +1866,10 @@ def _fit_transform(self, X, y=None, W=None, H=None, update_H=True): l1_reg_H, l2_reg_W, l2_reg_H, - update_H, - self.verbose, + update_H=update_H, + verbose=self.verbose, + estimator=self, + parent_node=parent_node, ) else: raise ValueError("Invalid solver parameter '%s'." % self.solver) @@ -2441,3 +2549,8 @@ def partial_fit(self, X, y=None, W=None, H=None): self.n_steps_ += 1 return self + + @property + def _n_features_out(self): + """Number of transformed output features.""" + return self.components_.shape[0] diff --git a/sklearn/decomposition/tests/test_nmf.py b/sklearn/decomposition/tests/test_nmf.py index ce13d9358b5d8..573af0a5e7258 100644 --- a/sklearn/decomposition/tests/test_nmf.py +++ b/sklearn/decomposition/tests/test_nmf.py @@ -1,5 +1,7 @@ +import pickle import re import sys +import tempfile import warnings from io import StringIO @@ -8,6 +10,7 @@ from scipy import linalg from sklearn.base import clone +from sklearn.callback import Snapshot from sklearn.decomposition import NMF, MiniBatchNMF, non_negative_factorization from sklearn.decomposition import _nmf as nmf # For testing internals from sklearn.exceptions import ConvergenceWarning @@ -1060,3 +1063,30 @@ def test_nmf_custom_init_shape_error(): with pytest.raises(ValueError, match="Array with wrong second dimension passed"): nmf.fit(X, H=H, W=rng.random_sample((6, 3))) + + +@pytest.mark.parametrize("solver, beta_loss", [("mu", 0), ("mu", 2), ("cd", 2)]) +def test_nmf_callback_reconstruction_attributes(solver, beta_loss): + # Check that the reconstruction attributes passed to the callback allow to make + # a new estimator as if the fit ended when the callback is called. + X = np.random.RandomState(0).random_sample((100, 20)) + + nmf = NMF(n_components=5, solver=solver, beta_loss=beta_loss, random_state=0) + nmf.fit(X) + + with tempfile.TemporaryDirectory() as tmp_dir: + callback = Snapshot(base_dir=tmp_dir) + nmf._set_callbacks(callback) + nmf.fit(X) + + # load model from last iteration + snapshot_dir = next(callback.base_dir.iterdir()) + snapshot = sorted(snapshot_dir.iterdir())[-1] + with open(snapshot, "rb") as f: + loaded_nmf = pickle.load(f) + + # The model saved during the last iteration is the same as the original model + assert nmf.n_iter_ == loaded_nmf.n_iter_ + assert_allclose(nmf.components_, loaded_nmf.components_) + assert_allclose(nmf.reconstruction_err_, loaded_nmf.reconstruction_err_) + assert_allclose(nmf.transform(X), loaded_nmf.transform(X)) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index c3af930654b73..27a2030ca2d08 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -25,6 +25,7 @@ _fit_context, is_classifier, ) +from ...callback._base import _eval_callbacks_on_fit_iter_end from ...metrics import check_scoring from ...model_selection import train_test_split from ...preprocessing import LabelEncoder @@ -475,6 +476,19 @@ def fit(self, X, y, sample_weight=None): X_train, y_train, sample_weight_train = X, y, sample_weight X_val = y_val = sample_weight_val = None + begin_at_stage = ( + 0 if not (self._is_fitted() and self.warm_start) else self.n_iter_ + ) + + root, X_train, y_train, X_val, y_val = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": self.max_iter - begin_at_stage}, + {"descr": "iter", "max_iter": None}, + ], + X=X, + y=y, + ) + # Bin the data # For ease of use of the API, the user-facing GBDT classes accept the # parameter max_bins, which doesn't take into account the bin for @@ -769,6 +783,26 @@ def fit(self, X, y, sample_weight=None): if should_early_stop: break + if _eval_callbacks_on_fit_iter_end( + estimator=self, + node=root.children[iteration - begin_at_stage], + fit_state={}, + from_reconstruction_attributes=partial( + self._from_reconstruction_attributes, + reconstruction_attributes=lambda: { + "train_score_": np.asarray(self.train_score_), + "validation_score_": np.asarray(self.validation_score_), + }, + ), + data={ + "X": X_binned_train, + "y": y_train, + "X_val": X_binned_val, + "y_val": y_val, + }, + ): + break + if self.verbose: duration = time() - fit_start_time n_total_leaves = sum( @@ -807,8 +841,22 @@ def fit(self, X, y, sample_weight=None): self.train_score_ = np.asarray(self.train_score_) self.validation_score_ = np.asarray(self.validation_score_) del self._in_fit # hard delete so we're sure it can't be used anymore + return self + def objective_function(self, X, y, *, raw_predictions=None, normalize=False): + if raw_predictions is None: + raw_predictions = self._raw_predict(X) + + loss = self._loss( + y_true=y, + raw_prediction=raw_predictions, + ) + if normalize: + loss /= raw_predictions.shape[0] + + return loss, loss, 0 + def _is_fitted(self): return len(getattr(self, "_predictors", [])) > 0 diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index e6ac6ff087945..8e3886d38b781 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -22,6 +22,7 @@ from .._loss.loss import HalfBinomialLoss, HalfMultinomialLoss from ..base import _fit_context +from ..callback._base import _eval_callbacks_on_fit_iter_end from ..metrics import get_scorer from ..model_selection import check_cv from ..preprocessing import LabelBinarizer, LabelEncoder @@ -127,6 +128,8 @@ def _logistic_regression_path( sample_weight=None, l1_ratio=None, n_threads=1, + estimator=None, + parent_node=None, ): """Compute a Logistic Regression model for a list of regularization parameters. @@ -452,11 +455,19 @@ def _logistic_regression_path( coefs = list() n_iter = np.zeros(len(Cs), dtype=np.int32) for i, C in enumerate(Cs): + # Distinguish between LogReg and LogRegCV + node = ( + None + if parent_node is None + else parent_node if len(Cs) == 1 else parent_node.children + ) + if solver == "lbfgs": l2_reg_strength = 1.0 / (C * sw_sum) iprint = [-1, 50, 1, 100, 101][ np.searchsorted(np.array([0, 1, 2, 3]), verbose) ] + children = iter(node.children) if node is not None else None opt_res = optimize.minimize( func, w0, @@ -470,6 +481,10 @@ def _logistic_regression_path( "gtol": tol, "ftol": 64 * np.finfo(float).eps, }, + callback=lambda xk: _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=next(children) if children is not None else None, + ), ) n_iter_i = _check_optimize_result( solver, @@ -482,7 +497,15 @@ def _logistic_regression_path( l2_reg_strength = 1.0 / (C * sw_sum) args = (X, target, sample_weight, l2_reg_strength, n_threads) w0, n_iter_i = _newton_cg( - hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol + hess, + func, + grad, + w0, + args=args, + maxiter=max_iter, + tol=tol, + estimator=estimator, + parent_node=node, ) elif solver == "newton-cholesky": l2_reg_strength = 1.0 / (C * sw_sum) @@ -557,6 +580,8 @@ def _logistic_regression_path( max_squared_sum, warm_start_sag, is_saga=(solver == "saga"), + estimator=estimator, + parent_node=node, ) else: @@ -577,8 +602,20 @@ def _logistic_regression_path( else: coefs.append(w0.copy()) + if len(Cs) > 1: + _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=node, + ) + n_iter[i] = n_iter_i + if multi_class == "ovr": + _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=parent_node, + ) + return np.array(coefs), np.array(Cs), n_iter @@ -1296,6 +1333,24 @@ def fit(self, X, y, sample_weight=None): if warm_start_coef is None: warm_start_coef = [None] * n_classes + if len(classes_) == 1: + levels = [ + {"descr": "fit", "max_iter": self.max_iter}, + {"descr": "iter", "max_iter": None}, + ] + else: + levels = [ + {"descr": "fit", "max_iter": len(classes_)}, + {"descr": "class", "max_iter": self.max_iter}, + {"descr": "iter", "max_iter": None}, + ] + root, X, y, X_val, y_val = self._eval_callbacks_on_fit_begin( + levels=levels, X=X, y=y + ) + + # distinguish between multinomial and ovr + nodes = [root] if len(classes_) == 1 else root.children + path_func = delayed(_logistic_regression_path) # The SAG solver releases the GIL so it's more efficient to use @@ -1340,8 +1395,10 @@ def fit(self, X, y, sample_weight=None): max_squared_sum=max_squared_sum, sample_weight=sample_weight, n_threads=n_threads, + estimator=self, + parent_node=node, ) - for class_, warm_start_coef_ in zip(classes_, warm_start_coef) + for class_, warm_start_coef_, node in zip(classes_, warm_start_coef, nodes) ) fold_coefs_, _, n_iter_ = zip(*fold_coefs_) diff --git a/sklearn/linear_model/_sag.py b/sklearn/linear_model/_sag.py index 2626955ec2a7f..88f8f9d50bf0c 100644 --- a/sklearn/linear_model/_sag.py +++ b/sklearn/linear_model/_sag.py @@ -100,6 +100,8 @@ def sag_solver( max_squared_sum=None, warm_start_mem=None, is_saga=False, + estimator=None, + parent_node=None, ): """SAG solver for Ridge and LogisticRegression. @@ -344,6 +346,8 @@ def sag_solver( intercept_decay, is_saga, verbose, + estimator=estimator, + parent_node=parent_node, ) if n_iter_ == max_iter: diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index 97bf3020d6602..a342d95fb9dbb 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -34,6 +34,7 @@ from ._sgd_fast cimport LossFunction from ._sgd_fast cimport Log, SquaredLoss from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64 +from ..callback._base import _eval_callbacks_on_fit_iter_end from libc.stdio cimport printf @@ -217,7 +218,9 @@ def sag{{name_suffix}}( {{c_type}}[::1] intercept_sum_gradient_init, double intercept_decay, bint saga, - bint verbose + bint verbose, + estimator, + parent_node, ): """Stochastic Average Gradient (SAG) and SAGA solvers. @@ -538,6 +541,22 @@ def sag{{name_suffix}}( max_weight = fmax{{name_suffix}}(max_weight, fabs(weights[idx])) max_change = fmax{{name_suffix}}(max_change, fabs(weights[idx] - previous_weights[idx])) previous_weights[idx] = weights[idx] + + with gil: + if _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=parent_node.children[n_iter] if parent_node is not None else None, + stopping_criterion = ( + lambda: max_change / max_weight + if max_weight != 0 + else 0 + if max_weight == max_change == 0 + else np.inf + ), + tol=tol, + ): + break + if ((max_weight != 0 and max_change / max_weight <= tol) or max_weight == 0 and max_change == 0): if verbose: diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 9de03c2c663ec..7b140dc732464 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -18,7 +18,7 @@ from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial, reduce -from itertools import product +from itertools import cycle, product import numpy as np from numpy.ma import MaskedArray @@ -900,7 +900,9 @@ def fit(self, X, y=None, **params): all_out = [] all_more_results = defaultdict(list) - def evaluate_candidates(candidate_params, cv=None, more_results=None): + def evaluate_candidates( + candidate_params, cv=None, more_results=None, parent_node=None + ): cv = cv or cv_orig candidate_params = list(candidate_params) n_candidates = len(candidate_params) @@ -913,6 +915,11 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None): ) ) + if parent_node is not None: + nodes = parent_node.children + else: + nodes = cycle([None]) + out = parallel( delayed(_fit_and_score)( clone(base_estimator), @@ -924,10 +931,18 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None): split_progress=(split_idx, n_splits), candidate_progress=(cand_idx, n_candidates), **fit_and_score_kwargs, + caller=self, + node=node, ) - for (cand_idx, parameters), (split_idx, (train, test)) in product( - enumerate(candidate_params), - enumerate(cv.split(X, y, **routed_params.splitter.split)), + for ( + (cand_idx, parameters), + (split_idx, (train, test)), + ), node in zip( + product( + enumerate(candidate_params), + enumerate(cv.split(X, y, **routed_params.splitter.split)), + ), + nodes, ) ) @@ -1522,9 +1537,58 @@ def __init__( ) self.param_grid = param_grid + def fit(self, X, y=None, *, groups=None, **fit_params): + """Run fit with all sets of parameters. + + Parameters + ---------- + + X : array-like of shape (n_samples, n_features) + Training vector, where `n_samples` is the number of samples and + `n_features` is the number of features. + + y : array-like of shape (n_samples, n_output) or (n_samples,), default=None + Target relative to X for classification or regression; + None for unsupervised learning. + + groups : array-like of shape (n_samples,), default=None + Group labels for the samples used while splitting the dataset into + train/test set. Only used in conjunction with a "Group" :term:`cv` + instance (e.g., :class:`~sklearn.model_selection.GroupKFold`). + + **fit_params : dict of str -> object + Parameters passed to the `fit` method of the estimator. + + If a fit parameter is an array-like whose length is equal to + `num_samples` then it will be split across CV groups along with `X` + and `y`. For example, the :term:`sample_weight` parameter is split + because `len(sample_weights) = len(X)`. + + Returns + ------- + self : object + Instance of fitted estimator. + """ + self._param_grid = ParameterGrid(self.param_grid) + + self._checked_cv_orig = check_cv( + self.cv, y, classifier=is_classifier(self.estimator) + ) + n_splits = self._checked_cv_orig.get_n_splits(X, y, groups) + + self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": len(self._param_grid) * n_splits}, + {"descr": "param - fold", "max_iter": None}, + ], + X=X, + y=y, + ) + super().fit(X, y=y, groups=groups, **fit_params) + def _run_search(self, evaluate_candidates): """Search all candidates in param_grid""" - evaluate_candidates(ParameterGrid(self.param_grid)) + evaluate_candidates(self._param_grid, parent_node=self._computation_tree.root) class RandomizedSearchCV(BaseSearchCV): diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index f3c8735043408..e2e23500a00a2 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -25,6 +25,7 @@ from joblib import logger from ..base import clone, is_classifier +from ..callback._base import _eval_callbacks_on_fit_iter_end from ..exceptions import FitFailedWarning, UnsetMetadataPassedError from ..metrics import check_scoring, get_scorer_names from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer @@ -748,6 +749,8 @@ def _fit_and_score( split_progress=None, candidate_progress=None, error_score=np.nan, + caller=None, + node=None, ): """Fit estimator and compute scores for a given dataset split. @@ -877,6 +880,9 @@ def _fit_and_score( # ref: https://github.com/scikit-learn/scikit-learn/pull/26786 estimator = estimator.set_params(**clone(parameters, safe=False)) + if caller is not None: + caller._propagate_callbacks(estimator, parent_node=node) + start_time = time.time() X_train, y_train = _safe_split(estimator, X, y, train) @@ -943,6 +949,8 @@ def _fit_and_score( end_msg += result_msg print(end_msg) + _eval_callbacks_on_fit_iter_end(estimator=caller, node=node) + result["test_scores"] = test_scores if return_train_score: result["train_scores"] = train_scores diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 713c7d6116f53..6169084f5ee6d 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -16,9 +16,14 @@ from scipy import sparse from .base import TransformerMixin, _fit_context, clone +from .callback._base import _eval_callbacks_on_fit_iter_end from .exceptions import NotFittedError from .preprocessing import FunctionTransformer -from .utils import Bunch, _print_elapsed_time, check_pandas_support +from .utils import ( + Bunch, + _print_elapsed_time, + check_pandas_support, +) from .utils._estimator_html_repr import _VisualBlock from .utils._metadata_requests import METHODS from .utils._param_validation import HasMethods, Hidden @@ -382,12 +387,23 @@ def _fit(self, X, y=None, routed_params=None): # Setup the memory memory = check_memory(self.memory) + root, *_ = self._eval_callbacks_on_fit_begin( + levels=[ + {"descr": "fit", "max_iter": len(self.steps)}, + {"descr": "step", "max_iter": None}, + ], + X=X, + y=y, + ) + fit_transform_one_cached = memory.cache(_fit_transform_one) for step_idx, name, transformer in self._iter( with_final=False, filter_passthrough=False ): + node = root.children[step_idx] if transformer is None or transformer == "passthrough": + _eval_callbacks_on_fit_iter_end(estimator=self, node=node) with _print_elapsed_time("Pipeline", self._log_message(step_idx)): continue @@ -397,6 +413,9 @@ def _fit(self, X, y=None, routed_params=None): cloned_transformer = transformer else: cloned_transformer = clone(transformer) + + self._propagate_callbacks(cloned_transformer, parent_node=node) + # Fit or load from cache the current transformer X, fitted_transformer = fit_transform_one_cached( cloned_transformer, @@ -411,6 +430,9 @@ def _fit(self, X, y=None, routed_params=None): # transformer. This is necessary when loading the transformer # from the cache. self.steps[step_idx] = (name, fitted_transformer) + + _eval_callbacks_on_fit_iter_end(estimator=self, node=node) + return X @_fit_context( @@ -464,9 +486,14 @@ def fit(self, X, y=None, **params): Xt = self._fit(X, y, routed_params) with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if self._final_estimator != "passthrough": + node = self._computation_tree.root.children[-1] + self._propagate_callbacks(self._final_estimator, parent_node=node) + last_step_params = routed_params[self.steps[-1][0]] self._final_estimator.fit(Xt, y, **last_step_params["fit"]) + _eval_callbacks_on_fit_iter_end(estimator=self, node=node) + return self def _can_fit_transform(self): diff --git a/sklearn/utils/optimize.py b/sklearn/utils/optimize.py index 024b0bcaf95ee..807c78810df37 100644 --- a/sklearn/utils/optimize.py +++ b/sklearn/utils/optimize.py @@ -18,6 +18,7 @@ import numpy as np import scipy +from ..callback._base import _eval_callbacks_on_fit_iter_end from ..exceptions import ConvergenceWarning from .fixes import line_search_wolfe1, line_search_wolfe2 @@ -156,6 +157,8 @@ def _newton_cg( maxinner=200, line_search=True, warn=True, + estimator=None, + parent_node=None, ): """ Minimization of scalar function of one or more variables using the @@ -217,7 +220,17 @@ def _newton_cg( fgrad, fhess_p = grad_hess(xk, *args) absgrad = np.abs(fgrad) - if np.max(absgrad) <= tol: + max_absgrad = np.max(absgrad) + + if _eval_callbacks_on_fit_iter_end( + estimator=estimator, + node=None if parent_node is None else parent_node.children[k], + stopping_criterion=lambda: max_absgrad, + tol=tol, + ): + break + + if max_absgrad <= tol: break maggrad = np.sum(absgrad)