Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Callback API continued #22000

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

jeremiedbb
Copy link
Member

@jeremiedbb jeremiedbb commented Dec 16, 2021

Fixes #78 #7574 #10973
Continuation of the work started in #16925 by @rth.

Goal

The goal of this PR is to propose a callback API that can handle the most important / asked use cases.

Challenges

  • Supporting all these features and make each of these callbacks available is not easy and will require some refactoring in probably many estimators.

    The proposed API makes it possible to enable the callbacks 1 estimator at a time: Setting callbacks on non yet supported estimators has no effect. Thus we can then incrementally do it in subsequent dedicated PRs. Here I only did NMF, LogisticRegression and Pipeline to show what are the necessary changes in the code base.

    The proposed API also makes it possible to only enable a subset of the features for an estimator, and add the remaining ones later. For LogisticRegression I only passed the minimum for instance.

  • Callbacks should not impact the performance of the estimators. Some quantities passed to the callbacks might be costly to compute. We don't want to spend time computing them if the only callback is a progress bar for instance.

    The solution I found is to do a lazy evaluation using lambdas and only actually compute them if there's at least 1 callback requesting it. For now callbacks can request these by defining specific class attributes but maybe there's a better way. mixins ?

  • The callbacks described above are not meant to be evaluated a the same fitting step of an estimator.

    When an estimator has several nested loops (LogisticRegressionCV(multiclass="ovr") for instance has a loop over Cs, a loop over the classes and then the final loop for the iterations on the dataset), the snapshot callback can only be evaluated at the end of an outermost loop while the EarlyStopping would be evaluated at the end of an innermost loop, and the ProgressBar could be evaluated at each level of nesting.

    In this PR I propose that each estimator holds a computation tree as a private attribute representing these nested loops, the root being the beginning of fit and each node being one step of a loop. This structure is defined in _computation_tree.py. It allows to have a simple way to know exactly at which step of the fit we are at each evaluation of the callbacks and is kind of the best solution I found to solve the challenges described below. This imposes the main changes to the code base, i.e. passing the parent node around.

  • Dealing with parallelism and especially multiprocessing is the main challenge to me.

    Typically with a callback you might want to accumulate a bunch of info during fit and recover them at the end. The issue is that the callback is not shared between sub-processes and modifying its state in a sub-process (e.g. modifying an attribute) will not be visible from the main process. The joblib API doesn't allow inter-process communication that would be needed to overcome this.

    The solution we found is that the callbacks write the information they want to keep in files (in files in this first implementation but we might consider sockets or another solution ?). It's relatively easy to avoid race conditions with this design.
    As an example this is necessary to be able to report progress in real time. In an estimator running in parallel, there's no like current computation node. We are at different nodes at the same time. But having the status of each node in a file updated at each call to the callbacks allows to know the current overall progress from the main process. (there are other difficulties described later).

  • The last main challenge is meta-estimators. We'd like some callbacks to be set on the meta-estimator, like progress bars, but some others to be set on the underlying estimator(s), like early-stopping. Moreover we encounter the parallelism issue again if the meta-estimator is fitting clones of the underlying estimator in parallel, like GS.

    For that, I propose to have a mixin to tell a callback that it should be propagated to sub estimators. This way the meta-estimator will only propagate the appropriate callbacks to its sub-estimators, and these sub-estimators can also have normal callbacks.

The API

This PR adds a new module sklearn.callback which exposes BaseCallback, the abstract base class for the callbacks. All callbacks must inherit from BaseCallback. It also exposes AutoPropagatedMixin. Callbacks that should be propagated to sub-estimators by meta-estimators must inherit from this.

BaseCallback has 3 abstract methods:

  • on_fit_end. Called at the beginning of fit, after all validations. We pass a reference to the estimator, X_train and y_train.
  • on_fit_iter_end. Called at the end of each node of the computation tree, i.e. each step of each nested loop. We pass a reference to the estimator (which at this point might be different from the one passed at in_fit_begin for propagated callbacks), and the computation node where it was called. We also pass some of these:
    • stopping_criterion: when the estimator has a stopping criterion such that the iterations stop when stopping_criterion <= tol.
    • tol: tolerance for the stopping criterion.
    • reconstruction_attributes. These are the necessary attributes to construct an estimator (by copying the estimator and setting these as attributes) which will behave as if the fit stopped at this node. Then we must be able to call predict, transform, ...
    • fit_state: Model specific quantities updated during fit. This is not meant to be used by generic callbacks but by a callback designed for a specific estimator instead. This arg is not used in any of the use cases described above but I thinkit's important to have for custom callbacks. It's the role of each estimator to decide what is interesting to pass to the callback. We could later think of new field in the docstring of the estimators to describe what keys they pass to this arg.
  • on_fit_end. Called at the end of fit. Takes no argument. It allows the callback to do some clean-up.

Examples

  • Progress bars.

    expand

    Here's an example of progress monitoring using rich. I used custom estimators to simulate a complex setting with a meta-estimator (like a GridSearchCV) running in parallel with a sub-estimator also running in parallel.

    simplescreenrecorder-2021-12-16_19.17.35.mp4
  • Convergence Monitoring

    expand
    from sklearn.decomposition import NMF
    import numpy as np
    X = np.random.random_sample((1100, 100))
    X_val = X[-100:]
    nmf = NMF(n_components=20, solver="mu")
    callback = ConvergenceMonitor(X_val=X_val)
    nmf._set_callbacks(callback)
    nmf.fit(X[:1000])
    callback.plot()

    callback_ex1

  • Snapshot

    expand
    from sklearn.decomposition import NMF
    import numpy as np
    X = np.random.random_sample((1100, 100))
    nmf = NMF(n_components=20, solver="mu")
    callback = Snapshot()
    nmf._set_callbacks(callback)
    nmf.fit(X[:1000])
    # interrupt fit. Ctrl-C for instance
    # [...]
    KeyboardInterrupt:
    
    import pickle
    with open(callback.directory / "2021-12-16_19-33-15-083014.pkl", "rb") as f:
        new_nmf = pickle.load(f)
    W = new_nmf.transform(X[-100:])
  • EarlyStopping

    expand

    If the on_fit_iter_end method of the callbacks returns True, the iteration loop breaks.

    from sklearn.decomposition import NMF
    import numpy as np
    X = np.random.random_sample((1100, 100))
    X_val = X[-100:]
    nmf = NMF(n_components=20, solver="mu")
    callback = EarlyStopping(monitor="objective_function", X_val=X_val, max_no_improvement=10, tol=1e-4)
    nmf._set_callbacks(callback)
    nmf.fit(X[:1000])
  • Verbose

    expand
    from sklearn.decomposition import NMF
    import numpy as np
    X = np.random.random_sample((1100, 100))
    nmf = NMF(n_components=20, solver="mu", max_iter=20)
    nmf._set_callbacks(TextVerbose())
    nmf.fit(X)
    [NMF] iter 0 | time 0.02493s | stopping_criterion=8.730E-01 | tol=1.000E-04
    [NMF] iter 1 | time 0.02634s | stopping_criterion=8.737E-01 | tol=1.000E-04
    [NMF] iter 2 | time 0.02768s | stopping_criterion=8.743E-01 | tol=1.000E-04
    [NMF] iter 3 | time 0.02893s | stopping_criterion=8.749E-01 | tol=1.000E-04
    [NMF] iter 4 | time 0.03016s | stopping_criterion=8.755E-01 | tol=1.000E-04
    [NMF] iter 5 | time 0.03136s | stopping_criterion=8.760E-01 | tol=1.000E-04
    [NMF] iter 6 | time 0.03255s | stopping_criterion=8.766E-01 | tol=1.000E-04
    [NMF] iter 7 | time 0.03375s | stopping_criterion=8.772E-01 | tol=1.000E-04
    [NMF] iter 8 | time 0.03496s | stopping_criterion=8.777E-01 | tol=1.000E-04
    [NMF] iter 9 | time 0.03691s | stopping_criterion=8.782E-01 | tol=1.000E-04
    [NMF] iter 10 | time 0.03841s | stopping_criterion=5.307E-04 | tol=1.000E-04
    [NMF] iter 11 | time 0.03966s | stopping_criterion=1.049E-03 | tol=1.000E-04
    [NMF] iter 12 | time 0.04087s | stopping_criterion=1.552E-03 | tol=1.000E-04
    [NMF] iter 13 | time 0.04209s | stopping_criterion=2.036E-03 | tol=1.000E-04
    [NMF] iter 14 | time 0.04327s | stopping_criterion=2.498E-03 | tol=1.000E-04
    [NMF] iter 15 | time 0.04447s | stopping_criterion=2.936E-03 | tol=1.000E-04
    [NMF] iter 16 | time 0.04565s | stopping_criterion=3.349E-03 | tol=1.000E-04
    [NMF] iter 17 | time 0.04686s | stopping_criterion=3.734E-03 | tol=1.000E-04
    [NMF] iter 18 | time 0.04804s | stopping_criterion=4.093E-03 | tol=1.000E-04
    [NMF] iter 19 | time 0.04923s | stopping_criterion=4.425E-03 | tol=1.000E-04 
    

TODO

This PR is still WIP.

  • It's missing all the documentation of the callback module to describe the API and how to use and write callbacks and an example.
  • I started adding tests for the computation tree but we need more, and I still need to add test for the callback api and tests for each of the implemented callbacks.
  • Finalize and Document the implemented callbacks. There are still a few issues that need to be fixed in these callbacks.
  • Think about how callbacks should be reinitialized when reused, like refitting an estimator.

@rth rth mentioned this pull request Dec 17, 2021
5 tasks
@ogrisel
Copy link
Member

ogrisel commented Dec 17, 2021

Thanks! Another use case I see is structured logging: instead of generating lines in a text file, generate an event log in json file, records in a database (e.g. MongoDB or PostgreSQL, possibly via a JSON column type), a Kafka stream or an ML specific with ML tracking platforms, for instance MLFlow tracking features or weights and biases' wandb.log.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

@@ -515,6 +518,22 @@ def sag{{name_suffix}}(SequentialDataset{{name_suffix}} dataset,
fabs(weights[idx] -
previous_weights[idx]))
previous_weights[idx] = weights[idx]

with gil:
if _eval_callbacks_on_fit_iter_end(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the overhead of taking the GIL compare to early stopping directly using the stopping_criteron?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has an impact on performance for sure. But If we want to enable callbacks at this step of the fit there's no way around.
What we can do however is to check before entering the nogil part if the estimator has callbacks and execute this part only if it's the case. Let me try something like that. We might encounter the same issue as in #13389

sklearn/callback/_computation_tree.py Show resolved Hide resolved
else:
# node is a leaf, look for tasks of its sub computation tree before
# going to the next node
child_dir = this_dir / str(node.tree_status_idx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should abstract away the filesystem that backs the computation trees because:

  1. I do not think we want third party developers writing Callbacks to worry about the filesystem.
  2. It will be easier to switch to another inter-process communication method in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably better yes. I'll try to come up with a more friendly solution

@@ -0,0 +1,268 @@
# License: BSD 3 clause
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali Discussing with @jeremiedbb IRL and while I was explaining to the sample-props PR, he was under the impression that the MetaDataRequest class would be similar to the ComputationTree in some regards. Maybe you could have a look for some inspiration :)

else:
sub_estimator._callbacks.extend(propagated_callbacks)

def _eval_callbacks_on_fit_begin(self, *, levels, X=None, y=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be too magical to have only something call _eval_callbacks_begin and inspect internal the stack of calls to infer which method called this function.

Of course it would make sense only if the same methods are expected to be called for fit/predict/etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course it would make sense only if the same methods are expected to be called for fit/predict/etc.

Well that's not obvious at all and I've not really thought about that. This first iteration is all about fit. I think it will be easier to not try to be too magical for now

@chritter
Copy link
Contributor

chritter commented Dec 29, 2021

@jeremiedbb Would this PR cover early stopping for RandomizedSearchCV when using a time budget would be beneficial (e.g. stop after X seconds instead of a iteration limit)? Snapshots seem to apply to a single estimator. Maybe it is out of scope. Thanks!

@jeremiedbb
Copy link
Member Author

@chritter For now EarlyStopping based on a time budget in SearchCV estimators doesn't seem possible due to joblib (it might be possible at some point if the possibility to return a generator is merged joblib/joblib#588)

@adampl
Copy link

adampl commented Sep 22, 2022

@jeremiedbb What is the current status of this feature? Is it abandoned? :(

@jeremiedbb
Copy link
Member Author

Is it abandoned? :(

No it's not :) I haven't been working on it for some time but I started working on it again a few weeks ago. There's still a lot work to do though

@ogrisel
Copy link
Member

ogrisel commented Sep 22, 2022

Maybe you could keep this WIP branch up to date ;)

@github-actions
Copy link

❌ Linting issues

This PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling pre-commit hooks. Instructions to enable them can be found here.

You can see the details of the linting issues under the lint job here


mypy

mypy detected issues. Please fix them locally and push the changes. Here you can see the detected issues. Note that the installed mypy version is mypy=1.3.0.


sklearn/externals/_arff.py:782: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs  [annotation-unchecked]
sklearn/callback/tests/_utils.py:37: error: Need type annotation for "_parameter_constraints" (hint: "_parameter_constraints: Dict[<type>, <type>] = ...")  [var-annotated]
sklearn/callback/tests/_utils.py:74: error: Need type annotation for "_parameter_constraints" (hint: "_parameter_constraints: Dict[<type>, <type>] = ...")  [var-annotated]
Found 2 errors in 1 file (checked 553 source files)

Generated for commit: b8ac1a5. Link to the linter CI: here

@jondo
Copy link

jondo commented Nov 29, 2023

Remark: #27663 implements a smaller portion of this.

@amueller
Copy link
Member

amueller commented Feb 9, 2024

I think I'm -1 on using callbacks for early stopping since I don't see a way of making it work within pipelines.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use python logging to report on convergence progress it level info for long running tasks
8 participants