Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA add TunedThresholdClassifier meta-estimator to post-tune the cut-off threshold #26120

Merged
merged 228 commits into from May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
228 commits
Select commit Hold shift + click to select a range
b44dd9d
MAINT refactor scorer using _get_response_values
glemaitre Mar 31, 2023
516f62f
Add __name__ for method of Mock
glemaitre Apr 1, 2023
d2fbee0
remove multiclass issue
glemaitre Apr 1, 2023
29e5e87
make response_method a mandatory arg
glemaitre Apr 3, 2023
b645ade
Update sklearn/metrics/_scorer.py
glemaitre Apr 3, 2023
3397c56
apply jeremie comments
glemaitre Apr 3, 2023
092689a
Merge branch 'main' into is/18589_restart
glemaitre Apr 3, 2023
200ec31
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 4, 2023
e871558
iter
glemaitre Apr 4, 2023
31aa1c0
Merge remote-tracking branch 'glemaitre/is/18589_restart' into cutoff…
glemaitre Apr 4, 2023
74614e8
FEA add CutOffClassifier to post-tune prediction threshold
glemaitre Apr 7, 2023
27713af
DOC add changelog entry
glemaitre Apr 7, 2023
ed1d9b3
refresh implementation
glemaitre Apr 7, 2023
8410317
add files
glemaitre Apr 7, 2023
c7d1fe4
remove random state for the moment
glemaitre Apr 7, 2023
c9d7a22
TST make sure to pass the common test
glemaitre Apr 15, 2023
9981f3a
TST metaestimator sample_weight
glemaitre Apr 15, 2023
b9c9d5e
API add prediction functions
glemaitre Apr 15, 2023
588f1c4
TST bypass the test for classification
glemaitre Apr 17, 2023
243d173
iter before another bug
glemaitre Apr 17, 2023
883e929
iter
glemaitre Apr 18, 2023
69333ed
TST add test for _fit_and_score
glemaitre Apr 19, 2023
8616da1
iter
glemaitre Apr 19, 2023
99a10b3
integrate refit
glemaitre Apr 19, 2023
0f6dce2
TST more test
glemaitre Apr 19, 2023
d6fb9f7
TST more test with sample_weight
glemaitre Apr 19, 2023
7ff3d0d
BUG fit_params split
glemaitre Apr 19, 2023
6985ae9
TST add test for fit_params
glemaitre Apr 19, 2023
239793a
TST check underlying response method for TNR/TPR
glemaitre Apr 20, 2023
92083ed
FEA add the possibility to provide a dict
glemaitre Apr 20, 2023
55d0844
TST check string and pos_label interation for cost-matrix
glemaitre Apr 20, 2023
7dfc4a6
TST add sample_weight test for cost-matrix
glemaitre Apr 20, 2023
729c9a8
iter
glemaitre Apr 20, 2023
787be21
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 24, 2023
146b170
change strategy for finding max
glemaitre Apr 24, 2023
03b1f7f
iter
glemaitre Apr 24, 2023
8a09a5f
add some test for precision-recall
glemaitre Apr 24, 2023
d56f57f
TST add invariance zeros weight
glemaitre Apr 24, 2023
cf164c5
DOC fix default n_thresholds
glemaitre Apr 24, 2023
c943f5e
DOC add a small example
glemaitre Apr 25, 2023
bf1462b
iter
glemaitre Apr 25, 2023
fa89431
iter
glemaitre Apr 25, 2023
862519d
bug fixes everywhere
glemaitre Apr 25, 2023
aa520da
iter
glemaitre Apr 25, 2023
5403cf6
Do not allow for single threshold
glemaitre Apr 26, 2023
cd37743
TST add random state checkingclassifier
glemaitre Apr 26, 2023
e7d07af
TST more test for _ContinuousScorer
glemaitre Apr 26, 2023
bc20a47
TST add test for pos_label
glemaitre Apr 26, 2023
bba2f97
TST add pos_label test for TNR/TPR
glemaitre Apr 26, 2023
f925503
some more
glemaitre Apr 26, 2023
d539235
avoid extrapolation
glemaitre Apr 27, 2023
c0acd44
FEA add all thresholds and score computed as attributes
glemaitre Apr 27, 2023
f87baa7
fix docstring
glemaitre Apr 28, 2023
e4dac09
EXA add example of cut-off tuning
glemaitre Apr 28, 2023
4da7cef
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 28, 2023
bd86595
solving the issue of unknown categories
glemaitre Apr 28, 2023
45e6e5a
fix
glemaitre Apr 28, 2023
402a1a7
EXA add hyperlink in the example
glemaitre Apr 29, 2023
6745afc
DOC add warning regarding overfitting
glemaitre Apr 29, 2023
4d557cc
some more doc
glemaitre Apr 29, 2023
2c6ee7e
some more doc
glemaitre Apr 29, 2023
91c8222
DOC more documentation
glemaitre May 2, 2023
9a96ae1
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre May 2, 2023
3d4ce81
fix import
glemaitre May 2, 2023
d7d8dac
fix import
glemaitre May 2, 2023
aa3e83d
iter
glemaitre May 2, 2023
ab97d63
fix
glemaitre May 2, 2023
acb6af8
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre May 4, 2023
486a2bd
Update sklearn/metrics/_scorer.py
glemaitre May 4, 2023
6d4c4aa
Apply suggestions from code review
glemaitre May 15, 2023
1d12e1f
Fix linter
ogrisel Jun 1, 2023
21e20e0
Merge branch 'main' into cutoff_classifier_again
ogrisel Jun 1, 2023
7952cce
Add routing to LogisticRegressionCV
Jun 7, 2023
66ad513
Add a test with enable_metadata_routing=False and fix an issue in sco…
Jun 7, 2023
7e8b824
Add metaestimator tests and fix passing routed params in score method
Jun 13, 2023
d7e50a6
PR suggestions
Jun 25, 2023
3844706
Merge branch 'main' into logistic_cv_routing
Jun 26, 2023
0866c42
Add changelog entry
Jun 26, 2023
43f971b
Add user and pr information
Jun 26, 2023
db63769
Changelog adjustment
Jun 26, 2023
a9b984f
Remove repr method from ConsumingScorer
Jun 26, 2023
97105a4
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 3, 2023
52f5921
handle the np.inf case in roc-curve
glemaitre Jul 3, 2023
637c18e
Merge branch 'main' into logistic_cv_routing
adrinjalali Jul 7, 2023
314bc83
Adjust changelog
Jul 7, 2023
9a8ef4e
Add tests for error when passing params when routing not enabled in L…
OmarManzoor Jul 10, 2023
5b723a0
Address PR suggestions partially
Jul 13, 2023
9ce463d
address comment Tim
glemaitre Jul 13, 2023
bba8f55
iter
glemaitre Jul 13, 2023
1c5487d
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 13, 2023
d302678
MAINT rename modules as per olivier comment
glemaitre Jul 13, 2023
2ade221
add missing module
glemaitre Jul 13, 2023
dca5770
update changelog
glemaitre Jul 13, 2023
8897533
more renaming
glemaitre Jul 13, 2023
75bd7ac
iter
glemaitre Jul 13, 2023
c07a980
Adjust and change the name of params in _check_method_params
OmarManzoor Jul 13, 2023
cc5ba48
Resolve conflict in changelog
OmarManzoor Jul 13, 2023
66c4c7f
iter
glemaitre Jul 13, 2023
378930e
iter
glemaitre Jul 13, 2023
c88ed94
iter
glemaitre Jul 13, 2023
4715e67
iter
glemaitre Jul 13, 2023
b3bb39f
iter
glemaitre Jul 13, 2023
915624a
Merge branch 'main' into logistic_cv_routing
glemaitre Jul 13, 2023
b72a72a
iter
glemaitre Jul 13, 2023
5108e43
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 13, 2023
95150b0
Merge branch 'pr/OmarManzoor/26525' into cutoff_with_metadata_routing
glemaitre Jul 13, 2023
5b66ab8
Merge remote-tracking branch 'origin/main' into cutoff_with_metadata_…
glemaitre Jul 14, 2023
b4e67fb
Add metadata routing
glemaitre Jul 14, 2023
005126a
Apply suggestions from code review
glemaitre Jul 14, 2023
767a05f
CLN clean up some repeated code related to SLEP006
adrinjalali Jul 14, 2023
080ba5c
iter
glemaitre Jul 14, 2023
05ec85d
iter
glemaitre Jul 14, 2023
63c32bd
ENH add new response_method in make_scorer
glemaitre Jul 15, 2023
1584c5b
add non-regression test
glemaitre Jul 15, 2023
1a5a247
update validation param
glemaitre Jul 15, 2023
4cc61b9
more coverage
glemaitre Jul 15, 2023
8f36235
TST add mulitlabel test
glemaitre Jul 15, 2023
9e6b384
Merge branch 'make_scorer_list_response' into cutoff_classifier_again
glemaitre Jul 15, 2023
5490ce4
simplify scorer
glemaitre Jul 15, 2023
8dad0a4
iter
glemaitre Jul 15, 2023
b918708
remove unecessary part in doc
glemaitre Jul 15, 2023
5e23523
iter
glemaitre Jul 15, 2023
d5578f9
iter
glemaitre Jul 16, 2023
f3f844e
address tim comments
glemaitre Jul 24, 2023
44ad195
Merge remote-tracking branch 'origin/main' into make_scorer_list_resp…
glemaitre Jul 24, 2023
e489eab
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 24, 2023
1cf5528
Merge branch 'make_scorer_list_response' into cutoff_classifier_again
glemaitre Jul 24, 2023
26dc94e
iter
glemaitre Jul 26, 2023
6a1a6c7
iter
glemaitre Jul 26, 2023
b17b59e
iter
glemaitre Jul 28, 2023
43c1da8
iter
glemaitre Jul 28, 2023
41a6d07
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 28, 2023
ca06717
iter
glemaitre Jul 28, 2023
ab8b466
iter
glemaitre Jul 30, 2023
235abf5
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Aug 7, 2023
d9ec528
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Aug 11, 2023
69f60a6
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Sep 28, 2023
45a8504
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Oct 17, 2023
8c4c88d
solve deprecation
glemaitre Oct 17, 2023
b97ebf4
Merge remote-tracking branch 'glemaitre/cutoff_classifier_again' into…
glemaitre Oct 17, 2023
e37f831
update changelog
glemaitre Oct 17, 2023
383937f
whoops
glemaitre Oct 17, 2023
d4ce3fb
Update sklearn/metrics/_scorer.py
glemaitre Oct 17, 2023
b6b3548
fix doc
glemaitre Oct 17, 2023
759d680
remove useless fitted attributes
glemaitre Oct 18, 2023
23e65e6
Merge branch 'main' into cutoff_classifier_again
ogrisel Dec 4, 2023
6904817
bump pandas to 1.1.5
glemaitre Dec 4, 2023
bee1ebe
update lock file
glemaitre Dec 4, 2023
4d86a36
iter
glemaitre Jan 13, 2024
48fd7cd
update doc-min lock file
glemaitre Jan 13, 2024
0854cd4
partial reviews
glemaitre Jan 13, 2024
ac75300
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Mar 18, 2024
2df616e
Apply suggestions from code review
glemaitre Mar 18, 2024
b14225c
update lock files
glemaitre Mar 18, 2024
98dcefd
Merge remote-tracking branch 'glemaitre/cutoff_classifier_again' into…
glemaitre Mar 18, 2024
7e3d7aa
iter
glemaitre Mar 18, 2024
b958bb0
simplify refit and do not allow cv == 1
glemaitre Mar 19, 2024
e7722f6
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Mar 19, 2024
98a1db8
check raise for multilabel
glemaitre Mar 19, 2024
c28a3e1
fix test name
glemaitre Mar 19, 2024
076fd29
test another to check if beta is forwarded
glemaitre Mar 19, 2024
e728f1d
iter
glemaitre Mar 20, 2024
c73b205
refit=True and cv is float
glemaitre Mar 20, 2024
a4890df
rename scorer to curve scorer internally
glemaitre Mar 22, 2024
f8a5a79
add a note regarding the abuse of the scorer API
glemaitre Mar 22, 2024
5dfa435
use None instead of highest
glemaitre Mar 22, 2024
d45a71b
use a closer CV API
glemaitre Mar 23, 2024
a32c151
fix example
glemaitre Mar 23, 2024
7592437
simplify model
glemaitre Mar 23, 2024
dc5346b
fix
glemaitre Mar 23, 2024
843ca04
fix docstring
glemaitre Mar 23, 2024
51ed9a8
Apply suggestions from code review
glemaitre Mar 30, 2024
3c89ab3
Apply suggestions from code review
glemaitre Apr 2, 2024
8cd5582
pep8
glemaitre Apr 2, 2024
a48487c
rephrase suggestions
glemaitre Apr 2, 2024
dd18549
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 4, 2024
27515ca
fix
glemaitre Apr 8, 2024
8a87b26
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 8, 2024
811dec9
include and discuss more about amount
glemaitre Apr 8, 2024
c83b4e1
iter
glemaitre Apr 8, 2024
d4e232f
Apply suggestions from code review
glemaitre Apr 25, 2024
92f6e05
Update examples/model_selection/plot_tuned_decision_threshold.py
glemaitre Apr 25, 2024
1c5c3f4
iter
glemaitre Apr 25, 2024
94160ba
iter
glemaitre Apr 25, 2024
85c8484
other comment
glemaitre Apr 25, 2024
4f86e9d
addressed comments
glemaitre Apr 25, 2024
d747098
Apply suggestions from code review
glemaitre Apr 27, 2024
6d0f418
rename TunedThresholdClassifier to TunedThresholdClassifierCV
glemaitre Apr 27, 2024
5671dd6
use meaningful values for check the thresholds values depending on po…
glemaitre Apr 27, 2024
f04085d
TST add more info regarding why not exactly 0 and 1
glemaitre Apr 27, 2024
d179b5f
DOC add documentation for base scorer
glemaitre Apr 27, 2024
2c375f8
DOC add more details regarding the curve scorer
glemaitre Apr 27, 2024
66ba8da
directly test curve_scorer instead to look for function anem
glemaitre Apr 27, 2024
a6b19c1
add required arguments
glemaitre Apr 27, 2024
b3b99ff
DOC add docstring for interpolated score
glemaitre Apr 27, 2024
dda0d2c
Update sklearn/model_selection/tests/test_classification_threshold.py
glemaitre Apr 29, 2024
48e7829
Update sklearn/model_selection/tests/test_classification_threshold.py
glemaitre Apr 29, 2024
17839e8
Apply suggestions from code review
glemaitre Apr 29, 2024
3f02bc3
remove duplicated check
glemaitre Apr 29, 2024
553cfce
remove duplicated check
glemaitre Apr 29, 2024
6ae6d27
check cv_results_ API
glemaitre Apr 29, 2024
d010096
clone classifier
glemaitre Apr 29, 2024
8bb8ca6
TST better comments
glemaitre Apr 29, 2024
fd971c7
iter
glemaitre Apr 29, 2024
bf57dac
FEA add a ConstantThresholdClassifier instead of strategy="constant" …
glemaitre Apr 30, 2024
0409932
make FixedThresholdClassifier appear in example
glemaitre Apr 30, 2024
66ea575
iter
glemaitre Apr 30, 2024
1c97dd4
Update doc/modules/classes.rst
glemaitre Apr 30, 2024
c8c1d0c
Update sklearn/model_selection/_classification_threshold.py
glemaitre Apr 30, 2024
8a52bc6
TST and fix default parameter
glemaitre Apr 30, 2024
ef668cf
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 30, 2024
fdbf68e
TST metadarouting FixedThresholdClassifier
glemaitre Apr 30, 2024
9c0c13d
rename n_thresholds to thresholds
glemaitre Apr 30, 2024
f419371
cover constant predictor error
glemaitre Apr 30, 2024
42eafe5
TST some tests for get_response_values_binary
glemaitre Apr 30, 2024
581133f
use conditional p(y|X) instead of posterior
glemaitre May 2, 2024
0f803d9
be more explicit that strings need to be provided to objective_metric
glemaitre May 2, 2024
eb0defc
factorize plotting into a function
glemaitre May 2, 2024
ffd5669
fix typo in code
glemaitre May 2, 2024
18abafe
use proper scoring rule and robust estimator to scale
glemaitre May 2, 2024
ce9464c
improve narrative
glemaitre May 2, 2024
89d67cf
use grid-search
glemaitre May 2, 2024
db3360b
Apply suggestions from code review
glemaitre May 2, 2024
1789cc0
remove constrainted metrics option
glemaitre May 3, 2024
e7c31b9
partial review
glemaitre May 3, 2024
0fd667c
rename objective_metric to scoring
glemaitre May 3, 2024
07e4387
fix typo
glemaitre May 3, 2024
9bd68e6
remove pos_label and delegate to make_scorer
glemaitre May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/model_selection.rst
Expand Up @@ -14,5 +14,6 @@ Model selection and evaluation

modules/cross_validation
modules/grid_search
modules/classification_threshold
modules/model_evaluation
modules/learning_curve
11 changes: 11 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -1248,6 +1248,17 @@ Hyper-parameter optimizers
model_selection.RandomizedSearchCV
model_selection.HalvingRandomSearchCV

Post-fit model tuning
---------------------

.. currentmodule:: sklearn

.. autosummary::
:toctree: generated/
:template: class.rst

model_selection.FixedThresholdClassifier
model_selection.TunedThresholdClassifierCV

Model validation
----------------
Expand Down
156 changes: 156 additions & 0 deletions doc/modules/classification_threshold.rst
@@ -0,0 +1,156 @@
.. currentmodule:: sklearn.model_selection

.. _TunedThresholdClassifierCV:

==================================================
Tuning the decision threshold for class prediction
==================================================

Classification is best divided into two parts:

* the statistical problem of learning a model to predict, ideally, class probabilities;
* the decision problem to take concrete action based on those probability predictions.

Let's take a straightforward example related to weather forecasting: the first point is
related to answering "what is the chance that it will rain tomorrow?" while the second
point is related to answering "should I take an umbrella tomorrow?".

When it comes to the scikit-learn API, the first point is addressed providing scores
using :term:`predict_proba` or :term:`decision_function`. The former returns conditional
probability estimates :math:`P(y|X)` for each class, while the latter returns a decision
score for each class.

The decision corresponding to the labels are obtained with :term:`predict`. In binary
classification, a decision rule or action is then defined by thresholding the scores,
leading to the prediction of a single class label for each sample. For binary
classification in scikit-learn, class labels predictions are obtained by hard-coded
cut-off rules: a positive class is predicted when the conditional probability
:math:`P(y|X)` is greater than 0.5 (obtained with :term:`predict_proba`) or if the
decision score is greater than 0 (obtained with :term:`decision_function`).

Here, we show an example that illustrates the relation between conditional
probability estimates :math:`P(y|X)` and class labels::

>>> from sklearn.datasets import make_classification
>>> from sklearn.tree import DecisionTreeClassifier
>>> X, y = make_classification(random_state=0)
>>> classifier = DecisionTreeClassifier(max_depth=2, random_state=0).fit(X, y)
>>> classifier.predict_proba(X[:4])
array([[0.94 , 0.06 ],
[0.94 , 0.06 ],
[0.0416..., 0.9583...],
[0.0416..., 0.9583...]])
>>> classifier.predict(X[:4])
array([0, 0, 1, 1])

While these hard-coded rules might at first seem reasonable as default behavior, they
are most certainly not ideal for most use cases. Let's illustrate with an example.

Consider a scenario where a predictive model is being deployed to assist
physicians in detecting tumors. In this setting, physicians will most likely be
interested in identifying all patients with cancer and not missing anyone with cancer so
that they can provide them with the right treatment. In other words, physicians
prioritize achieving a high recall rate. This emphasis on recall comes, of course, with
the trade-off of potentially more false-positive predictions, reducing the precision of
the model. That is a risk physicians are willing to take because the cost of a missed
cancer is much higher than the cost of further diagnostic tests. Consequently, when it
comes to deciding whether to classify a patient as having cancer or not, it may be more
beneficial to classify them as positive for cancer when the conditional probability
estimate is much lower than 0.5.

glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Post-tuning the decision threshold
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned in the other example, I would introduce the roc curve earlier in the explanation.

==================================

One solution to address the problem stated in the introduction is to tune the decision
threshold of the classifier once the model has been trained. The
:class:`~sklearn.model_selection.TunedThresholdClassifierCV` tunes this threshold using
an internal cross-validation. The optimum threshold is chosen to maximize a given
metric.

The following image illustrates the tuning of the decision threshold for a gradient
boosting classifier. While the vanilla and tuned classifiers provide the same
:term:`predict_proba` outputs and thus the same Receiver Operating Characteristic (ROC)
and Precision-Recall curves, the class label predictions differ because of the tuned
decision threshold. The vanilla classifier predicts the class of interest for a
conditional probability greater than 0.5 while the tuned classifier predicts the class
of interest for a very low probability (around 0.02). This decision threshold optimizes
a utility metric defined by the business (in this case an insurance company).

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cost_sensitive_learning_002.png
:target: ../auto_examples/model_selection/plot_cost_sensitive_learning.html
:align: center

Options to tune the decision threshold
--------------------------------------

The decision threshold can be tuned through different strategies controlled by the
parameter `scoring`.

One way to tune the threshold is by maximizing a pre-defined scikit-learn metric. These
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe mention that a common tuning is picking the top right point on the ROC curve which is the same as picking f2 score (I think?) here. Or maybe mention that that has a nice geometric explanation but doesn't really consider the application.

metrics can be found by calling the function :func:`~sklearn.metrics.get_scorer_names`.
By default, the balanced accuracy is the metric used but be aware that one should choose
a meaningful metric for their use case.

.. note::

It is important to notice that these metrics come with default parameters, notably
the label of the class of interest (i.e. `pos_label`). Thus, if this label is not
the right one for your application, you need to define a scorer and pass the right
`pos_label` (and additional parameters) using the
:func:`~sklearn.metrics.make_scorer`. Refer to :ref:`scoring` to get
information to define your own scoring function. For instance, we show how to pass
the information to the scorer that the label of interest is `0` when maximizing the
:func:`~sklearn.metrics.f1_score`::

>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import TunedThresholdClassifierCV
>>> from sklearn.metrics import make_scorer, f1_score
>>> X, y = make_classification(
... n_samples=1_000, weights=[0.1, 0.9], random_state=0)
>>> pos_label = 0
>>> scorer = make_scorer(f1_score, pos_label=pos_label)
>>> base_model = LogisticRegression()
>>> model = TunedThresholdClassifierCV(base_model, scoring=scorer)
>>> scorer(model.fit(X, y), X, y)
0.88...
>>> # compare it with the internal score found by cross-validation
>>> model.best_score_
0.86...

Important notes regarding the internal cross-validation
-------------------------------------------------------

By default :class:`~sklearn.model_selection.TunedThresholdClassifierCV` uses a 5-fold
stratified cross-validation to tune the decision threshold. The parameter `cv` allows to
control the cross-validation strategy. It is possible to bypass cross-validation by
setting `cv="prefit"` and providing a fitted classifier. In this case, the decision
threshold is tuned on the data provided to the `fit` method.

However, you should be extremely careful when using this option. You should never use
the same data for training the classifier and tuning the decision threshold due to the
risk of overfitting. Refer to the following example section for more details (cf.
:ref:`TunedThresholdClassifierCV_no_cv`). If you have limited resources, consider using
a float number for `cv` to limit to an internal single train-test split.

glemaitre marked this conversation as resolved.
Show resolved Hide resolved
The option `cv="prefit"` should only be used when the provided classifier was already
trained, and you just want to find the best decision threshold using a new validation
set.

.. _FixedThresholdClassifier:

Manually setting the decision threshold
---------------------------------------

The previous sections discussed strategies to find an optimal decision threshold. It is
also possible to manually set the decision threshold using the class
:class:`~sklearn.model_selection.FixedThresholdClassifier`.

Examples
--------

- See the example entitled
:ref:`sphx_glr_auto_examples_model_selection_plot_tuned_decision_threshold.py`,
to get insights on the post-tuning of the decision threshold.
- See the example entitled
:ref:`sphx_glr_auto_examples_model_selection_plot_cost_sensitive_learning.py`,
to learn about cost-sensitive learning and decision threshold tuning.
7 changes: 7 additions & 0 deletions doc/whats_new/v1.5.rst
Expand Up @@ -412,6 +412,13 @@ Changelog
:mod:`sklearn.model_selection`
..............................

- |MajorFeature| :class:`model_selection.TunedThresholdClassifierCV` finds
the decision threshold of a binary classifier that maximizes a
classification metric through cross-validation.
:class:`model_selection.FixedThresholdClassifier` is an alternative when one wants
to use a fixed decision threshold without any tuning scheme.
:pr:`26120` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Enhancement| :term:`CV splitters <CV splitter>` that ignores the group parameter now
raises a warning when groups are passed in to :term:`split`. :pr:`28210` by
`Thomas Fan`_.
Expand Down