New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEA Callbacks base infrastructure + progress bars #27663
base: callbacks
Are you sure you want to change the base?
Changes from 36 commits
272e75f
584bdf7
bb32ff3
7a1825d
3e3b25f
26dbb69
eb7b824
9b913fd
f78442e
34bab15
596a58e
4f9363c
030f68b
35c5284
115e184
bdb4990
d1bb5eb
7a43c30
573fd5d
a218068
f794694
ab74f19
37e569b
d7208fa
774ff69
b8ac1a5
e544cc4
b644430
3ab3d7f
39c04cc
73ecb31
9058919
309f755
3569329
2e28e4a
57b30b1
5270bad
ae5facc
df50ab3
aaa2dec
a3e2b35
e13516d
a0667c4
2fdbda3
aea9af7
44b615a
fabe932
07a6875
02ecb2e
6433ba3
052f9d2
268d5cf
2381645
d392b63
9177757
436bcad
5bf6608
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,6 +89,7 @@ | |
|
||
__all__ = [ | ||
"calibration", | ||
"callback", | ||
"cluster", | ||
"covariance", | ||
"cross_decomposition", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
|
||
from . import __version__ | ||
from ._config import config_context, get_config | ||
from .callback import BaseCallback, build_computation_tree | ||
from .exceptions import InconsistentVersionWarning | ||
from .utils import _IS_32BIT | ||
from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr | ||
|
@@ -115,6 +116,10 @@ | |
|
||
params_set = new_object.get_params(deep=False) | ||
|
||
# copy callbacks | ||
if hasattr(estimator, "_callbacks"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, this might break something for https://github.com/microsoft/FLAML/blob/0415638dd1e1d3149fb17fb8760520af975d16f6/flaml/automl/model.py#L1587 which also adds this attribute to scikit-learn's base estimator in their library. But there is no reserved private namespaces, so probably they could adapt if it does. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can also break quite easily if the callback object keeps references to attributes of the old estimator. Why aren't we creating a copy here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I can change the name to _sk_callbacks or any derivative of that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The motivation is a use case like this monitoring = Monitoring()
lr = LogisitcRegression()._set_callbacks(monitoring)
GridSearchCV(lr, param_grid).fit(X,y)
monitoring.plot() The monitoring callback will gather information across all copies of logistic regression made in the grid search. If we made a copy of the callback in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The object itself can disable copy by implementing I think the kind of state you're referring to here, is something which can be outside the callback object, like a file / a database / an external singleton object, and the callback method just writes into that storage, and at the end one can use that data to plot/investigate/etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
good point and actually in my latest design of the callbacks I didn't rely on a shared state so you can ignore my previous comment :) And we can't get around copies anyway since the clones can happen in subprocesses (in a grid search for instance). I updated to clone the callbacks as any other param |
||
new_object._callbacks = estimator._callbacks | ||
|
||
# quick sanity check of the parameters of the clone | ||
for name in new_object_params: | ||
param1 = new_object_params[name] | ||
|
@@ -641,6 +646,127 @@ | |
caller_name=self.__class__.__name__, | ||
) | ||
|
||
def _set_callbacks(self, callbacks): | ||
"""Set callbacks for the estimator. | ||
|
||
Parameters | ||
---------- | ||
callbacks : callback or list of callbacks | ||
the callbacks to set. | ||
|
||
Returns | ||
------- | ||
self : estimator instance | ||
The estimator instance itself. | ||
""" | ||
if not isinstance(callbacks, list): | ||
callbacks = [callbacks] | ||
|
||
if not all(isinstance(callback, BaseCallback) for callback in callbacks): | ||
raise TypeError("callbacks must be subclasses of BaseCallback.") | ||
adrinjalali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self._callbacks = callbacks | ||
|
||
return self | ||
|
||
# XXX should be a method of MetaEstimatorMixin but this mixin can't handle all | ||
# meta-estimators. | ||
def _propagate_callbacks(self, sub_estimator, *, parent_node): | ||
"""Propagate the auto-propagated callbacks to a sub-estimator | ||
jeremiedbb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Parameters | ||
---------- | ||
sub_estimator : estimator instance | ||
The sub-estimator to propagate the callbacks to. | ||
|
||
parent_node : ComputationNode instance | ||
The computation node in this estimator to set as parent_node to the | ||
jeremiedbb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
computation tree of the sub-estimator. It must be the node where the fit | ||
method of the sub-estimator is called. | ||
""" | ||
if hasattr(sub_estimator, "_callbacks") and any( | ||
callback.auto_propagate for callback in sub_estimator._callbacks | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if you call this method twice? Wouldn't this be raised on the second run. If so it feels a bit fragile. What's the harm of not checking this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should never be called twice. The meta estimator only propagates its callbacks once to its sub-estimator(s). This error is important to have an informative error message if a user tries something like
It would crash without telling the user what's the right way to do it. |
||
bad_callbacks = [ | ||
callback.__class__.__name__ | ||
for callback in sub_estimator._callbacks | ||
if callback.auto_propagate | ||
] | ||
raise TypeError( | ||
f"The sub-estimators ({sub_estimator.__class__.__name__}) of a" | ||
f" meta-estimator ({self.__class__.__name__}) can't have" | ||
f" auto-propagated callbacks ({bad_callbacks})." | ||
" Set them directly on the meta-estimator." | ||
) | ||
|
||
if not hasattr(self, "_callbacks"): | ||
return | ||
|
||
propagated_callbacks = [ | ||
callback for callback in self._callbacks if callback.auto_propagate | ||
] | ||
|
||
if not propagated_callbacks: | ||
return | ||
|
||
sub_estimator._parent_node = parent_node | ||
|
||
sub_estimator._set_callbacks( | ||
getattr(sub_estimator, "_callbacks", []) + propagated_callbacks | ||
) | ||
|
||
def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None): | ||
"""Evaluate the on_fit_begin method of the callbacks | ||
jeremiedbb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
The computation tree is also built at this point. | ||
|
||
This method should be called after all data and parameters validation. | ||
|
||
Parameters | ||
---------- | ||
X : ndarray or sparse matrix, default=None | ||
The training data. | ||
|
||
y : ndarray, default=None | ||
The target. | ||
|
||
levels : list of dict | ||
A description of the nested levels of computation of the estimator to build | ||
the computation tree. It's a list of dict with "descr" and "max_iter" keys. | ||
|
||
Returns | ||
------- | ||
root : ComputationNode instance | ||
The root of the computation tree. | ||
""" | ||
self._computation_tree = build_computation_tree( | ||
estimator_name=self.__class__.__name__, | ||
levels=levels, | ||
parent=getattr(self, "_parent_node", None), | ||
) | ||
|
||
if not hasattr(self, "_callbacks"): | ||
return self._computation_tree | ||
|
||
# Only call the on_fit_begin method of callbacks that are not | ||
# propagated from a meta-estimator. | ||
for callback in self._callbacks: | ||
if not callback._is_propagated(estimator=self): | ||
callback.on_fit_begin(estimator=self, X=X, y=y) | ||
|
||
return self._computation_tree | ||
|
||
def _eval_callbacks_on_fit_end(self): | ||
"""Evaluate the on_fit_end method of the callbacks""" | ||
jeremiedbb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not hasattr(self, "_callbacks") or not hasattr(self, "_computation_tree"): | ||
return | ||
|
||
# Only call the on_fit_end method of callbacks that are not | ||
# propagated from a meta-estimator. | ||
for callback in self._callbacks: | ||
if not callback._is_propagated(estimator=self): | ||
callback.on_fit_end() | ||
|
||
@property | ||
def _repr_html_(self): | ||
"""HTML representation of estimator. | ||
|
@@ -1212,7 +1338,10 @@ | |
prefer_skip_nested_validation or global_skip_validation | ||
) | ||
): | ||
return fit_method(estimator, *args, **kwargs) | ||
try: | ||
return fit_method(estimator, *args, **kwargs) | ||
finally: | ||
estimator._eval_callbacks_on_fit_end() | ||
Comment on lines
+1558
to
+1561
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is becoming larger than just a validation wrapper. We can simplify debugging and magic by having a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The motivation when I introduced Although having a consistent framework where we'd have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw do you why Also, note that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still think it's outside the scope of this PR. Using the existing context manager is just 1 line addition whereas implementing |
||
|
||
return wrapper | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# License: BSD 3 clause | ||
jeremiedbb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Authors: the scikit-learn developers | ||
|
||
from ._base import BaseCallback | ||
from ._computation_tree import ComputationNode, build_computation_tree | ||
from ._progressbar import ProgressBar | ||
|
||
__all__ = [ | ||
"BaseCallback", | ||
"build_computation_tree", | ||
"ComputationNode", | ||
"ProgressBar", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# License: BSD 3 clause | ||
# Authors: the scikit-learn developers | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
|
||
# Not a method of BaseEstimator because it might not be directly called from fit but | ||
# by a non-method function called by fit | ||
def _eval_callbacks_on_fit_iter_end(**kwargs): | ||
"""Evaluate the on_fit_iter_end method of the callbacks | ||
jeremiedbb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
This function must be called at the end of each computation node. | ||
|
||
Parameters | ||
---------- | ||
kwargs : dict | ||
arguments passed to the callback. | ||
jeremiedbb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns | ||
------- | ||
stop : bool | ||
Whether or not to stop the fit at this node. | ||
""" | ||
estimator = kwargs.get("estimator") | ||
node = kwargs.get("node") | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if not hasattr(estimator, "_callbacks") or node is None: | ||
return False | ||
adrinjalali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# stopping_criterion and reconstruction_attributes can be costly to compute. | ||
# They are passed as lambdas for lazy evaluation. We only actually | ||
# compute them if a callback requests it. | ||
# TODO: This is not used yet but will be necessary for next callbacks | ||
# Uncomment when needed | ||
# if any(cb.request_stopping_criterion for cb in estimator._callbacks): | ||
# kwarg = kwargs.pop("stopping_criterion", lambda: None)() | ||
# kwargs["stopping_criterion"] = kwarg | ||
|
||
# if any(cb.request_from_reconstruction_attributes for cb in estimator._callbacks): | ||
# kwarg = kwargs.pop("from_reconstruction_attributes", lambda: None)() | ||
# kwargs["from_reconstruction_attributes"] = kwarg | ||
|
||
return any(callback.on_fit_iter_end(**kwargs) for callback in estimator._callbacks) | ||
|
||
|
||
class BaseCallback(ABC): | ||
"""Abstract class for the callbacks""" | ||
|
||
@abstractmethod | ||
def on_fit_begin(self, estimator, *, X=None, y=None): | ||
"""Method called at the beginning of the fit method of the estimator | ||
|
||
Only called | ||
|
||
Parameters | ||
---------- | ||
estimator : estimator instance | ||
The estimator the callback is set on. | ||
|
||
X : ndarray or sparse matrix, default=None | ||
The training data. | ||
|
||
y : ndarray or sparse matrix, default=None | ||
The target. | ||
""" | ||
|
||
@abstractmethod | ||
def on_fit_end(self): | ||
"""Method called at the end of the fit method of the estimator""" | ||
|
||
@abstractmethod | ||
def on_fit_iter_end(self, estimator, node, **kwargs): | ||
"""Method called at the end of each computation node of the estimator | ||
|
||
Parameters | ||
---------- | ||
estimator : estimator instance | ||
The caller estimator. It might differ from the estimator passed to the | ||
`on_fit_begin` method for auto-propagated callbacks. | ||
|
||
node : ComputationNode instance | ||
The caller computation node. | ||
|
||
**kwargs : dict | ||
arguments passed to the callback. Possible keys are | ||
|
||
- stopping_criterion: float | ||
Usually iterations stop when `stopping_criterion <= tol`. | ||
This is only provided at the innermost level of iterations. | ||
|
||
- tol: float | ||
Tolerance for the stopping criterion. | ||
This is only provided at the innermost level of iterations. | ||
|
||
- from_reconstruction_attributes: estimator instance | ||
A ready to predict, transform, etc ... estimator as if the fit stopped | ||
at this node. Usually it's a copy of the caller estimator with the | ||
necessary attributes set but it can sometimes be an instance of another | ||
class (e.g. LogisticRegressionCV -> LogisticRegression) | ||
|
||
- fit_state: dict | ||
Model specific quantities updated during fit. This is not meant to be | ||
used by generic callbacks but by a callback designed for a specific | ||
estimator instead. | ||
|
||
Returns | ||
------- | ||
stop : bool or None | ||
Whether or not to stop the current level of iterations at this node. | ||
""" | ||
|
||
@property | ||
def auto_propagate(self): | ||
"""Whether or not this callback should be propagated to sub-estimators. | ||
|
||
An auto-propagated callback (from a meta-estimator to its sub-estimators) must | ||
be set on the meta-estimator. Its `on_fit_begin` and `on_fit_end` methods will | ||
only be called at the beginning and end of the fit method of the meta-estimator, | ||
while its `on_fit_iter_end` method will be called at each computation node of | ||
the meta-estimator and its sub-estimators. | ||
""" | ||
return False | ||
|
||
def _is_propagated(self, estimator): | ||
"""Check if this callback attached to estimator has been propagated from a | ||
meta-estimator. | ||
""" | ||
return self.auto_propagate and hasattr(estimator, "_parent_node") | ||
|
||
# TODO: This is not used yet but will be necessary for next callbacks | ||
# Uncomment when needed | ||
# @property | ||
# def request_stopping_criterion(self): | ||
# return False | ||
|
||
# @property | ||
# def request_from_reconstruction_attributes(self): | ||
# return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not really a copy here, isn't it?