Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA Callbacks base infrastructure + progress bars #27663

Open
wants to merge 57 commits into
base: callbacks
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
272e75f
callback API
jeremiedbb Dec 14, 2021
584bdf7
cln nmf and test reconstruction attributes
jeremiedbb Dec 17, 2021
bb32ff3
cln snapshot + test snapshot + uuid for computation tree
jeremiedbb Dec 20, 2021
7a1825d
cln
jeremiedbb Dec 31, 2021
3e3b25f
black
jeremiedbb Dec 31, 2021
26dbb69
lint
jeremiedbb Dec 31, 2021
eb7b824
wip
jeremiedbb Feb 14, 2022
9b913fd
Merge branch 'master' into callback-api
jeremiedbb Feb 14, 2022
f78442e
class
jeremiedbb Feb 23, 2022
34bab15
more tests
jeremiedbb Feb 23, 2022
596a58e
cln
jeremiedbb Feb 23, 2022
4f9363c
wip
jeremiedbb Sep 12, 2022
030f68b
Merge remote-tracking branch 'upstream/main' into callback-api
jeremiedbb Sep 12, 2022
35c5284
wip
jeremiedbb Sep 16, 2022
115e184
wip
jeremiedbb Sep 16, 2022
bdb4990
wip
jeremiedbb Sep 21, 2022
d1bb5eb
Merge remote-tracking branch 'upstream/main' into callback-api
jeremiedbb Sep 23, 2022
7a43c30
wip
jeremiedbb Sep 23, 2022
573fd5d
Merge remote-tracking branch 'upstream/main' into callback-api
jeremiedbb Oct 11, 2022
a218068
wip
jeremiedbb Oct 13, 2022
f794694
update poor_score
jeremiedbb Oct 13, 2022
ab74f19
Merge remote-tracking branch 'upstream/main' into pr/jeremiedbb/22000
jeremiedbb Jun 19, 2023
37e569b
wip
jeremiedbb Jun 21, 2023
d7208fa
wip
jeremiedbb Jun 29, 2023
774ff69
Merge remote-tracking branch 'upstream/main' into pr/jeremiedbb/22000
jeremiedbb Oct 17, 2023
b8ac1a5
cln
jeremiedbb Oct 18, 2023
e544cc4
Merge remote-tracking branch 'upstream/main' into pr/jeremiedbb/22000
jeremiedbb Oct 20, 2023
b644430
wip
jeremiedbb Oct 25, 2023
3ab3d7f
wip
jeremiedbb Oct 25, 2023
39c04cc
wip
jeremiedbb Oct 25, 2023
73ecb31
wip
jeremiedbb Oct 25, 2023
9058919
mypy
jeremiedbb Oct 27, 2023
309f755
add test for progressbars
jeremiedbb Oct 27, 2023
3569329
can't guarantee same order of tasks
jeremiedbb Oct 27, 2023
2e28e4a
cln
jeremiedbb Oct 27, 2023
57b30b1
Merge branch 'callbacks' into base
jeremiedbb Oct 27, 2023
5270bad
address nitpicks
jeremiedbb Nov 21, 2023
ae5facc
make rich soft dependency
jeremiedbb Nov 22, 2023
df50ab3
missing arg
jeremiedbb Nov 22, 2023
aaa2dec
improve coverage
jeremiedbb Nov 23, 2023
a3e2b35
Merge branch 'callbacks' into base
jeremiedbb Nov 23, 2023
e13516d
Merge branch 'callbacks' into base
jeremiedbb Feb 9, 2024
a0667c4
mixin for callback propagation
jeremiedbb Feb 9, 2024
2fdbda3
rename _skl_callbacks
jeremiedbb Feb 17, 2024
aea9af7
clone callbacks
jeremiedbb Feb 19, 2024
44b615a
some renaming and cleanup
jeremiedbb Feb 20, 2024
fabe932
Merge branch 'callbacks' into base
jeremiedbb Feb 20, 2024
07a6875
Merge branch 'callbacks' into base (continued)
jeremiedbb Feb 20, 2024
02ecb2e
Merge branch 'callbacks' into base
jeremiedbb Feb 21, 2024
6433ba3
fix imports
jeremiedbb Feb 21, 2024
052f9d2
Merge remote-tracking branch 'upstream/callbacks' into base
jeremiedbb Feb 23, 2024
268d5cf
update lock files
jeremiedbb Feb 23, 2024
2381645
Merge remote-tracking branch 'upstream/callbacks' into base
jeremiedbb Mar 6, 2024
d392b63
debug ci
jeremiedbb Mar 6, 2024
9177757
iter
jeremiedbb Mar 6, 2024
436bcad
iter
jeremiedbb Mar 6, 2024
5bf6608
iter
jeremiedbb Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions sklearn/__init__.py
Expand Up @@ -89,6 +89,7 @@

__all__ = [
"calibration",
"callback",
"cluster",
"covariance",
"cross_decomposition",
Expand Down
131 changes: 130 additions & 1 deletion sklearn/base.py
Expand Up @@ -15,6 +15,7 @@

from . import __version__
from ._config import config_context, get_config
from .callback import BaseCallback, build_computation_tree
from .exceptions import InconsistentVersionWarning
from .utils import _IS_32BIT
from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr
Expand Down Expand Up @@ -115,6 +116,10 @@

params_set = new_object.get_params(deep=False)

# copy callbacks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really a copy here, isn't it?

if hasattr(estimator, "_callbacks"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, this might break something for https://github.com/microsoft/FLAML/blob/0415638dd1e1d3149fb17fb8760520af975d16f6/flaml/automl/model.py#L1587 which also adds this attribute to scikit-learn's base estimator in their library. But there is no reserved private namespaces, so probably they could adapt if it does.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also break quite easily if the callback object keeps references to attributes of the old estimator. Why aren't we creating a copy here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, this might break something for https://github.com/microsoft/FLAML/blob/0415638dd1e1d3149fb17fb8760520af975d16f6/flaml/automl/model.py#L1587 which also adds this attribute to scikit-learn's base estimator in their library

I can change the name to _sk_callbacks or any derivative of that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we creating a copy here?

The motivation is a use case like this

monitoring = Monitoring()
lr = LogisitcRegression()._set_callbacks(monitoring)
GridSearchCV(lr, param_grid).fit(X,y)
monitoring.plot()

The monitoring callback will gather information across all copies of logistic regression made in the grid search. If we made a copy of the callback in clone, then we couldn't retrieve any information once the grid search is finished.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The object itself can disable copy by implementing __copy__ and __deep_copy__, and then they would be in the space of "we know what we're doing" and we don't need to worry about it.

I think the kind of state you're referring to here, is something which can be outside the callback object, like a file / a database / an external singleton object, and the callback method just writes into that storage, and at the end one can use that data to plot/investigate/etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the kind of state you're referring to here, is something which can be outside the callback object, like a file / a database / an external singleton object

good point and actually in my latest design of the callbacks I didn't rely on a shared state so you can ignore my previous comment :) And we can't get around copies anyway since the clones can happen in subprocesses (in a grid search for instance). I updated to clone the callbacks as any other param

new_object._callbacks = estimator._callbacks

# quick sanity check of the parameters of the clone
for name in new_object_params:
param1 = new_object_params[name]
Expand Down Expand Up @@ -641,6 +646,127 @@
caller_name=self.__class__.__name__,
)

def _set_callbacks(self, callbacks):
"""Set callbacks for the estimator.

Parameters
----------
callbacks : callback or list of callbacks
the callbacks to set.

Returns
-------
self : estimator instance
The estimator instance itself.
"""
if not isinstance(callbacks, list):
callbacks = [callbacks]

if not all(isinstance(callback, BaseCallback) for callback in callbacks):
raise TypeError("callbacks must be subclasses of BaseCallback.")
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved

self._callbacks = callbacks

return self

# XXX should be a method of MetaEstimatorMixin but this mixin can't handle all
# meta-estimators.
def _propagate_callbacks(self, sub_estimator, *, parent_node):
"""Propagate the auto-propagated callbacks to a sub-estimator
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
sub_estimator : estimator instance
The sub-estimator to propagate the callbacks to.

parent_node : ComputationNode instance
The computation node in this estimator to set as parent_node to the
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
computation tree of the sub-estimator. It must be the node where the fit
method of the sub-estimator is called.
"""
if hasattr(sub_estimator, "_callbacks") and any(
callback.auto_propagate for callback in sub_estimator._callbacks
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you call this method twice? Wouldn't this be raised on the second run. If so it feels a bit fragile. What's the harm of not checking this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should never be called twice. The meta estimator only propagates its callbacks once to its sub-estimator(s).

This error is important to have an informative error message if a user tries something like

lr = LogisiticRegression()._set_callbacks(ProgressBar)
GridSearchCV(lr, param_grid)

It would crash without telling the user what's the right way to do it.

bad_callbacks = [
callback.__class__.__name__
for callback in sub_estimator._callbacks
if callback.auto_propagate
]
raise TypeError(
f"The sub-estimators ({sub_estimator.__class__.__name__}) of a"
f" meta-estimator ({self.__class__.__name__}) can't have"
f" auto-propagated callbacks ({bad_callbacks})."
" Set them directly on the meta-estimator."
)

if not hasattr(self, "_callbacks"):
return

propagated_callbacks = [
callback for callback in self._callbacks if callback.auto_propagate
]

if not propagated_callbacks:
return

Check warning on line 710 in sklearn/base.py

View check run for this annotation

Codecov / codecov/patch

sklearn/base.py#L710

Added line #L710 was not covered by tests

sub_estimator._parent_node = parent_node

sub_estimator._set_callbacks(
getattr(sub_estimator, "_callbacks", []) + propagated_callbacks
)

def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None):
"""Evaluate the on_fit_begin method of the callbacks
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

The computation tree is also built at this point.

This method should be called after all data and parameters validation.

Parameters
----------
X : ndarray or sparse matrix, default=None
The training data.

y : ndarray, default=None
The target.

levels : list of dict
A description of the nested levels of computation of the estimator to build
the computation tree. It's a list of dict with "descr" and "max_iter" keys.

Returns
-------
root : ComputationNode instance
The root of the computation tree.
"""
self._computation_tree = build_computation_tree(
estimator_name=self.__class__.__name__,
levels=levels,
parent=getattr(self, "_parent_node", None),
)

if not hasattr(self, "_callbacks"):
return self._computation_tree

# Only call the on_fit_begin method of callbacks that are not
# propagated from a meta-estimator.
for callback in self._callbacks:
if not callback._is_propagated(estimator=self):
callback.on_fit_begin(estimator=self, X=X, y=y)

return self._computation_tree

def _eval_callbacks_on_fit_end(self):
"""Evaluate the on_fit_end method of the callbacks"""
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
if not hasattr(self, "_callbacks") or not hasattr(self, "_computation_tree"):
return

# Only call the on_fit_end method of callbacks that are not
# propagated from a meta-estimator.
for callback in self._callbacks:
if not callback._is_propagated(estimator=self):
callback.on_fit_end()

Check warning on line 768 in sklearn/base.py

View check run for this annotation

Codecov / codecov/patch

sklearn/base.py#L768

Added line #L768 was not covered by tests

@property
def _repr_html_(self):
"""HTML representation of estimator.
Expand Down Expand Up @@ -1212,7 +1338,10 @@
prefer_skip_nested_validation or global_skip_validation
)
):
return fit_method(estimator, *args, **kwargs)
try:
return fit_method(estimator, *args, **kwargs)
finally:
estimator._eval_callbacks_on_fit_end()
Comment on lines +1558 to +1561
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is becoming larger than just a validation wrapper. We can simplify debugging and magic by having a BaseEstimator.fit which calls self._fit(...) and does all the common stuff before and after. That seems a lot better to understand and debug.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation when I introduced _fit_context was not only for validation but to have a generic context manager to handle everything we need to do before and after fit. That's why I gave it this generic name.

Although having a consistent framework where we'd have a BaseEstimator.fit and every child estimator implements _fit is appealing, I think it goes far beyond the scope of this PR and requires rewriting a lot of estimators.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw do you why BaseEstimator does not implement fit in the first place ?

Also, note that _fit_context also handles partial_fit, but I don't think we want BaseEstimator to implement partial_fit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseEstimator doesn't implement fit cause we don't generally have methods which raise NotImplementedError. They're simply not there. But now that we have all this work, we can certainly have it in BaseEstimator, and children only implement a __sklearn_fit__ kind of method instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it's outside the scope of this PR. Using the existing context manager is just 1 line addition whereas implementing __sklearn_fit__ means countless PRs :)


return wrapper

Expand Down
13 changes: 13 additions & 0 deletions sklearn/callback/__init__.py
@@ -0,0 +1,13 @@
# License: BSD 3 clause
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
# Authors: the scikit-learn developers

from ._base import BaseCallback
from ._computation_tree import ComputationNode, build_computation_tree
from ._progressbar import ProgressBar

__all__ = [
"BaseCallback",
"build_computation_tree",
"ComputationNode",
"ProgressBar",
]
138 changes: 138 additions & 0 deletions sklearn/callback/_base.py
@@ -0,0 +1,138 @@
# License: BSD 3 clause
# Authors: the scikit-learn developers

from abc import ABC, abstractmethod


# Not a method of BaseEstimator because it might not be directly called from fit but
# by a non-method function called by fit
def _eval_callbacks_on_fit_iter_end(**kwargs):
"""Evaluate the on_fit_iter_end method of the callbacks
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

This function must be called at the end of each computation node.

Parameters
----------
kwargs : dict
arguments passed to the callback.
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
stop : bool
Whether or not to stop the fit at this node.
"""
estimator = kwargs.get("estimator")
node = kwargs.get("node")

Check warning on line 25 in sklearn/callback/_base.py

View check run for this annotation

Codecov / codecov/patch

sklearn/callback/_base.py#L24-L25

Added lines #L24 - L25 were not covered by tests
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

if not hasattr(estimator, "_callbacks") or node is None:
return False

Check warning on line 28 in sklearn/callback/_base.py

View check run for this annotation

Codecov / codecov/patch

sklearn/callback/_base.py#L28

Added line #L28 was not covered by tests
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved

# stopping_criterion and reconstruction_attributes can be costly to compute.
# They are passed as lambdas for lazy evaluation. We only actually
# compute them if a callback requests it.
# TODO: This is not used yet but will be necessary for next callbacks
# Uncomment when needed
# if any(cb.request_stopping_criterion for cb in estimator._callbacks):
# kwarg = kwargs.pop("stopping_criterion", lambda: None)()
# kwargs["stopping_criterion"] = kwarg

# if any(cb.request_from_reconstruction_attributes for cb in estimator._callbacks):
# kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)()
# kwargs["from_reconstruction_attributes"] = kwarg

return any(callback.on_fit_iter_end(**kwargs) for callback in estimator._callbacks)


class BaseCallback(ABC):
"""Abstract class for the callbacks"""

@abstractmethod
def on_fit_begin(self, estimator, *, X=None, y=None):
"""Method called at the beginning of the fit method of the estimator

Only called

Parameters
----------
estimator : estimator instance
The estimator the callback is set on.

X : ndarray or sparse matrix, default=None
The training data.

y : ndarray or sparse matrix, default=None
The target.
"""

@abstractmethod
def on_fit_end(self):
"""Method called at the end of the fit method of the estimator"""

@abstractmethod
def on_fit_iter_end(self, estimator, node, **kwargs):
"""Method called at the end of each computation node of the estimator

Parameters
----------
estimator : estimator instance
The caller estimator. It might differ from the estimator passed to the
`on_fit_begin` method for auto-propagated callbacks.

node : ComputationNode instance
The caller computation node.

**kwargs : dict
arguments passed to the callback. Possible keys are

- stopping_criterion: float
Usually iterations stop when `stopping_criterion <= tol`.
This is only provided at the innermost level of iterations.

- tol: float
Tolerance for the stopping criterion.
This is only provided at the innermost level of iterations.

- from_reconstruction_attributes: estimator instance
A ready to predict, transform, etc ... estimator as if the fit stopped
at this node. Usually it's a copy of the caller estimator with the
necessary attributes set but it can sometimes be an instance of another
class (e.g. LogisticRegressionCV -> LogisticRegression)

- fit_state: dict
Model specific quantities updated during fit. This is not meant to be
used by generic callbacks but by a callback designed for a specific
estimator instead.

Returns
-------
stop : bool or None
Whether or not to stop the current level of iterations at this node.
"""

@property
def auto_propagate(self):
"""Whether or not this callback should be propagated to sub-estimators.

An auto-propagated callback (from a meta-estimator to its sub-estimators) must
be set on the meta-estimator. Its `on_fit_begin` and `on_fit_end` methods will
only be called at the beginning and end of the fit method of the meta-estimator,
while its `on_fit_iter_end` method will be called at each computation node of
the meta-estimator and its sub-estimators.
"""
return False

def _is_propagated(self, estimator):
"""Check if this callback attached to estimator has been propagated from a
meta-estimator.
"""
return self.auto_propagate and hasattr(estimator, "_parent_node")

# TODO: This is not used yet but will be necessary for next callbacks
# Uncomment when needed
# @property
# def request_stopping_criterion(self):
# return False

# @property
# def request_from_reconstruction_attributes(self):
# return False