Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA add TunedThresholdClassifier meta-estimator to post-tune the cut-off threshold #26120

Merged
merged 228 commits into from May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 89 commits
Commits
Show all changes
228 commits
Select commit Hold shift + click to select a range
b44dd9d
MAINT refactor scorer using _get_response_values
glemaitre Mar 31, 2023
516f62f
Add __name__ for method of Mock
glemaitre Apr 1, 2023
d2fbee0
remove multiclass issue
glemaitre Apr 1, 2023
29e5e87
make response_method a mandatory arg
glemaitre Apr 3, 2023
b645ade
Update sklearn/metrics/_scorer.py
glemaitre Apr 3, 2023
3397c56
apply jeremie comments
glemaitre Apr 3, 2023
092689a
Merge branch 'main' into is/18589_restart
glemaitre Apr 3, 2023
200ec31
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 4, 2023
e871558
iter
glemaitre Apr 4, 2023
31aa1c0
Merge remote-tracking branch 'glemaitre/is/18589_restart' into cutoff…
glemaitre Apr 4, 2023
74614e8
FEA add CutOffClassifier to post-tune prediction threshold
glemaitre Apr 7, 2023
27713af
DOC add changelog entry
glemaitre Apr 7, 2023
ed1d9b3
refresh implementation
glemaitre Apr 7, 2023
8410317
add files
glemaitre Apr 7, 2023
c7d1fe4
remove random state for the moment
glemaitre Apr 7, 2023
c9d7a22
TST make sure to pass the common test
glemaitre Apr 15, 2023
9981f3a
TST metaestimator sample_weight
glemaitre Apr 15, 2023
b9c9d5e
API add prediction functions
glemaitre Apr 15, 2023
588f1c4
TST bypass the test for classification
glemaitre Apr 17, 2023
243d173
iter before another bug
glemaitre Apr 17, 2023
883e929
iter
glemaitre Apr 18, 2023
69333ed
TST add test for _fit_and_score
glemaitre Apr 19, 2023
8616da1
iter
glemaitre Apr 19, 2023
99a10b3
integrate refit
glemaitre Apr 19, 2023
0f6dce2
TST more test
glemaitre Apr 19, 2023
d6fb9f7
TST more test with sample_weight
glemaitre Apr 19, 2023
7ff3d0d
BUG fit_params split
glemaitre Apr 19, 2023
6985ae9
TST add test for fit_params
glemaitre Apr 19, 2023
239793a
TST check underlying response method for TNR/TPR
glemaitre Apr 20, 2023
92083ed
FEA add the possibility to provide a dict
glemaitre Apr 20, 2023
55d0844
TST check string and pos_label interation for cost-matrix
glemaitre Apr 20, 2023
7dfc4a6
TST add sample_weight test for cost-matrix
glemaitre Apr 20, 2023
729c9a8
iter
glemaitre Apr 20, 2023
787be21
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 24, 2023
146b170
change strategy for finding max
glemaitre Apr 24, 2023
03b1f7f
iter
glemaitre Apr 24, 2023
8a09a5f
add some test for precision-recall
glemaitre Apr 24, 2023
d56f57f
TST add invariance zeros weight
glemaitre Apr 24, 2023
cf164c5
DOC fix default n_thresholds
glemaitre Apr 24, 2023
c943f5e
DOC add a small example
glemaitre Apr 25, 2023
bf1462b
iter
glemaitre Apr 25, 2023
fa89431
iter
glemaitre Apr 25, 2023
862519d
bug fixes everywhere
glemaitre Apr 25, 2023
aa520da
iter
glemaitre Apr 25, 2023
5403cf6
Do not allow for single threshold
glemaitre Apr 26, 2023
cd37743
TST add random state checkingclassifier
glemaitre Apr 26, 2023
e7d07af
TST more test for _ContinuousScorer
glemaitre Apr 26, 2023
bc20a47
TST add test for pos_label
glemaitre Apr 26, 2023
bba2f97
TST add pos_label test for TNR/TPR
glemaitre Apr 26, 2023
f925503
some more
glemaitre Apr 26, 2023
d539235
avoid extrapolation
glemaitre Apr 27, 2023
c0acd44
FEA add all thresholds and score computed as attributes
glemaitre Apr 27, 2023
f87baa7
fix docstring
glemaitre Apr 28, 2023
e4dac09
EXA add example of cut-off tuning
glemaitre Apr 28, 2023
4da7cef
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 28, 2023
bd86595
solving the issue of unknown categories
glemaitre Apr 28, 2023
45e6e5a
fix
glemaitre Apr 28, 2023
402a1a7
EXA add hyperlink in the example
glemaitre Apr 29, 2023
6745afc
DOC add warning regarding overfitting
glemaitre Apr 29, 2023
4d557cc
some more doc
glemaitre Apr 29, 2023
2c6ee7e
some more doc
glemaitre Apr 29, 2023
91c8222
DOC more documentation
glemaitre May 2, 2023
9a96ae1
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre May 2, 2023
3d4ce81
fix import
glemaitre May 2, 2023
d7d8dac
fix import
glemaitre May 2, 2023
aa3e83d
iter
glemaitre May 2, 2023
ab97d63
fix
glemaitre May 2, 2023
acb6af8
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre May 4, 2023
486a2bd
Update sklearn/metrics/_scorer.py
glemaitre May 4, 2023
6d4c4aa
Apply suggestions from code review
glemaitre May 15, 2023
1d12e1f
Fix linter
ogrisel Jun 1, 2023
21e20e0
Merge branch 'main' into cutoff_classifier_again
ogrisel Jun 1, 2023
7952cce
Add routing to LogisticRegressionCV
Jun 7, 2023
66ad513
Add a test with enable_metadata_routing=False and fix an issue in sco…
Jun 7, 2023
7e8b824
Add metaestimator tests and fix passing routed params in score method
Jun 13, 2023
d7e50a6
PR suggestions
Jun 25, 2023
3844706
Merge branch 'main' into logistic_cv_routing
Jun 26, 2023
0866c42
Add changelog entry
Jun 26, 2023
43f971b
Add user and pr information
Jun 26, 2023
db63769
Changelog adjustment
Jun 26, 2023
a9b984f
Remove repr method from ConsumingScorer
Jun 26, 2023
97105a4
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 3, 2023
52f5921
handle the np.inf case in roc-curve
glemaitre Jul 3, 2023
637c18e
Merge branch 'main' into logistic_cv_routing
adrinjalali Jul 7, 2023
314bc83
Adjust changelog
Jul 7, 2023
9a8ef4e
Add tests for error when passing params when routing not enabled in L…
OmarManzoor Jul 10, 2023
5b723a0
Address PR suggestions partially
Jul 13, 2023
9ce463d
address comment Tim
glemaitre Jul 13, 2023
bba8f55
iter
glemaitre Jul 13, 2023
1c5487d
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 13, 2023
d302678
MAINT rename modules as per olivier comment
glemaitre Jul 13, 2023
2ade221
add missing module
glemaitre Jul 13, 2023
dca5770
update changelog
glemaitre Jul 13, 2023
8897533
more renaming
glemaitre Jul 13, 2023
75bd7ac
iter
glemaitre Jul 13, 2023
c07a980
Adjust and change the name of params in _check_method_params
OmarManzoor Jul 13, 2023
cc5ba48
Resolve conflict in changelog
OmarManzoor Jul 13, 2023
66c4c7f
iter
glemaitre Jul 13, 2023
378930e
iter
glemaitre Jul 13, 2023
c88ed94
iter
glemaitre Jul 13, 2023
4715e67
iter
glemaitre Jul 13, 2023
b3bb39f
iter
glemaitre Jul 13, 2023
915624a
Merge branch 'main' into logistic_cv_routing
glemaitre Jul 13, 2023
b72a72a
iter
glemaitre Jul 13, 2023
5108e43
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 13, 2023
95150b0
Merge branch 'pr/OmarManzoor/26525' into cutoff_with_metadata_routing
glemaitre Jul 13, 2023
5b66ab8
Merge remote-tracking branch 'origin/main' into cutoff_with_metadata_…
glemaitre Jul 14, 2023
b4e67fb
Add metadata routing
glemaitre Jul 14, 2023
005126a
Apply suggestions from code review
glemaitre Jul 14, 2023
767a05f
CLN clean up some repeated code related to SLEP006
adrinjalali Jul 14, 2023
080ba5c
iter
glemaitre Jul 14, 2023
05ec85d
iter
glemaitre Jul 14, 2023
63c32bd
ENH add new response_method in make_scorer
glemaitre Jul 15, 2023
1584c5b
add non-regression test
glemaitre Jul 15, 2023
1a5a247
update validation param
glemaitre Jul 15, 2023
4cc61b9
more coverage
glemaitre Jul 15, 2023
8f36235
TST add mulitlabel test
glemaitre Jul 15, 2023
9e6b384
Merge branch 'make_scorer_list_response' into cutoff_classifier_again
glemaitre Jul 15, 2023
5490ce4
simplify scorer
glemaitre Jul 15, 2023
8dad0a4
iter
glemaitre Jul 15, 2023
b918708
remove unecessary part in doc
glemaitre Jul 15, 2023
5e23523
iter
glemaitre Jul 15, 2023
d5578f9
iter
glemaitre Jul 16, 2023
f3f844e
address tim comments
glemaitre Jul 24, 2023
44ad195
Merge remote-tracking branch 'origin/main' into make_scorer_list_resp…
glemaitre Jul 24, 2023
e489eab
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 24, 2023
1cf5528
Merge branch 'make_scorer_list_response' into cutoff_classifier_again
glemaitre Jul 24, 2023
26dc94e
iter
glemaitre Jul 26, 2023
6a1a6c7
iter
glemaitre Jul 26, 2023
b17b59e
iter
glemaitre Jul 28, 2023
43c1da8
iter
glemaitre Jul 28, 2023
41a6d07
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 28, 2023
ca06717
iter
glemaitre Jul 28, 2023
ab8b466
iter
glemaitre Jul 30, 2023
235abf5
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Aug 7, 2023
d9ec528
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Aug 11, 2023
69f60a6
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Sep 28, 2023
45a8504
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Oct 17, 2023
8c4c88d
solve deprecation
glemaitre Oct 17, 2023
b97ebf4
Merge remote-tracking branch 'glemaitre/cutoff_classifier_again' into…
glemaitre Oct 17, 2023
e37f831
update changelog
glemaitre Oct 17, 2023
383937f
whoops
glemaitre Oct 17, 2023
d4ce3fb
Update sklearn/metrics/_scorer.py
glemaitre Oct 17, 2023
b6b3548
fix doc
glemaitre Oct 17, 2023
759d680
remove useless fitted attributes
glemaitre Oct 18, 2023
23e65e6
Merge branch 'main' into cutoff_classifier_again
ogrisel Dec 4, 2023
6904817
bump pandas to 1.1.5
glemaitre Dec 4, 2023
bee1ebe
update lock file
glemaitre Dec 4, 2023
4d86a36
iter
glemaitre Jan 13, 2024
48fd7cd
update doc-min lock file
glemaitre Jan 13, 2024
0854cd4
partial reviews
glemaitre Jan 13, 2024
ac75300
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Mar 18, 2024
2df616e
Apply suggestions from code review
glemaitre Mar 18, 2024
b14225c
update lock files
glemaitre Mar 18, 2024
98dcefd
Merge remote-tracking branch 'glemaitre/cutoff_classifier_again' into…
glemaitre Mar 18, 2024
7e3d7aa
iter
glemaitre Mar 18, 2024
b958bb0
simplify refit and do not allow cv == 1
glemaitre Mar 19, 2024
e7722f6
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Mar 19, 2024
98a1db8
check raise for multilabel
glemaitre Mar 19, 2024
c28a3e1
fix test name
glemaitre Mar 19, 2024
076fd29
test another to check if beta is forwarded
glemaitre Mar 19, 2024
e728f1d
iter
glemaitre Mar 20, 2024
c73b205
refit=True and cv is float
glemaitre Mar 20, 2024
a4890df
rename scorer to curve scorer internally
glemaitre Mar 22, 2024
f8a5a79
add a note regarding the abuse of the scorer API
glemaitre Mar 22, 2024
5dfa435
use None instead of highest
glemaitre Mar 22, 2024
d45a71b
use a closer CV API
glemaitre Mar 23, 2024
a32c151
fix example
glemaitre Mar 23, 2024
7592437
simplify model
glemaitre Mar 23, 2024
dc5346b
fix
glemaitre Mar 23, 2024
843ca04
fix docstring
glemaitre Mar 23, 2024
51ed9a8
Apply suggestions from code review
glemaitre Mar 30, 2024
3c89ab3
Apply suggestions from code review
glemaitre Apr 2, 2024
8cd5582
pep8
glemaitre Apr 2, 2024
a48487c
rephrase suggestions
glemaitre Apr 2, 2024
dd18549
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 4, 2024
27515ca
fix
glemaitre Apr 8, 2024
8a87b26
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 8, 2024
811dec9
include and discuss more about amount
glemaitre Apr 8, 2024
c83b4e1
iter
glemaitre Apr 8, 2024
d4e232f
Apply suggestions from code review
glemaitre Apr 25, 2024
92f6e05
Update examples/model_selection/plot_tuned_decision_threshold.py
glemaitre Apr 25, 2024
1c5c3f4
iter
glemaitre Apr 25, 2024
94160ba
iter
glemaitre Apr 25, 2024
85c8484
other comment
glemaitre Apr 25, 2024
4f86e9d
addressed comments
glemaitre Apr 25, 2024
d747098
Apply suggestions from code review
glemaitre Apr 27, 2024
6d0f418
rename TunedThresholdClassifier to TunedThresholdClassifierCV
glemaitre Apr 27, 2024
5671dd6
use meaningful values for check the thresholds values depending on po…
glemaitre Apr 27, 2024
f04085d
TST add more info regarding why not exactly 0 and 1
glemaitre Apr 27, 2024
d179b5f
DOC add documentation for base scorer
glemaitre Apr 27, 2024
2c375f8
DOC add more details regarding the curve scorer
glemaitre Apr 27, 2024
66ba8da
directly test curve_scorer instead to look for function anem
glemaitre Apr 27, 2024
a6b19c1
add required arguments
glemaitre Apr 27, 2024
b3b99ff
DOC add docstring for interpolated score
glemaitre Apr 27, 2024
dda0d2c
Update sklearn/model_selection/tests/test_classification_threshold.py
glemaitre Apr 29, 2024
48e7829
Update sklearn/model_selection/tests/test_classification_threshold.py
glemaitre Apr 29, 2024
17839e8
Apply suggestions from code review
glemaitre Apr 29, 2024
3f02bc3
remove duplicated check
glemaitre Apr 29, 2024
553cfce
remove duplicated check
glemaitre Apr 29, 2024
6ae6d27
check cv_results_ API
glemaitre Apr 29, 2024
d010096
clone classifier
glemaitre Apr 29, 2024
8bb8ca6
TST better comments
glemaitre Apr 29, 2024
fd971c7
iter
glemaitre Apr 29, 2024
bf57dac
FEA add a ConstantThresholdClassifier instead of strategy="constant" …
glemaitre Apr 30, 2024
0409932
make FixedThresholdClassifier appear in example
glemaitre Apr 30, 2024
66ea575
iter
glemaitre Apr 30, 2024
1c97dd4
Update doc/modules/classes.rst
glemaitre Apr 30, 2024
c8c1d0c
Update sklearn/model_selection/_classification_threshold.py
glemaitre Apr 30, 2024
8a52bc6
TST and fix default parameter
glemaitre Apr 30, 2024
ef668cf
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 30, 2024
fdbf68e
TST metadarouting FixedThresholdClassifier
glemaitre Apr 30, 2024
9c0c13d
rename n_thresholds to thresholds
glemaitre Apr 30, 2024
f419371
cover constant predictor error
glemaitre Apr 30, 2024
42eafe5
TST some tests for get_response_values_binary
glemaitre Apr 30, 2024
581133f
use conditional p(y|X) instead of posterior
glemaitre May 2, 2024
0f803d9
be more explicit that strings need to be provided to objective_metric
glemaitre May 2, 2024
eb0defc
factorize plotting into a function
glemaitre May 2, 2024
ffd5669
fix typo in code
glemaitre May 2, 2024
18abafe
use proper scoring rule and robust estimator to scale
glemaitre May 2, 2024
ce9464c
improve narrative
glemaitre May 2, 2024
89d67cf
use grid-search
glemaitre May 2, 2024
db3360b
Apply suggestions from code review
glemaitre May 2, 2024
1789cc0
remove constrainted metrics option
glemaitre May 3, 2024
e7c31b9
partial review
glemaitre May 3, 2024
0fd667c
rename objective_metric to scoring
glemaitre May 3, 2024
07e4387
fix typo
glemaitre May 3, 2024
9bd68e6
remove pos_label and delegate to make_scorer
glemaitre May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/model_selection.rst
Expand Up @@ -14,5 +14,6 @@ Model selection and evaluation

modules/cross_validation
modules/grid_search
modules/classification_threshold
modules/model_evaluation
modules/learning_curve
10 changes: 10 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -1221,6 +1221,16 @@ Hyper-parameter optimizers
model_selection.RandomizedSearchCV
model_selection.HalvingRandomSearchCV

Model post-fit tuning
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
---------------------

.. currentmodule:: sklearn

.. autosummary::
:toctree: generated/
:template: class.rst

model_selection.TunedThresholdClassifier

Model validation
----------------
Expand Down
171 changes: 171 additions & 0 deletions doc/modules/classification_threshold.rst
@@ -0,0 +1,171 @@
.. currentmodule:: sklearn.model_selection

.. _tunedthresholdclassifier:

========================================================
Tuning cut-off decision threshold for classes prediction
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
========================================================

glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Classifiers are predictive models: they use statistical learning to predict outcomes.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
The outcomes of a classifier are scores for each sample in relation to each class and
categorical prediction (class label). Scores are obtained from :term:`predict_proba` or
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
:term:`decision_function`. The former returns posterior probability estimates for each
class while the latter returns a decision score for each class. The decision score is a
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
measure of how strongly the sample is predicted to belong to the positive class (e.g.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
the distance to the decisin boundary). A decision rule is then defined by thresholding
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
the scores and obtained the class label for each sample. Those labels are obtained with
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
:term:`predict`.

For binary classification in scikit-learn, class labels are obtained by associating the
positive class with posterior probability estimates greater than 0.5 (obtained with
:term:`predict_proba`) or decision scores greater than 0 (obtained with
:term:`decision_function`).
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

Here, we show an example that illustrates the relation between posterior
probability estimates and class labels::
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

>>> from sklearn.datasets import make_classification
>>> from sklearn.tree import DecisionTreeClassifier
>>> X, y = make_classification(random_state=0)
>>> classifier = DecisionTreeClassifier(max_depth=2, random_state=0).fit(X, y)
>>> classifier.predict_proba(X[:4])
array([[0.94 , 0.06 ],
[0.94 , 0.06 ],
[0.04..., 0.95...],
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
[0.04..., 0.95...]])
>>> classifier.predict(X[:4])
array([0, 0, 1, 1])

While these approaches are reasonable as default behaviors, they are not be ideal for
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
all cases. The context and nature of the use case defines the expected behavior of the
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
classifier and thus the strategy to convert soft predictions into hard predictions. We
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
illustrate this point with an example.

Let's imagine the deployment of a predictive model helping medical doctors to detect
tumour. In a setting where this model was a tool to discard obvious cases and false
positives don't lead to potentially harmful treatments, doctors might be interested in
having a high recall (all cancer cases should be tagged as such) to not miss any patient
with a cancer. However, that is at the cost of having more false positive predictions
(i.e. lower precision). Thus, in terms of decision threshold, it may be better to
classify a patient as having a cancer for a posterior probability estimate lower than
0.5.

glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Post-tuning of the decision threshold
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
=====================================

One solution to address the problem stated in the introduction is to tune the decision
threshold of the classifier once the model has been trained. The
:class:`~sklearn.model_selection.TunedThresholdClassifier` tunes this threshold using an
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
internal cross-validation. The optimum threshold is chosen to maximize a given metric
Copy link
Member

@jnothman jnothman Dec 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need justify using "internal cross-validation" somewhere rather than tuning this like any other hyperparameter? Indeed, ignoring constraint_value is this tool just a workaround for a design flaw wherein a parameter to BaseClassifier is not exposed to *SearchCV with warm start?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today, I would indeed consider it as a design and methodological flaw. I don't know what level of justification do you think we should add-on here:

... tunes this threshold using an internal cross-validation since the scikit-learn
:term:`predict` API does not offer this flexibility.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. Isn't the problem that we can't freeze the estimator? Maybe I need to look at the code first...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I realized there's two sort-of independent long-term API issues going on here and you were talking about the other one lol. So we need to have internal cross-validation because we still don't have an API for path-like algorithms and efficient CV and that's what you were talking about. But there's also the "prefit" part which I was thinking about, which is now actually fixed via the __clone__ protocol.
Basically if we're not fitting the underlying estimator, then we should use the __clone__ protocol to not clone it if the TunedThresholdClassifier is being cloned, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still don't have an API for path-like algorithms and efficient CV

@amueller What do you mean by that?

is this tool just a workaround for a design flaw wherein a parameter to BaseClassifier is not exposed

and

Today, I would indeed consider it as a design and methodological flaw.

@glemaitre @jnothman Shouldn't we fix this design flaw? Do you have suggestions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we fix this design flaw? Do you have suggestions?

This is something that I would introduce in 2.0 because we will surely break some API. In this case, we could the freedom to think about a good API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess what I am missing is a roadmap or a plan/vision.

with or without constraints.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we note the contrast with what CalibratedClassifierCV is for?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes for sure.

glemaitre marked this conversation as resolved.
Show resolved Hide resolved

The following image illustrates the tuning of the cut-off point for a gradient boosting
classifier. While the vanilla and tuned classifiers provide the same Receiver Operating
Characteristic (ROC) and Precision-Recall curves, and thus the same
:term:`predict_proba` outputs, the class label predictions differ because of the tuned
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
decision threshold. The vanilla classifier predicts the class of interest for a
posterior probability greater than 0.5 while the tuned classifier predicts the class of
interest for a very low probability (around 0.02). This cut-off point optimizes a
utility metric defined by the business case (in this case an insurance company).
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_tuned_threshold_classifier_002.png
:target: ../auto_examples/model_selection/plot_tuned_threshold_classifier.html
:align: center

Available options to tune the cut-off point
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
-------------------------------------------

The cut-off point can be tuned with different strategies controlled by the parameter
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
`objective_metric`.

A straightforward use case is to maximize a pre-defined scikit-learn metric. These
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
metrics can be found by calling the function :func:`~sklearn.metrics.get_scorer_names`.
We provide an example where we maximize the balanced accuracy.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

.. note::

It is important to notice that these metrics come with default parameters, notably
the label of the class of interested (i.e. `pos_label`). Thus, if this label is not
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
the right one for your application, you need to define a scorer and pass the right
`pos_label` (and additional parameters) using the
:func:`~sklearn.metrics.make_scorer`. You should refer to :ref:`scoring` to get all
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
information to define your own scoring function. For instance, we show how to pass
the information to the scorer that the label of interest is `0` when maximizing the
:func:`~sklearn.metrics.f1_score`:

>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import (
... TunedThresholdClassifier, train_test_split
... )
>>> from sklearn.metrics import make_scorer, f1_score
>>> X, y = make_classification(
... n_samples=1_000, weights=[0.1, 0.9], random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
>>> pos_label = 0
>>> scorer = make_scorer(f1_score, pos_label=pos_label)
>>> base_model = LogisticRegression()
>>> model = TunedThresholdClassifier(base_model, objective_metric=scorer).fit(
... X_train, y_train)
>>> scorer(model, X_test, y_test)
0.82...
>>> # compare it with the internal score found by cross-validation
>>> model.objective_score_
0.86...

A second strategy aims at maximizing a metric while imposing constraints on another
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
metric. Four pre-defined options exist, 2 that uses the Receiver Operating
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Characteristic (ROC) statistics and 2 that uses the Precision-Recall statistics.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

- `"max_tpr_at_tnr_constraint"`: maximizes the True Positive Rate (TPR) such that the
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
True Negative Rate (TNR) is the closest to a given value.
- `"max_tnr_at_tpr_constraint"`: maximizes the TNR such that the TPR is the closest to
a given value.
- `"max_precision_at_recall_constraint"`: maximizes the precision such that the recall
is the closest to a given value.
- `"max_recall_at_precision_constraint"`: maximizes the recall such that the precision
is the closest to a given value.

For these options, the `constraint_value` parameter needs to be defined. In addition,
you can use the `pos_label` parameter to indicate the label of the class of interest.

The final strategy maximizes a custom utility function. This problem is also known as
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this MetaCost? If that is the case, instead of saying "utility function" which is unclear what it is, we may say "conditional risk". and then quote the reference: https://dl.acm.org/doi/10.1145/312129.312220

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really MetaCost procedure here. We are to revisit the statistical term indeed.

cost-sensitive learning. The utility function is defined by providing a dictionary
containing the cost-gain associated with the entries of the confusion matrix. The keys
are defined as `{"tn", "fp", "fn", "tp"}`. The class of interest is defined using the
`pos_label` parameter. Refer to :ref:`cost_sensitive_learning_example` for an example
depicting the use of such a utility function.

Important notes regarding the internal cross-validation
-------------------------------------------------------

By default :class:`~sklearn.model_selection.TunedThresholdClassifier` uses a
5-fold stratified cross-validation to tune the cut-off point. The parameter
`cv` allows to control the cross-validation strategy. It is possible to go
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
around cross-validation by passing `cv="prefit"` and provide an already fitted
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
classifier. In this case, the cut-off point is tuned on the data provided to
the `fit` method.

However, you should be extremely careful when using this option. You should never use
the same data for training the classifier and tuning the cut-off point at the risk of
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
overfitting. Refer to :ref:`tunedthresholdclassifier_no_cv` that shows such overfitting. If
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
you are in a situation where you have limited resources, you should consider using
a float number that will use a single split internally.

glemaitre marked this conversation as resolved.
Show resolved Hide resolved
The option `cv="prefit"` should only be used when the provided classifier was already
trained on some data and you want to tune (or re-tune) on a new validation set.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

Manually setting the decision thresholding
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
-------------------------------------------

The previous sections discussed strategies to find an optimal decision threshold. It is
also possible to manually set the decision threshold in
:class`~sklearn.model_selection.TunedThresholdClassifier` by setting the parameter
`strategy` to `"constant"` and provide the desired threshold using the parameter
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
`constant_threshold`.

Examples
--------

- See
:ref:`sphx_glr_auto_examples_model_selection_plot_tuned_threshold_classifier.py`
example for an example of tuning the decision threshold of a classifier.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 8 additions & 0 deletions doc/whats_new/v1.4.rst
Expand Up @@ -78,6 +78,14 @@ Changelog
- |Fix| :func:`feature_selection.mutual_info_regression` now correctly computes the
result when `X` is of integer dtype. :pr:`26748` by :user:`Yao Xiao <Charlie-XIAO>`.

:mod:`sklearn.model_selection`
..............................

- |MajorFeature| :class:`model_selection.TunedThresholdClassifier` calibrates
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
decision threshold function of a binary classifier by maximizing a
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
classification metric through cross-validation.
:pr:`26120` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.pipeline`
.......................

Expand Down