Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Callback API continued #22000

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions sklearn/__init__.py
Expand Up @@ -89,6 +89,7 @@

__all__ = [
"calibration",
"callback",
"cluster",
"covariance",
"cross_decomposition",
Expand Down
189 changes: 188 additions & 1 deletion sklearn/base.py
Expand Up @@ -6,15 +6,20 @@
import copy
import functools
import inspect
import pickle
import platform
import re
import warnings
from collections import defaultdict
from functools import partial
from shutil import rmtree

import numpy as np

from . import __version__
from ._config import config_context, get_config
from .callback import BaseCallback, ComputationTree
from .callback._base import CallbackContext
from .exceptions import InconsistentVersionWarning
from .utils import _IS_32BIT
from .utils._estimator_html_repr import estimator_html_repr
Expand Down Expand Up @@ -115,6 +120,10 @@ def _clone_parametrized(estimator, *, safe=True):

params_set = new_object.get_params(deep=False)

# copy callbacks
if hasattr(estimator, "_callbacks"):
new_object._callbacks = estimator._callbacks

# quick sanity check of the parameters of the clone
for name in new_object_params:
param1 = new_object_params[name]
Expand Down Expand Up @@ -641,6 +650,181 @@ class attribute, which is a dictionary `param_name: list of constraints`. See
caller_name=self.__class__.__name__,
)

def _set_callbacks(self, callbacks):
"""Set callbacks for the estimator.

Parameters
----------
callbacks : callback or list of callbacks
the callbacks to set.

Returns
-------
self : estimator instance
The estimator instance itself.
"""
if not isinstance(callbacks, list):
callbacks = [callbacks]

if not all(isinstance(callback, BaseCallback) for callback in callbacks):
raise TypeError("callbacks must be subclasses of BaseCallback.")

self._callbacks = callbacks

return self

# XXX should be a method of MetaEstimatorMixin but this mixin can't handle all
# meta-estimators.
def _propagate_callbacks(self, sub_estimator, *, parent_node):
"""Propagate the auto-propagated callbacks to a sub-estimator

Parameters
----------
sub_estimator : estimator instance
The sub-estimator to propagate the callbacks to.

parent_node : ComputationNode instance
The computation node in this estimator to set as parent_node to the
computation tree of the sub-estimator. It must be the node where the fit
method of the sub-estimator is called.
"""
if hasattr(sub_estimator, "_callbacks") and any(
callback.auto_propagate for callback in sub_estimator._callbacks
):
bad_callbacks = [
callback.__class__.__name__
for callback in sub_estimator._callbacks
if callback.auto_propagate
]
raise TypeError(
f"The sub-estimators ({sub_estimator.__class__.__name__}) of a"
f" meta-estimator ({self.__class__.__name__}) can't have"
f" auto-propagated callbacks ({bad_callbacks})."
" Set them directly on the meta-estimator."
)

if not hasattr(self, "_callbacks"):
return

propagated_callbacks = [
callback for callback in self._callbacks if callback.auto_propagate
]

if not propagated_callbacks:
return

sub_estimator._parent_ct_node = parent_node

sub_estimator._set_callbacks(
getattr(sub_estimator, "_callbacks", []) + propagated_callbacks
)

def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be too magical to have only something call _eval_callbacks_begin and inspect internal the stack of calls to infer which method called this function.

Of course it would make sense only if the same methods are expected to be called for fit/predict/etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course it would make sense only if the same methods are expected to be called for fit/predict/etc.

Well that's not obvious at all and I've not really thought about that. This first iteration is all about fit. I think it will be easier to not try to be too magical for now

"""Evaluate the on_fit_begin method of the callbacks

The computation tree is also built at this point.

This method should be called after all data and parameters validation.

Parameters
----------
X : ndarray or sparse matrix, default=None
The training data.

y : ndarray, default=None
The target.

levels : list of dict
A description of the nested levels of computation of the estimator to build
the computation tree. It's a list of dict with "descr" and "max_iter" keys.

Returns
-------
root : ComputationNode instance
The root of the computation tree.
"""
self._computation_tree = ComputationTree(
estimator_name=self.__class__.__name__,
levels=levels,
parent_node=getattr(self, "_parent_ct_node", None),
)

if not hasattr(self, "_callbacks"):
return self._computation_tree.root, None, None, None, None

X_val, y_val = None, None

if any(callback.request_validation_split for callback in self._callbacks):
splitter = next(
callback.validation_split
for callback in self._callbacks
if hasattr(callback, "validation_split")
)

train, val = next(splitter.split(X))
if X is not None:
X, X_val = X[train], X[val]
if y is not None:
y, y_val = y[train], y[val]

#
CallbackContext(
self._callbacks,
finalizer=partial(rmtree, ignore_errors=True),
finalizer_args=self._computation_tree.tree_dir,
)

#
file_path = self._computation_tree.tree_dir / "computation_tree.pkl"
with open(file_path, "wb") as f:
pickle.dump(self._computation_tree, f)

# Only call the on_fit_begin method of callbacks that are not
# propagated from a meta-estimator.
for callback in self._callbacks:
if not callback._is_propagated(estimator=self):
callback.on_fit_begin(estimator=self, X=X, y=y)

return self._computation_tree.root, X, y, X_val, y_val

def _eval_callbacks_on_fit_end(self):
"""Evaluate the on_fit_end method of the callbacks"""
if not hasattr(self, "_callbacks"):
return

if not hasattr(self, "_computation_tree"):
return

self._computation_tree._tree_status[0] = True

# Only call the on_fit_end method of callbacks that are not
# propagated from a meta-estimator.
for callback in self._callbacks:
if not callback._is_propagated(estimator=self):
callback.on_fit_end()

def _from_reconstruction_attributes(self, *, reconstruction_attributes):
"""Return an as if fitted copy of this estimator

Parameters
----------
reconstruction_attributes : callable
A callable that has no arguments and returns the necessary fitted attributes
to create a working fitted estimator from this instance.

Using a callable allows lazy evaluation of the potentially costly
reconstruction attributes.

Returns
-------
fitted_estimator : estimator instance
The fitted copy of this estimator.
"""
new_estimator = copy.copy(self) # XXX deepcopy ?
for key, val in reconstruction_attributes().items():
setattr(new_estimator, key, val)
return new_estimator

@property
def _repr_html_(self):
"""HTML representation of estimator.
Expand Down Expand Up @@ -1212,7 +1396,10 @@ def wrapper(estimator, *args, **kwargs):
prefer_skip_nested_validation or global_skip_validation
)
):
return fit_method(estimator, *args, **kwargs)
try:
return fit_method(estimator, *args, **kwargs)
finally:
estimator._eval_callbacks_on_fit_end()

return wrapper

Expand Down
21 changes: 21 additions & 0 deletions sklearn/callback/__init__.py
@@ -0,0 +1,21 @@
# License: BSD 3 clause

from ._base import BaseCallback
from ._computation_tree import ComputationNode, ComputationTree, load_computation_tree
from ._early_stopping import EarlyStopping
from ._monitoring import Monitoring
from ._progressbar import ProgressBar
from ._snapshot import Snapshot
from ._text_verbose import TextVerbose

__all__ = [
"BaseCallback",
"ComputationNode",
"ComputationTree",
"load_computation_tree",
"Monitoring",
"EarlyStopping",
"ProgressBar",
"Snapshot",
"TextVerbose",
]