Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callbacks API #16925

Closed
wants to merge 14 commits into from
75 changes: 68 additions & 7 deletions doc/developers/develop.rst
Expand Up @@ -5,9 +5,9 @@ Developing scikit-learn estimators
==================================

Whether you are proposing an estimator for inclusion in scikit-learn,
developing a separate package compatible with scikit-learn, or
implementing custom components for your own projects, this chapter
details how to develop objects that safely interact with scikit-learn
developing a separate package compatible with scikit-learn, or
implementing custom components for your own projects, this chapter
details how to develop objects that safely interact with scikit-learn
Pipelines and model selection tools.

.. currentmodule:: sklearn
Expand Down Expand Up @@ -576,10 +576,10 @@ closed-form solutions.
Coding guidelines
=================

The following are some guidelines on how new code should be written for
inclusion in scikit-learn, and which may be appropriate to adopt in external
projects. Of course, there are special cases and there will be exceptions to
these rules. However, following these rules when submitting new code makes
The following are some guidelines on how new code should be written for
inclusion in scikit-learn, and which may be appropriate to adopt in external
projects. Of course, there are special cases and there will be exceptions to
these rules. However, following these rules when submitting new code makes
the review easier so new code can be integrated in less time.

Uniformly formatted code makes it easier to share code ownership. The
Expand Down Expand Up @@ -709,3 +709,64 @@ The reason for this setup is reproducibility:
when an estimator is ``fit`` twice to the same data,
it should produce an identical model both times,
hence the validation in ``fit``, not ``__init__``.

Estimator callbacks
===================

To add (optional) support of callbacks, for instance to support progress
bars or monitoring convergence, the estimator must implement the following
points:

- At the beginning of ``fit`` either explicitly call ``self._fit_callbacks(X,
y)`` or use ``self._validate_data(X, y)`` which
makes a ``self._fit_callbacks`` call internally.
- For iterative solvers call ``self._eval_callbacks(n_iter=.., **kwargs)`` at
each iteration, where ``kwargs`` keys must be part of supported callback
arguments (cf. list below).

User defined callbacks must extend the ``sklearn._callbacks.BaseCallback``
absract base class. For instance some callbacks are implemented in the
`sklearn-callbacks <https://github.com/rth/sklearn-callbacks>`_ package
and can be used as follows,

.. code::

from sklearn.linear_model import LogisticRegression
from sklearn_callbacks import ProgressBar

est = LogisticRegression()
pbar = ProgressBar()
est._set_callbacks(pbar)

est.fit(X, y) # will display a progress bar


**Callback arguments**

Following input parameters are supported:

n_iter, int
current iteration number for iterative solvers.

max_iter, int
maximum number of iterations for iterative solvers. If the estimator
has a ``max_iter`` init parameter, this will be inferred.

loss, float or ordered dict
cost function value or error at a given iteration. When ordered dict,
multiple loss functions can given, with the default loss being the first
element. Lower is better.

score, float or ordered dict
same as ``loss`` parameter, but for evaluation metrics. Higher is better.

validation_loss, float or ordered dict
cost function value or error at a given iteration, evaluated on the
validation set.

validation_score, float or ordered dict
same as ``validation_loss`` parameter, but for evaluation metrics. Higher is
better.

coef: ndarray
coefficients of linear models.
3 changes: 3 additions & 0 deletions pyproject.toml
Expand Up @@ -13,3 +13,6 @@ requires = [
"numpy==1.17.3; python_version>='3.8' and platform_system=='AIX'",
"scipy>=0.19.1",
]

[tool.black]
line-length = 79
Comment on lines +17 to +18
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sneaky :p

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't see the point in manually formatting code anymore for new files. It shouldn't hurt even if we are not using everywhere..

58 changes: 58 additions & 0 deletions sklearn/_callbacks.py
@@ -0,0 +1,58 @@
# License: BSD 3 clause
from typing import List, Callable
from abc import ABC, abstractmethod

import numpy as np

CALLBACK_PARAM_TYPES = {
"n_iter": int,
"max_iter": int,
"loss": (float, dict),
"score": (float, dict),
"validation_loss": (float, dict),
"validation_score": (float, dict),
"coef": np.ndarray,
"intercept": (np.ndarray, float)
}


def _check_callback_params(**kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we let each callback independently validate its data?

My question might not make sense but I don't see this being used anywhere except in the tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we let each callback independently validate its data?
My question might not make sense but I don't see this being used anywhere except in the tests

Yes, absolutely each callback validates its data. But we also need to enforce that callbacks do follow the documented API in tests. For instance, that no undocumented parameters are passed etc which requires this function.

Third party callbacks could also use this validations function, similarly to how we expose check_array.

invalid_params = []
invalid_types = []
for key, val in kwargs.items():
if key not in CALLBACK_PARAM_TYPES:
invalid_params.append(key)
else:
val_types = CALLBACK_PARAM_TYPES[key]
if not isinstance(val, val_types):
invalid_types.append(
f"{key}={val} is not of type {val_types}"
)
msg = ""
if invalid_params:
msg += ("Invalid callback parameters: {}, must be one of {}. ").format(
", ".join(invalid_params),
", ".join(CALLBACK_PARAM_TYPES.keys())
)
if invalid_types:
msg += "Invalid callback parameters: " + ", ".join(invalid_types)
if msg:
raise ValueError(msg)


def _eval_callbacks(callbacks: List[Callable], **kwargs) -> None:
if callbacks is None:
return
Comment on lines +43 to +44
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we rely on callbacks being an empty list instead of subcasing with None?

Or maybe you are anticipating a future where callbacks=None would be a default argument to estimators?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we switched to the case when callbacks=None when missing.


for callback in callbacks:
callback(**kwargs)


class BaseCallback(ABC):
@abstractmethod
def fit(self, estimator, X, y) -> None:
pass

@abstractmethod
def __call__(self, **kwargs) -> None:
pass
44 changes: 44 additions & 0 deletions sklearn/base.py
Expand Up @@ -84,6 +84,11 @@ def clone(estimator, *, safe=True):
new_object = klass(**new_object_params)
params_set = new_object.get_params(deep=False)

# copy callbacks
if hasattr(estimator, "_callbacks"):
# TODO: do we need to use the recusive setter here?
new_object._callbacks = estimator._callbacks

# quick sanity check of the parameters of the clone
for name in new_object_params:
param1 = new_object_params[name]
Expand Down Expand Up @@ -406,6 +411,7 @@ def _validate_data(self, X, y=None, reset=True,
out : {ndarray, sparse matrix} or tuple of these
The validated input. A tuple is returned if `y` is not None.
"""
self._fit_callbacks(X, y)

if y is None:
if self._get_tags()['requires_y']:
Expand Down Expand Up @@ -433,6 +439,44 @@ def _validate_data(self, X, y=None, reset=True,

return out

def _set_callbacks(self, callbacks):
"""Set callbacks for the estimator.

In the case of meta-estmators, callbacks are also set recursively
for all child estimators.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on doing this vs letting users set callbacks on sub-estimator instances?

what about e.g. early stopping when we ultimately support this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a deep=True option to allow disabling recursion for meta-estimators, which can certainly be useful in some cases.

In most cases though, I don't see users manually setting callbacks for each individual estimator in a complex pipeline..

"""
from sklearn._callbacks import BaseCallback
if isinstance(callbacks, BaseCallback):
self._callbacks = [callbacks]
else:
self._callbacks = callbacks

for attr_name in getattr(self, "_required_parameters", []):
# likely a meta-estimator
if attr_name in ['steps', 'transformers']:
for attr in getattr(self, attr_name):
if isinstance(attr, BaseEstimator):
attr._set_callbacks(callbacks)
elif (hasattr(attr, '__len__')
and len(attr) >= 2
and isinstance(attr[1], BaseEstimator)):
attr[1]._set_callbacks(callbacks)

def _fit_callbacks(self, X, y):
"""Send the signal to callbacks that the estimator is being fitted"""
callbacks = getattr(self, '_callbacks', [])

for callback in callbacks:
callback.fit(self, X, y)

def _eval_callbacks(self, **kwargs):
"""Call callbacks, e.g. in each iteration of an iterative solver"""
from ._callbacks import _eval_callbacks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why lazy import?


callbacks = getattr(self, '_callbacks', [])

_eval_callbacks(callbacks)

@property
def _repr_html_(self):
"""HTML representation of estimator.
Expand Down
1 change: 1 addition & 0 deletions sklearn/compose/_column_transformer.py
Expand Up @@ -516,6 +516,7 @@ def fit_transform(self, X, y=None):
sparse matrices.

"""
self._fit_callbacks(X, y)
# TODO: this should be `feature_names_in_` when we start having it
if hasattr(X, "columns"):
self._feature_names_in = np.asarray(X.columns)
Expand Down
1 change: 1 addition & 0 deletions sklearn/decomposition/_factor_analysis.py
Expand Up @@ -236,6 +236,7 @@ def my_svd(X):
old_ll = ll

psi = np.maximum(var - np.sum(W ** 2, axis=0), SMALL)
self._eval_callbacks(n_iter=i)
else:
warnings.warn('FactorAnalysis did not converge.' +
' You might want' +
Expand Down
6 changes: 4 additions & 2 deletions sklearn/decomposition/_incremental_pca.py
Expand Up @@ -206,12 +206,14 @@ def fit(self, X, y=None):
else:
self.batch_size_ = self.batch_size

for batch in gen_batches(n_samples, self.batch_size_,
min_batch_size=self.n_components or 0):
for n_batch, batch in enumerate(
gen_batches(n_samples, self.batch_size_,
min_batch_size=self.n_components or 0)):
X_batch = X[batch]
if sparse.issparse(X_batch):
X_batch = X_batch.toarray()
self.partial_fit(X_batch, check_input=False)
self._eval_callbacks(n_iter=n_batch)

return self

Expand Down
1 change: 1 addition & 0 deletions sklearn/decomposition/_lda.py
Expand Up @@ -464,6 +464,7 @@ def _em_step(self, X, total_samples, batch_update, parallel=None):
self.exp_dirichlet_component_ = np.exp(
_dirichlet_expectation_2d(self.components_))
self.n_batch_iter_ += 1
self._eval_callbacks()
return

def _more_tags(self):
Expand Down
24 changes: 17 additions & 7 deletions sklearn/decomposition/_nmf.py
Expand Up @@ -19,6 +19,7 @@
from ..utils import check_random_state, check_array
from ..utils.extmath import randomized_svd, safe_sparse_dot, squared_norm
from ..utils.validation import check_is_fitted, check_non_negative
from .._callbacks import _eval_callbacks
from ..utils.validation import _deprecate_positional_args

EPSILON = np.finfo(np.float32).eps
Expand Down Expand Up @@ -426,7 +427,8 @@ def _update_coordinate_descent(X, W, Ht, l1_reg, l2_reg, shuffle,

def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0,
l1_reg_H=0, l2_reg_W=0, l2_reg_H=0, update_H=True,
verbose=0, shuffle=False, random_state=None):
verbose=0, shuffle=False, random_state=None,
callbacks=None):
"""Compute Non-negative Matrix Factorization (NMF) with Coordinate Descent

The objective function is minimized with an alternating minimization of W
Expand Down Expand Up @@ -522,6 +524,10 @@ def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0,
if verbose:
print("violation:", violation / violation_init)

_eval_callbacks(callbacks, n_iter=n_iter,
tol=violation/violation_init,
error=violation)

if violation / violation_init <= tol:
if verbose:
print("Converged at iteration", n_iter + 1)
Expand Down Expand Up @@ -710,7 +716,7 @@ def _multiplicative_update_h(X, W, H, beta_loss, l1_reg_H, l2_reg_H, gamma):
def _fit_multiplicative_update(X, W, H, beta_loss='frobenius',
max_iter=200, tol=1e-4,
l1_reg_W=0, l1_reg_H=0, l2_reg_W=0, l2_reg_H=0,
update_H=True, verbose=0):
update_H=True, verbose=0, callbacks=None):
"""Compute Non-negative Matrix Factorization with Multiplicative Update

The objective function is _beta_divergence(X, WH) and is minimized with an
Expand Down Expand Up @@ -828,6 +834,9 @@ def _fit_multiplicative_update(X, W, H, beta_loss='frobenius',
print("Epoch %02d reached after %.3f seconds, error: %f" %
(n_iter, iter_time - start_time, error))

_eval_callbacks(callbacks, n_iter=n_iter, error=error,
tol=(previous_error - error) / error_at_init)

if (previous_error - error) / error_at_init < tol:
break
previous_error = error
Expand All @@ -847,7 +856,7 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, *,
beta_loss='frobenius', tol=1e-4,
max_iter=200, alpha=0., l1_ratio=0.,
regularization=None, random_state=None,
verbose=0, shuffle=False):
verbose=0, shuffle=False, callbacks=None):
r"""Compute Non-negative Matrix Factorization (NMF)

Find two non-negative matrices (W, H) whose product approximates the non-
Expand Down Expand Up @@ -1062,12 +1071,13 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, *,
update_H=update_H,
verbose=verbose,
shuffle=shuffle,
random_state=random_state)
random_state=random_state,
callbacks=callbacks)
elif solver == 'mu':
W, H, n_iter = _fit_multiplicative_update(X, W, H, beta_loss, max_iter,
tol, l1_reg_W, l1_reg_H,
l2_reg_W, l2_reg_H, update_H,
verbose)
verbose, callbacks=callbacks)

else:
raise ValueError("Invalid solver parameter '%s'." % solver)
Expand Down Expand Up @@ -1286,7 +1296,7 @@ def fit_transform(self, X, y=None, W=None, H=None):
tol=self.tol, max_iter=self.max_iter, alpha=self.alpha,
l1_ratio=self.l1_ratio, regularization='both',
random_state=self.random_state, verbose=self.verbose,
shuffle=self.shuffle)
shuffle=self.shuffle, callbacks=getattr(self, "_callbacks", []))

self.reconstruction_err_ = _beta_divergence(X, W, H, self.beta_loss,
square_root=True)
Expand Down Expand Up @@ -1335,7 +1345,7 @@ def transform(self, X):
beta_loss=self.beta_loss, tol=self.tol, max_iter=self.max_iter,
alpha=self.alpha, l1_ratio=self.l1_ratio, regularization='both',
random_state=self.random_state, verbose=self.verbose,
shuffle=self.shuffle)
shuffle=self.shuffle, callbacks=getattr(self, '_callbacks', []))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should either decide that

  • no callbacks means an empty list
  • no callabacks means None and having callbacks means a non-empty list

but it seems that the code is mixing both right now?

If we ultimately plan on having callbacks=None as a default param then the latter would be more appropriate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right let's go with callbacks=None everywhere, and just let the eval function handle it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, here we would still need to do provide the default option to getattrs: getattr(self, '_callbacks', None) and that's not really much more readable than getattr(self, '_callbacks', []) so both would work..


return W

Expand Down
15 changes: 11 additions & 4 deletions sklearn/decomposition/_pca.py
Expand Up @@ -491,6 +491,8 @@ def _fit_full(self, X, n_components):
explained_variance_ratio_[:n_components]
self.singular_values_ = singular_values_[:n_components]

self._eval_callbacks()

return U, S, Vt

def _fit_truncated(self, X, n_components, svd_solver):
Expand Down Expand Up @@ -537,12 +539,17 @@ def _fit_truncated(self, X, n_components, svd_solver):
# flip eigenvectors' sign to enforce deterministic output
U, Vt = svd_flip(U[:, ::-1], Vt[::-1])

self._eval_callbacks()

elif svd_solver == 'randomized':
# sign flipping is done inside
U, S, Vt = randomized_svd(X, n_components=n_components,
n_iter=self.iterated_power,
flip_sign=True,
random_state=random_state)
U, S, Vt = randomized_svd(
X, n_components=n_components,
n_iter=self.iterated_power,
flip_sign=True,
random_state=random_state,
callbacks=getattr(self, '_callbacks', [])
)

self.n_samples_, self.n_features_ = n_samples, n_features
self.components_ = Vt
Expand Down