New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Callbacks API #16925
Callbacks API #16925
Changes from 9 commits
b98f4c1
1fcc07e
7365435
0c5cfaa
7732ece
f236274
69e7255
0e78233
40135f1
9cf1272
3ce8771
6012f6d
1c5cd13
6ba0fe9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# License: BSD 3 clause | ||
from typing import List, Callable | ||
from abc import ABC, abstractmethod | ||
|
||
import numpy as np | ||
|
||
CALLBACK_PARAM_TYPES = { | ||
"n_iter": int, | ||
"max_iter": int, | ||
"loss": (float, dict), | ||
"score": (float, dict), | ||
"validation_loss": (float, dict), | ||
"validation_score": (float, dict), | ||
"coef": np.ndarray, | ||
"intercept": (np.ndarray, float) | ||
} | ||
|
||
|
||
def _check_callback_params(**kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we let each callback independently validate its data? My question might not make sense but I don't see this being used anywhere except in the tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, absolutely each callback validates its data. But we also need to enforce that callbacks do follow the documented API in tests. For instance, that no undocumented parameters are passed etc which requires this function. Third party callbacks could also use this validations function, similarly to how we expose |
||
invalid_params = [] | ||
invalid_types = [] | ||
for key, val in kwargs.items(): | ||
if key not in CALLBACK_PARAM_TYPES: | ||
invalid_params.append(key) | ||
else: | ||
val_types = CALLBACK_PARAM_TYPES[key] | ||
if not isinstance(val, val_types): | ||
invalid_types.append( | ||
f"{key}={val} is not of type {val_types}" | ||
) | ||
msg = "" | ||
if invalid_params: | ||
msg += ("Invalid callback parameters: {}, must be one of {}. ").format( | ||
", ".join(invalid_params), | ||
", ".join(CALLBACK_PARAM_TYPES.keys()) | ||
) | ||
if invalid_types: | ||
msg += "Invalid callback parameters: " + ", ".join(invalid_types) | ||
if msg: | ||
raise ValueError(msg) | ||
|
||
|
||
def _eval_callbacks(callbacks: List[Callable], **kwargs) -> None: | ||
if callbacks is None: | ||
return | ||
Comment on lines
+43
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we rely on Or maybe you are anticipating a future where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we switched to the case when |
||
|
||
for callback in callbacks: | ||
callback(**kwargs) | ||
|
||
|
||
class BaseCallback(ABC): | ||
@abstractmethod | ||
def fit(self, estimator, X, y) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def __call__(self, **kwargs) -> None: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,6 +84,11 @@ def clone(estimator, *, safe=True): | |
new_object = klass(**new_object_params) | ||
params_set = new_object.get_params(deep=False) | ||
|
||
# copy callbacks | ||
if hasattr(estimator, "_callbacks"): | ||
# TODO: do we need to use the recusive setter here? | ||
new_object._callbacks = estimator._callbacks | ||
|
||
# quick sanity check of the parameters of the clone | ||
for name in new_object_params: | ||
param1 = new_object_params[name] | ||
|
@@ -406,6 +411,7 @@ def _validate_data(self, X, y=None, reset=True, | |
out : {ndarray, sparse matrix} or tuple of these | ||
The validated input. A tuple is returned if `y` is not None. | ||
""" | ||
self._fit_callbacks(X, y) | ||
|
||
if y is None: | ||
if self._get_tags()['requires_y']: | ||
|
@@ -433,6 +439,44 @@ def _validate_data(self, X, y=None, reset=True, | |
|
||
return out | ||
|
||
def _set_callbacks(self, callbacks): | ||
"""Set callbacks for the estimator. | ||
|
||
In the case of meta-estmators, callbacks are also set recursively | ||
for all child estimators. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thoughts on doing this vs letting users set callbacks on sub-estimator instances? what about e.g. early stopping when we ultimately support this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a In most cases though, I don't see users manually setting callbacks for each individual estimator in a complex pipeline.. |
||
""" | ||
from sklearn._callbacks import BaseCallback | ||
if isinstance(callbacks, BaseCallback): | ||
self._callbacks = [callbacks] | ||
else: | ||
self._callbacks = callbacks | ||
|
||
for attr_name in getattr(self, "_required_parameters", []): | ||
# likely a meta-estimator | ||
if attr_name in ['steps', 'transformers']: | ||
for attr in getattr(self, attr_name): | ||
if isinstance(attr, BaseEstimator): | ||
attr._set_callbacks(callbacks) | ||
elif (hasattr(attr, '__len__') | ||
and len(attr) >= 2 | ||
and isinstance(attr[1], BaseEstimator)): | ||
attr[1]._set_callbacks(callbacks) | ||
|
||
def _fit_callbacks(self, X, y): | ||
"""Send the signal to callbacks that the estimator is being fitted""" | ||
callbacks = getattr(self, '_callbacks', []) | ||
|
||
for callback in callbacks: | ||
callback.fit(self, X, y) | ||
|
||
def _eval_callbacks(self, **kwargs): | ||
"""Call callbacks, e.g. in each iteration of an iterative solver""" | ||
from ._callbacks import _eval_callbacks | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why lazy import? |
||
|
||
callbacks = getattr(self, '_callbacks', []) | ||
|
||
_eval_callbacks(callbacks) | ||
|
||
@property | ||
def _repr_html_(self): | ||
"""HTML representation of estimator. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
from ..utils import check_random_state, check_array | ||
from ..utils.extmath import randomized_svd, safe_sparse_dot, squared_norm | ||
from ..utils.validation import check_is_fitted, check_non_negative | ||
from .._callbacks import _eval_callbacks | ||
from ..utils.validation import _deprecate_positional_args | ||
|
||
EPSILON = np.finfo(np.float32).eps | ||
|
@@ -426,7 +427,8 @@ def _update_coordinate_descent(X, W, Ht, l1_reg, l2_reg, shuffle, | |
|
||
def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0, | ||
l1_reg_H=0, l2_reg_W=0, l2_reg_H=0, update_H=True, | ||
verbose=0, shuffle=False, random_state=None): | ||
verbose=0, shuffle=False, random_state=None, | ||
callbacks=None): | ||
"""Compute Non-negative Matrix Factorization (NMF) with Coordinate Descent | ||
|
||
The objective function is minimized with an alternating minimization of W | ||
|
@@ -522,6 +524,10 @@ def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0, | |
if verbose: | ||
print("violation:", violation / violation_init) | ||
|
||
_eval_callbacks(callbacks, n_iter=n_iter, | ||
tol=violation/violation_init, | ||
error=violation) | ||
|
||
if violation / violation_init <= tol: | ||
if verbose: | ||
print("Converged at iteration", n_iter + 1) | ||
|
@@ -710,7 +716,7 @@ def _multiplicative_update_h(X, W, H, beta_loss, l1_reg_H, l2_reg_H, gamma): | |
def _fit_multiplicative_update(X, W, H, beta_loss='frobenius', | ||
max_iter=200, tol=1e-4, | ||
l1_reg_W=0, l1_reg_H=0, l2_reg_W=0, l2_reg_H=0, | ||
update_H=True, verbose=0): | ||
update_H=True, verbose=0, callbacks=None): | ||
"""Compute Non-negative Matrix Factorization with Multiplicative Update | ||
|
||
The objective function is _beta_divergence(X, WH) and is minimized with an | ||
|
@@ -828,6 +834,9 @@ def _fit_multiplicative_update(X, W, H, beta_loss='frobenius', | |
print("Epoch %02d reached after %.3f seconds, error: %f" % | ||
(n_iter, iter_time - start_time, error)) | ||
|
||
_eval_callbacks(callbacks, n_iter=n_iter, error=error, | ||
tol=(previous_error - error) / error_at_init) | ||
|
||
if (previous_error - error) / error_at_init < tol: | ||
break | ||
previous_error = error | ||
|
@@ -847,7 +856,7 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, *, | |
beta_loss='frobenius', tol=1e-4, | ||
max_iter=200, alpha=0., l1_ratio=0., | ||
regularization=None, random_state=None, | ||
verbose=0, shuffle=False): | ||
verbose=0, shuffle=False, callbacks=None): | ||
r"""Compute Non-negative Matrix Factorization (NMF) | ||
|
||
Find two non-negative matrices (W, H) whose product approximates the non- | ||
|
@@ -1062,12 +1071,13 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, *, | |
update_H=update_H, | ||
verbose=verbose, | ||
shuffle=shuffle, | ||
random_state=random_state) | ||
random_state=random_state, | ||
callbacks=callbacks) | ||
elif solver == 'mu': | ||
W, H, n_iter = _fit_multiplicative_update(X, W, H, beta_loss, max_iter, | ||
tol, l1_reg_W, l1_reg_H, | ||
l2_reg_W, l2_reg_H, update_H, | ||
verbose) | ||
verbose, callbacks=callbacks) | ||
|
||
else: | ||
raise ValueError("Invalid solver parameter '%s'." % solver) | ||
|
@@ -1286,7 +1296,7 @@ def fit_transform(self, X, y=None, W=None, H=None): | |
tol=self.tol, max_iter=self.max_iter, alpha=self.alpha, | ||
l1_ratio=self.l1_ratio, regularization='both', | ||
random_state=self.random_state, verbose=self.verbose, | ||
shuffle=self.shuffle) | ||
shuffle=self.shuffle, callbacks=getattr(self, "_callbacks", [])) | ||
|
||
self.reconstruction_err_ = _beta_divergence(X, W, H, self.beta_loss, | ||
square_root=True) | ||
|
@@ -1335,7 +1345,7 @@ def transform(self, X): | |
beta_loss=self.beta_loss, tol=self.tol, max_iter=self.max_iter, | ||
alpha=self.alpha, l1_ratio=self.l1_ratio, regularization='both', | ||
random_state=self.random_state, verbose=self.verbose, | ||
shuffle=self.shuffle) | ||
shuffle=self.shuffle, callbacks=getattr(self, '_callbacks', [])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like we should either decide that
but it seems that the code is mixing both right now? If we ultimately plan on having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you are right let's go with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, here we would still need to do provide the default option to |
||
|
||
return W | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sneaky :p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I don't see the point in manually formatting code anymore for new files. It shouldn't hurt even if we are not using everywhere..