New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Callback API continued #22000
Draft
jeremiedbb
wants to merge
26
commits into
scikit-learn:main
Choose a base branch
from
jeremiedbb:callback-api
base: main
Could not load branches
Branch not found: {{ refName }}
Could not load tags
Nothing to show
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
[WIP] Callback API continued #22000
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
272e75f
callback API
jeremiedbb 584bdf7
cln nmf and test reconstruction attributes
jeremiedbb bb32ff3
cln snapshot + test snapshot + uuid for computation tree
jeremiedbb 7a1825d
cln
jeremiedbb 3e3b25f
black
jeremiedbb 26dbb69
lint
jeremiedbb eb7b824
wip
jeremiedbb 9b913fd
Merge branch 'master' into callback-api
jeremiedbb f78442e
class
jeremiedbb 34bab15
more tests
jeremiedbb 596a58e
cln
jeremiedbb 4f9363c
wip
jeremiedbb 030f68b
Merge remote-tracking branch 'upstream/main' into callback-api
jeremiedbb 35c5284
wip
jeremiedbb 115e184
wip
jeremiedbb bdb4990
wip
jeremiedbb d1bb5eb
Merge remote-tracking branch 'upstream/main' into callback-api
jeremiedbb 7a43c30
wip
jeremiedbb 573fd5d
Merge remote-tracking branch 'upstream/main' into callback-api
jeremiedbb a218068
wip
jeremiedbb f794694
update poor_score
jeremiedbb ab74f19
Merge remote-tracking branch 'upstream/main' into pr/jeremiedbb/22000
jeremiedbb 37e569b
wip
jeremiedbb d7208fa
wip
jeremiedbb 774ff69
Merge remote-tracking branch 'upstream/main' into pr/jeremiedbb/22000
jeremiedbb b8ac1a5
cln
jeremiedbb File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,6 +84,7 @@ | |
|
||
__all__ = [ | ||
"calibration", | ||
"callback", | ||
"cluster", | ||
"covariance", | ||
"cross_decomposition", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# License: BSD 3 clause | ||
|
||
from ._base import AutoPropagatedMixin | ||
from ._base import BaseCallback | ||
from ._computation_tree import ComputationNode | ||
from ._computation_tree import ComputationTree | ||
from ._computation_tree import load_computation_tree | ||
from ._convergence_monitor import ConvergenceMonitor | ||
from ._early_stopping import EarlyStopping | ||
from ._progressbar import ProgressBar | ||
from ._snapshot import Snapshot | ||
from ._text_verbose import TextVerbose | ||
|
||
__all__ = [ | ||
"AutoPropagatedMixin", | ||
"Basecallback", | ||
"ComputationNode", | ||
"ComputationTree", | ||
"load_computation_tree", | ||
"ConvergenceMonitor", | ||
"EarlyStopping", | ||
"ProgressBar", | ||
"Snapshot", | ||
"TextVerbose", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# License: BSD 3 clause | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
|
||
# Not a method of BaseEstimator because it might be called from an extern function | ||
def _eval_callbacks_on_fit_iter_end(**kwargs): | ||
"""Evaluate the on_fit_iter_end method of the callbacks | ||
|
||
This function should be called at the end of each computation node. | ||
|
||
Parameters | ||
---------- | ||
kwargs : dict | ||
arguments passed to the callback. | ||
|
||
Returns | ||
------- | ||
stop : bool | ||
Whether or not to stop the fit at this node. | ||
""" | ||
estimator = kwargs.get("estimator") | ||
node = kwargs.get("node") | ||
|
||
if not hasattr(estimator, "_callbacks") or node is None: | ||
return False | ||
|
||
estimator._computation_tree._tree_status[node.tree_status_idx] = True | ||
|
||
# stopping_criterion and reconstruction_attributes can be costly to compute. They | ||
# are passed as lambdas for lazy evaluation. We only actually compute them if a | ||
# callback requests it. | ||
if any( | ||
getattr(callback, "request_stopping_criterion", False) | ||
for callback in estimator._callbacks | ||
): | ||
kwarg = kwargs.pop("stopping_criterion", lambda: None)() | ||
kwargs["stopping_criterion"] = kwarg | ||
|
||
if any( | ||
getattr(callback, "request_reconstruction_attributes", False) | ||
for callback in estimator._callbacks | ||
): | ||
kwarg = kwargs.pop("reconstruction_attributes", lambda: None)() | ||
kwargs["reconstruction_attributes"] = kwarg | ||
|
||
return any(callback.on_fit_iter_end(**kwargs) for callback in estimator._callbacks) | ||
|
||
|
||
class BaseCallback(ABC): | ||
"""Abstract class for the callbacks""" | ||
|
||
@abstractmethod | ||
def on_fit_begin(self, estimator, *, X=None, y=None): | ||
"""Method called at the beginning of the fit method of the estimator | ||
|
||
Parameters | ||
---------- | ||
estimator: estimator instance | ||
The estimator the callback is set on. | ||
X: ndarray or sparse matrix, default=None | ||
The training data. | ||
y: ndarray, default=None | ||
The target. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def on_fit_end(self): | ||
"""Method called at the end of the fit method of the estimator""" | ||
pass | ||
|
||
@abstractmethod | ||
def on_fit_iter_end(self, estimator, node, **kwargs): | ||
"""Method called at the end of each computation node of the estimator | ||
|
||
Parameters | ||
---------- | ||
estimator : estimator instance | ||
The caller estimator. It might differ from the estimator passed to the | ||
`on_fit_begin` method for auto-propagated callbacks. | ||
|
||
node : ComputationNode instance | ||
The caller computation node. | ||
|
||
kwargs : dict | ||
arguments passed to the callback. Possible keys are | ||
|
||
- stopping_criterion: float | ||
Usually iterations stop when `stopping_criterion <= tol`. | ||
This is only provided at the innermost level of iterations. | ||
|
||
- tol: float | ||
Tolerance for the stopping criterion. | ||
This is only provided at the innermost level of iterations. | ||
|
||
- reconstruction_attributes: dict | ||
Necessary attributes to construct an estimator (by copying this | ||
estimator and setting these as attributes) which will behave as if | ||
the fit stopped at this node. | ||
This is only provided at the outermost level of iterations. | ||
|
||
- fit_state: dict | ||
Model specific quantities updated during fit. This is not meant to be | ||
used by generic callbacks but by a callback designed for a specific | ||
estimator instead. | ||
|
||
Returns | ||
------- | ||
stop : bool or None | ||
Whether or not to stop the current level of iterations at this node. | ||
""" | ||
pass | ||
|
||
|
||
class AutoPropagatedMixin: | ||
"""Mixin for auto-propagated callbacks | ||
|
||
An auto-propagated callback (from a meta-estimator to its sub-estimators) must be | ||
set on the meta-estimator. Its `on_fit_begin` and `on_fit_end` methods will only be | ||
called at the beginning and end of the fit method of the meta-estimator, while its | ||
`on_fit_iter_end` method will be called at each computation node of the | ||
meta-estimator and its sub-estimators. | ||
""" | ||
|
||
pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be too magical to have only something call
_eval_callbacks_begin
and inspect internal the stack of calls to infer which method called this function.Of course it would make sense only if the same methods are expected to be called for fit/predict/etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well that's not obvious at all and I've not really thought about that. This first iteration is all about fit. I think it will be easier to not try to be too magical for now