Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Callback API continued #22000

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions sklearn/__init__.py
Expand Up @@ -84,6 +84,7 @@

__all__ = [
"calibration",
"callback",
"cluster",
"covariance",
"cross_decomposition",
Expand Down
136 changes: 136 additions & 0 deletions sklearn/base.py
Expand Up @@ -9,6 +9,7 @@
import platform
import inspect
import re
import pickle

import numpy as np

Expand All @@ -28,6 +29,9 @@
from .utils.validation import check_is_fitted
from .utils._estimator_html_repr import estimator_html_repr
from .utils.validation import _get_feature_names
from .callback import BaseCallback
from .callback import AutoPropagatedMixin
from .callback import ComputationTree


def clone(estimator, *, safe=True):
Expand Down Expand Up @@ -84,6 +88,10 @@ def clone(estimator, *, safe=True):
new_object = klass(**new_object_params)
params_set = new_object.get_params(deep=False)

# copy callbacks
if hasattr(estimator, "_callbacks"):
new_object._callbacks = clone(estimator._callbacks, safe=False)

# quick sanity check of the parameters of the clone
for name in new_object_params:
param1 = new_object_params[name]
Expand Down Expand Up @@ -597,6 +605,134 @@ def _validate_data(

return out

def _set_callbacks(self, callbacks):
"""Set callbacks for the estimator.

Parameters
----------
callbacks : callback or list of callbacks
the callbacks to set.
"""
if not isinstance(callbacks, list):
callbacks = [callbacks]

if not all(isinstance(callback, BaseCallback) for callback in callbacks):
raise TypeError(f"callbacks must be subclasses of BaseCallback.")

self._callbacks = callbacks

# XXX should be a method of MetaEstimatorMixin but this mixin can't handle all
# meta-estimators.
def _propagate_callbacks(self, sub_estimator, parent_node):
"""Propagate the auto-propagated callbacks to a sub-estimator

Parameters
----------
sub_estimator : estimator instance
The sub-estimator to propagate the callbacks to.

parent_node : ComputationNode instance
The computation node in this estimator to set as parent_node to the
computation tree of the sub-estimator. It must be the node where the fit
method of the sub-estimator is called.
"""
if not hasattr(self, "_callbacks"):
return

if hasattr(sub_estimator, "_callbacks") and any(
isinstance(callback, AutoPropagatedMixin)
for callback in sub_estimator._callbacks
):
bad_callbacks = [
callback.__class__.__name__
for callback in sub_estimator._callbacks
if isinstance(callback, AutoPropagatedMixin)
]
raise TypeError(
f"The sub-estimators ({sub_estimator.__class__.__name__}) of a"
f" meta-estimator ({self.__class__.__name__}) can't have"
f" auto-propagated callbacks ({bad_callbacks})."
" Set them directly on the meta-estimator."
)

propagated_callbacks = [
callback
for callback in self._callbacks
if isinstance(callback, AutoPropagatedMixin)
]

if not propagated_callbacks:
return

sub_estimator._parent_node = parent_node

if not hasattr(sub_estimator, "_callbacks"):
sub_estimator._callbacks = propagated_callbacks
else:
sub_estimator._callbacks.extend(propagated_callbacks)

def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be too magical to have only something call _eval_callbacks_begin and inspect internal the stack of calls to infer which method called this function.

Of course it would make sense only if the same methods are expected to be called for fit/predict/etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course it would make sense only if the same methods are expected to be called for fit/predict/etc.

Well that's not obvious at all and I've not really thought about that. This first iteration is all about fit. I think it will be easier to not try to be too magical for now

"""Evaluate the on_fit_begin method of the callbacks

The computation tree is also built at this point.

This method should be called after all data and parameters validation.

Parameters
----------
X : ndarray or sparse matrix, default=None
The training data.

y : ndarray, default=None
The target.

levels : list of dict
A description of the nested levels of computation of the estimator to build
the computation tree. It's a list of dict with "descr" and "max_iter" keys.

Returns
-------
root : ComputationNode instance
The root of the computation tree.
"""
self._computation_tree = ComputationTree(
estimator_name=self.__class__.__name__,
levels=levels,
parent_node=getattr(self, "_parent_node", None),
)

if hasattr(self, "_callbacks"):
file_path = self._computation_tree.tree_dir / "computation_tree.pkl"
with open(file_path, "wb") as f:
pickle.dump(self._computation_tree, f)

for callback in self._callbacks:
is_propagated = hasattr(self, "_parent_node") and isinstance(
callback, AutoPropagatedMixin
)
if not is_propagated:
# Only call the on_fit_begin method of callbacks that are not
# propagated from a meta-estimator.
callback.on_fit_begin(estimator=self, X=X, y=y)

return self._computation_tree.root

def _eval_callbacks_on_fit_end(self):
"""Evaluate the on_fit_end method of the callbacks"""
if not hasattr(self, "_callbacks"):
return

self._computation_tree._tree_status[0] = True

for callback in self._callbacks:
is_propagated = isinstance(callback, AutoPropagatedMixin) and hasattr(
self, "_parent_node"
)
if not is_propagated:
# Only call the on_fit_end method of callbacks that are not
# propagated from a meta-estimator.
callback.on_fit_end()

@property
def _repr_html_(self):
"""HTML representation of estimator.
Expand Down
25 changes: 25 additions & 0 deletions sklearn/callback/__init__.py
@@ -0,0 +1,25 @@
# License: BSD 3 clause

from ._base import AutoPropagatedMixin
from ._base import BaseCallback
from ._computation_tree import ComputationNode
from ._computation_tree import ComputationTree
from ._computation_tree import load_computation_tree
from ._convergence_monitor import ConvergenceMonitor
from ._early_stopping import EarlyStopping
from ._progressbar import ProgressBar
from ._snapshot import Snapshot
from ._text_verbose import TextVerbose

__all__ = [
"AutoPropagatedMixin",
"Basecallback",
"ComputationNode",
"ComputationTree",
"load_computation_tree",
"ConvergenceMonitor",
"EarlyStopping",
"ProgressBar",
"Snapshot",
"TextVerbose",
]
126 changes: 126 additions & 0 deletions sklearn/callback/_base.py
@@ -0,0 +1,126 @@
# License: BSD 3 clause

from abc import ABC, abstractmethod


# Not a method of BaseEstimator because it might be called from an extern function
def _eval_callbacks_on_fit_iter_end(**kwargs):
"""Evaluate the on_fit_iter_end method of the callbacks

This function should be called at the end of each computation node.

Parameters
----------
kwargs : dict
arguments passed to the callback.

Returns
-------
stop : bool
Whether or not to stop the fit at this node.
"""
estimator = kwargs.get("estimator")
node = kwargs.get("node")

if not hasattr(estimator, "_callbacks") or node is None:
return False

estimator._computation_tree._tree_status[node.tree_status_idx] = True

# stopping_criterion and reconstruction_attributes can be costly to compute. They
# are passed as lambdas for lazy evaluation. We only actually compute them if a
# callback requests it.
if any(
getattr(callback, "request_stopping_criterion", False)
for callback in estimator._callbacks
):
kwarg = kwargs.pop("stopping_criterion", lambda: None)()
kwargs["stopping_criterion"] = kwarg

if any(
getattr(callback, "request_reconstruction_attributes", False)
for callback in estimator._callbacks
):
kwarg = kwargs.pop("reconstruction_attributes", lambda: None)()
kwargs["reconstruction_attributes"] = kwarg

return any(callback.on_fit_iter_end(**kwargs) for callback in estimator._callbacks)


class BaseCallback(ABC):
"""Abstract class for the callbacks"""

@abstractmethod
def on_fit_begin(self, estimator, *, X=None, y=None):
"""Method called at the beginning of the fit method of the estimator

Parameters
----------
estimator: estimator instance
The estimator the callback is set on.
X: ndarray or sparse matrix, default=None
The training data.
y: ndarray, default=None
The target.
"""
pass

@abstractmethod
def on_fit_end(self):
"""Method called at the end of the fit method of the estimator"""
pass

@abstractmethod
def on_fit_iter_end(self, estimator, node, **kwargs):
"""Method called at the end of each computation node of the estimator

Parameters
----------
estimator : estimator instance
The caller estimator. It might differ from the estimator passed to the
`on_fit_begin` method for auto-propagated callbacks.

node : ComputationNode instance
The caller computation node.

kwargs : dict
arguments passed to the callback. Possible keys are

- stopping_criterion: float
Usually iterations stop when `stopping_criterion <= tol`.
This is only provided at the innermost level of iterations.

- tol: float
Tolerance for the stopping criterion.
This is only provided at the innermost level of iterations.

- reconstruction_attributes: dict
Necessary attributes to construct an estimator (by copying this
estimator and setting these as attributes) which will behave as if
the fit stopped at this node.
This is only provided at the outermost level of iterations.

- fit_state: dict
Model specific quantities updated during fit. This is not meant to be
used by generic callbacks but by a callback designed for a specific
estimator instead.

Returns
-------
stop : bool or None
Whether or not to stop the current level of iterations at this node.
"""
pass


class AutoPropagatedMixin:
"""Mixin for auto-propagated callbacks

An auto-propagated callback (from a meta-estimator to its sub-estimators) must be
set on the meta-estimator. Its `on_fit_begin` and `on_fit_end` methods will only be
called at the beginning and end of the fit method of the meta-estimator, while its
`on_fit_iter_end` method will be called at each computation node of the
meta-estimator and its sub-estimators.
"""

pass