Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA add TunedThresholdClassifier meta-estimator to post-tune the cut-off threshold #26120

Merged
merged 228 commits into from May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
228 commits
Select commit Hold shift + click to select a range
b44dd9d
MAINT refactor scorer using _get_response_values
glemaitre Mar 31, 2023
516f62f
Add __name__ for method of Mock
glemaitre Apr 1, 2023
d2fbee0
remove multiclass issue
glemaitre Apr 1, 2023
29e5e87
make response_method a mandatory arg
glemaitre Apr 3, 2023
b645ade
Update sklearn/metrics/_scorer.py
glemaitre Apr 3, 2023
3397c56
apply jeremie comments
glemaitre Apr 3, 2023
092689a
Merge branch 'main' into is/18589_restart
glemaitre Apr 3, 2023
200ec31
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 4, 2023
e871558
iter
glemaitre Apr 4, 2023
31aa1c0
Merge remote-tracking branch 'glemaitre/is/18589_restart' into cutoff…
glemaitre Apr 4, 2023
74614e8
FEA add CutOffClassifier to post-tune prediction threshold
glemaitre Apr 7, 2023
27713af
DOC add changelog entry
glemaitre Apr 7, 2023
ed1d9b3
refresh implementation
glemaitre Apr 7, 2023
8410317
add files
glemaitre Apr 7, 2023
c7d1fe4
remove random state for the moment
glemaitre Apr 7, 2023
c9d7a22
TST make sure to pass the common test
glemaitre Apr 15, 2023
9981f3a
TST metaestimator sample_weight
glemaitre Apr 15, 2023
b9c9d5e
API add prediction functions
glemaitre Apr 15, 2023
588f1c4
TST bypass the test for classification
glemaitre Apr 17, 2023
243d173
iter before another bug
glemaitre Apr 17, 2023
883e929
iter
glemaitre Apr 18, 2023
69333ed
TST add test for _fit_and_score
glemaitre Apr 19, 2023
8616da1
iter
glemaitre Apr 19, 2023
99a10b3
integrate refit
glemaitre Apr 19, 2023
0f6dce2
TST more test
glemaitre Apr 19, 2023
d6fb9f7
TST more test with sample_weight
glemaitre Apr 19, 2023
7ff3d0d
BUG fit_params split
glemaitre Apr 19, 2023
6985ae9
TST add test for fit_params
glemaitre Apr 19, 2023
239793a
TST check underlying response method for TNR/TPR
glemaitre Apr 20, 2023
92083ed
FEA add the possibility to provide a dict
glemaitre Apr 20, 2023
55d0844
TST check string and pos_label interation for cost-matrix
glemaitre Apr 20, 2023
7dfc4a6
TST add sample_weight test for cost-matrix
glemaitre Apr 20, 2023
729c9a8
iter
glemaitre Apr 20, 2023
787be21
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 24, 2023
146b170
change strategy for finding max
glemaitre Apr 24, 2023
03b1f7f
iter
glemaitre Apr 24, 2023
8a09a5f
add some test for precision-recall
glemaitre Apr 24, 2023
d56f57f
TST add invariance zeros weight
glemaitre Apr 24, 2023
cf164c5
DOC fix default n_thresholds
glemaitre Apr 24, 2023
c943f5e
DOC add a small example
glemaitre Apr 25, 2023
bf1462b
iter
glemaitre Apr 25, 2023
fa89431
iter
glemaitre Apr 25, 2023
862519d
bug fixes everywhere
glemaitre Apr 25, 2023
aa520da
iter
glemaitre Apr 25, 2023
5403cf6
Do not allow for single threshold
glemaitre Apr 26, 2023
cd37743
TST add random state checkingclassifier
glemaitre Apr 26, 2023
e7d07af
TST more test for _ContinuousScorer
glemaitre Apr 26, 2023
bc20a47
TST add test for pos_label
glemaitre Apr 26, 2023
bba2f97
TST add pos_label test for TNR/TPR
glemaitre Apr 26, 2023
f925503
some more
glemaitre Apr 26, 2023
d539235
avoid extrapolation
glemaitre Apr 27, 2023
c0acd44
FEA add all thresholds and score computed as attributes
glemaitre Apr 27, 2023
f87baa7
fix docstring
glemaitre Apr 28, 2023
e4dac09
EXA add example of cut-off tuning
glemaitre Apr 28, 2023
4da7cef
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 28, 2023
bd86595
solving the issue of unknown categories
glemaitre Apr 28, 2023
45e6e5a
fix
glemaitre Apr 28, 2023
402a1a7
EXA add hyperlink in the example
glemaitre Apr 29, 2023
6745afc
DOC add warning regarding overfitting
glemaitre Apr 29, 2023
4d557cc
some more doc
glemaitre Apr 29, 2023
2c6ee7e
some more doc
glemaitre Apr 29, 2023
91c8222
DOC more documentation
glemaitre May 2, 2023
9a96ae1
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre May 2, 2023
3d4ce81
fix import
glemaitre May 2, 2023
d7d8dac
fix import
glemaitre May 2, 2023
aa3e83d
iter
glemaitre May 2, 2023
ab97d63
fix
glemaitre May 2, 2023
acb6af8
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre May 4, 2023
486a2bd
Update sklearn/metrics/_scorer.py
glemaitre May 4, 2023
6d4c4aa
Apply suggestions from code review
glemaitre May 15, 2023
1d12e1f
Fix linter
ogrisel Jun 1, 2023
21e20e0
Merge branch 'main' into cutoff_classifier_again
ogrisel Jun 1, 2023
7952cce
Add routing to LogisticRegressionCV
Jun 7, 2023
66ad513
Add a test with enable_metadata_routing=False and fix an issue in sco…
Jun 7, 2023
7e8b824
Add metaestimator tests and fix passing routed params in score method
Jun 13, 2023
d7e50a6
PR suggestions
Jun 25, 2023
3844706
Merge branch 'main' into logistic_cv_routing
Jun 26, 2023
0866c42
Add changelog entry
Jun 26, 2023
43f971b
Add user and pr information
Jun 26, 2023
db63769
Changelog adjustment
Jun 26, 2023
a9b984f
Remove repr method from ConsumingScorer
Jun 26, 2023
97105a4
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 3, 2023
52f5921
handle the np.inf case in roc-curve
glemaitre Jul 3, 2023
637c18e
Merge branch 'main' into logistic_cv_routing
adrinjalali Jul 7, 2023
314bc83
Adjust changelog
Jul 7, 2023
9a8ef4e
Add tests for error when passing params when routing not enabled in L…
OmarManzoor Jul 10, 2023
5b723a0
Address PR suggestions partially
Jul 13, 2023
9ce463d
address comment Tim
glemaitre Jul 13, 2023
bba8f55
iter
glemaitre Jul 13, 2023
1c5487d
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 13, 2023
d302678
MAINT rename modules as per olivier comment
glemaitre Jul 13, 2023
2ade221
add missing module
glemaitre Jul 13, 2023
dca5770
update changelog
glemaitre Jul 13, 2023
8897533
more renaming
glemaitre Jul 13, 2023
75bd7ac
iter
glemaitre Jul 13, 2023
c07a980
Adjust and change the name of params in _check_method_params
OmarManzoor Jul 13, 2023
cc5ba48
Resolve conflict in changelog
OmarManzoor Jul 13, 2023
66c4c7f
iter
glemaitre Jul 13, 2023
378930e
iter
glemaitre Jul 13, 2023
c88ed94
iter
glemaitre Jul 13, 2023
4715e67
iter
glemaitre Jul 13, 2023
b3bb39f
iter
glemaitre Jul 13, 2023
915624a
Merge branch 'main' into logistic_cv_routing
glemaitre Jul 13, 2023
b72a72a
iter
glemaitre Jul 13, 2023
5108e43
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 13, 2023
95150b0
Merge branch 'pr/OmarManzoor/26525' into cutoff_with_metadata_routing
glemaitre Jul 13, 2023
5b66ab8
Merge remote-tracking branch 'origin/main' into cutoff_with_metadata_…
glemaitre Jul 14, 2023
b4e67fb
Add metadata routing
glemaitre Jul 14, 2023
005126a
Apply suggestions from code review
glemaitre Jul 14, 2023
767a05f
CLN clean up some repeated code related to SLEP006
adrinjalali Jul 14, 2023
080ba5c
iter
glemaitre Jul 14, 2023
05ec85d
iter
glemaitre Jul 14, 2023
63c32bd
ENH add new response_method in make_scorer
glemaitre Jul 15, 2023
1584c5b
add non-regression test
glemaitre Jul 15, 2023
1a5a247
update validation param
glemaitre Jul 15, 2023
4cc61b9
more coverage
glemaitre Jul 15, 2023
8f36235
TST add mulitlabel test
glemaitre Jul 15, 2023
9e6b384
Merge branch 'make_scorer_list_response' into cutoff_classifier_again
glemaitre Jul 15, 2023
5490ce4
simplify scorer
glemaitre Jul 15, 2023
8dad0a4
iter
glemaitre Jul 15, 2023
b918708
remove unecessary part in doc
glemaitre Jul 15, 2023
5e23523
iter
glemaitre Jul 15, 2023
d5578f9
iter
glemaitre Jul 16, 2023
f3f844e
address tim comments
glemaitre Jul 24, 2023
44ad195
Merge remote-tracking branch 'origin/main' into make_scorer_list_resp…
glemaitre Jul 24, 2023
e489eab
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 24, 2023
1cf5528
Merge branch 'make_scorer_list_response' into cutoff_classifier_again
glemaitre Jul 24, 2023
26dc94e
iter
glemaitre Jul 26, 2023
6a1a6c7
iter
glemaitre Jul 26, 2023
b17b59e
iter
glemaitre Jul 28, 2023
43c1da8
iter
glemaitre Jul 28, 2023
41a6d07
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 28, 2023
ca06717
iter
glemaitre Jul 28, 2023
ab8b466
iter
glemaitre Jul 30, 2023
235abf5
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Aug 7, 2023
d9ec528
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Aug 11, 2023
69f60a6
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Sep 28, 2023
45a8504
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Oct 17, 2023
8c4c88d
solve deprecation
glemaitre Oct 17, 2023
b97ebf4
Merge remote-tracking branch 'glemaitre/cutoff_classifier_again' into…
glemaitre Oct 17, 2023
e37f831
update changelog
glemaitre Oct 17, 2023
383937f
whoops
glemaitre Oct 17, 2023
d4ce3fb
Update sklearn/metrics/_scorer.py
glemaitre Oct 17, 2023
b6b3548
fix doc
glemaitre Oct 17, 2023
759d680
remove useless fitted attributes
glemaitre Oct 18, 2023
23e65e6
Merge branch 'main' into cutoff_classifier_again
ogrisel Dec 4, 2023
6904817
bump pandas to 1.1.5
glemaitre Dec 4, 2023
bee1ebe
update lock file
glemaitre Dec 4, 2023
4d86a36
iter
glemaitre Jan 13, 2024
48fd7cd
update doc-min lock file
glemaitre Jan 13, 2024
0854cd4
partial reviews
glemaitre Jan 13, 2024
ac75300
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Mar 18, 2024
2df616e
Apply suggestions from code review
glemaitre Mar 18, 2024
b14225c
update lock files
glemaitre Mar 18, 2024
98dcefd
Merge remote-tracking branch 'glemaitre/cutoff_classifier_again' into…
glemaitre Mar 18, 2024
7e3d7aa
iter
glemaitre Mar 18, 2024
b958bb0
simplify refit and do not allow cv == 1
glemaitre Mar 19, 2024
e7722f6
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Mar 19, 2024
98a1db8
check raise for multilabel
glemaitre Mar 19, 2024
c28a3e1
fix test name
glemaitre Mar 19, 2024
076fd29
test another to check if beta is forwarded
glemaitre Mar 19, 2024
e728f1d
iter
glemaitre Mar 20, 2024
c73b205
refit=True and cv is float
glemaitre Mar 20, 2024
a4890df
rename scorer to curve scorer internally
glemaitre Mar 22, 2024
f8a5a79
add a note regarding the abuse of the scorer API
glemaitre Mar 22, 2024
5dfa435
use None instead of highest
glemaitre Mar 22, 2024
d45a71b
use a closer CV API
glemaitre Mar 23, 2024
a32c151
fix example
glemaitre Mar 23, 2024
7592437
simplify model
glemaitre Mar 23, 2024
dc5346b
fix
glemaitre Mar 23, 2024
843ca04
fix docstring
glemaitre Mar 23, 2024
51ed9a8
Apply suggestions from code review
glemaitre Mar 30, 2024
3c89ab3
Apply suggestions from code review
glemaitre Apr 2, 2024
8cd5582
pep8
glemaitre Apr 2, 2024
a48487c
rephrase suggestions
glemaitre Apr 2, 2024
dd18549
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 4, 2024
27515ca
fix
glemaitre Apr 8, 2024
8a87b26
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 8, 2024
811dec9
include and discuss more about amount
glemaitre Apr 8, 2024
c83b4e1
iter
glemaitre Apr 8, 2024
d4e232f
Apply suggestions from code review
glemaitre Apr 25, 2024
92f6e05
Update examples/model_selection/plot_tuned_decision_threshold.py
glemaitre Apr 25, 2024
1c5c3f4
iter
glemaitre Apr 25, 2024
94160ba
iter
glemaitre Apr 25, 2024
85c8484
other comment
glemaitre Apr 25, 2024
4f86e9d
addressed comments
glemaitre Apr 25, 2024
d747098
Apply suggestions from code review
glemaitre Apr 27, 2024
6d0f418
rename TunedThresholdClassifier to TunedThresholdClassifierCV
glemaitre Apr 27, 2024
5671dd6
use meaningful values for check the thresholds values depending on po…
glemaitre Apr 27, 2024
f04085d
TST add more info regarding why not exactly 0 and 1
glemaitre Apr 27, 2024
d179b5f
DOC add documentation for base scorer
glemaitre Apr 27, 2024
2c375f8
DOC add more details regarding the curve scorer
glemaitre Apr 27, 2024
66ba8da
directly test curve_scorer instead to look for function anem
glemaitre Apr 27, 2024
a6b19c1
add required arguments
glemaitre Apr 27, 2024
b3b99ff
DOC add docstring for interpolated score
glemaitre Apr 27, 2024
dda0d2c
Update sklearn/model_selection/tests/test_classification_threshold.py
glemaitre Apr 29, 2024
48e7829
Update sklearn/model_selection/tests/test_classification_threshold.py
glemaitre Apr 29, 2024
17839e8
Apply suggestions from code review
glemaitre Apr 29, 2024
3f02bc3
remove duplicated check
glemaitre Apr 29, 2024
553cfce
remove duplicated check
glemaitre Apr 29, 2024
6ae6d27
check cv_results_ API
glemaitre Apr 29, 2024
d010096
clone classifier
glemaitre Apr 29, 2024
8bb8ca6
TST better comments
glemaitre Apr 29, 2024
fd971c7
iter
glemaitre Apr 29, 2024
bf57dac
FEA add a ConstantThresholdClassifier instead of strategy="constant" …
glemaitre Apr 30, 2024
0409932
make FixedThresholdClassifier appear in example
glemaitre Apr 30, 2024
66ea575
iter
glemaitre Apr 30, 2024
1c97dd4
Update doc/modules/classes.rst
glemaitre Apr 30, 2024
c8c1d0c
Update sklearn/model_selection/_classification_threshold.py
glemaitre Apr 30, 2024
8a52bc6
TST and fix default parameter
glemaitre Apr 30, 2024
ef668cf
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 30, 2024
fdbf68e
TST metadarouting FixedThresholdClassifier
glemaitre Apr 30, 2024
9c0c13d
rename n_thresholds to thresholds
glemaitre Apr 30, 2024
f419371
cover constant predictor error
glemaitre Apr 30, 2024
42eafe5
TST some tests for get_response_values_binary
glemaitre Apr 30, 2024
581133f
use conditional p(y|X) instead of posterior
glemaitre May 2, 2024
0f803d9
be more explicit that strings need to be provided to objective_metric
glemaitre May 2, 2024
eb0defc
factorize plotting into a function
glemaitre May 2, 2024
ffd5669
fix typo in code
glemaitre May 2, 2024
18abafe
use proper scoring rule and robust estimator to scale
glemaitre May 2, 2024
ce9464c
improve narrative
glemaitre May 2, 2024
89d67cf
use grid-search
glemaitre May 2, 2024
db3360b
Apply suggestions from code review
glemaitre May 2, 2024
1789cc0
remove constrainted metrics option
glemaitre May 3, 2024
e7c31b9
partial review
glemaitre May 3, 2024
0fd667c
rename objective_metric to scoring
glemaitre May 3, 2024
07e4387
fix typo
glemaitre May 3, 2024
9bd68e6
remove pos_label and delegate to make_scorer
glemaitre May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/model_selection.rst
Expand Up @@ -14,5 +14,6 @@ Model selection and evaluation

modules/cross_validation
modules/grid_search
modules/prediction
modules/model_evaluation
modules/learning_curve
10 changes: 10 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -1218,6 +1218,16 @@ Hyper-parameter optimizers
model_selection.RandomizedSearchCV
model_selection.HalvingRandomSearchCV

Model post-fit tuning
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
---------------------

.. currentmodule:: sklearn

.. autosummary::
:toctree: generated/
:template: class.rst

model_selection.CutOffClassifier

Model validation
----------------
Expand Down
155 changes: 155 additions & 0 deletions doc/modules/prediction.rst
@@ -0,0 +1,155 @@
.. currentmodule:: sklearn.model_selection

.. _cutoffclassifier:

========================================================
Tuning cut-off decision threshold for classes prediction
========================================================

Classifiers are predictive models: they use statistical learning to predict
outcomes. The outcomes of a classifier takes two forms: a "soft" score for each
sample in relation to each class, and a "hard" categorical prediction (i.e.
class label). Soft predictions are obtained using :term:`predict_proba` or
:term:`decision_function` while hard predictions are obtained using
:term:`predict`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying out some slightly different wording. Doing two comments one for this paragraph and one for the one below. I somehow don't like the "soft" and "hard". The way I think about it is that a classification model predicts a score and then we threshold that score to assign class labels.

WDYT?

Suggested change
outcomes. The outcomes of a classifier takes two forms: a "soft" score for each
sample in relation to each class, and a "hard" categorical prediction (i.e.
class label). Soft predictions are obtained using :term:`predict_proba` or
:term:`decision_function` while hard predictions are obtained using
:term:`predict`.
outcomes. The output of a classifier takes two forms: a score for each
sample in relation to each class and a categorical prediction (class label).
Scores are obtained using :term:`predict_proba` or
:term:`decision_function` while class labels are obtained using
:term:`predict`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you accept the change from "soft" and "hard" we need to fix up some later parts of the docs as well which refer to it. I can make suggestions for that if you like it, otherwise I won't spam this with useless suggestions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't like soft and hard. The statistical terms are

  • predict: deterministic classification or decision rule
  • predict_proba: probabilistic classification

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the point here. I will reformulate accordingly.

@lorentzenchr In statistical terms, how would you refer to the output of decision_function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beware that I'm not the statistical reference. Here is my understanding:
decision_function is often nothing more than a technical detail, what matters in statistical terms is predict_proba. For estimators with predict_proba, decision_function is often the predictions in link space (linear predictor for linear models, raw predictions in our HGBT code). ML literatur often calls it "score" which confuses me every time.

TLDR: I don't have a good universal name for it, but it is certainly not a "decision function", our predict is a decision function in statistical terms.


In scikit-learn, there is a connection between soft and hard prediction. In the
case of a binary classification, hard predictions are obtained by associating
the positive class with probability value greater than 0.5 (obtained with
:term:`predict_proba`) or decision function value greater than 0 (obtained with
:term:`decision_function`).
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

>>> from sklearn.datasets import make_classification
>>> from sklearn.tree import DecisionTreeClassifier
>>> X, y = make_classification(random_state=0)
>>> classifier = DecisionTreeClassifier(max_depth=2, random_state=0).fit(X, y)
>>> classifier.predict_proba(X[:4])
array([[0.94 , 0.06 ],
[0.94 , 0.06 ],
[0.04..., 0.95...],
[0.04..., 0.95...]])
>>> classifier.predict(X[:4])
array([0, 0, 1, 1])


Similar rules apply for other classification problems.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

While these approaches are reasonable as default behaviors, they might not be
adapted to some cases. The context and nature of the use case define the
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
expected behavior of the classifier and thus the strategy to convert soft
predictions into hard predictions. We illustrate this point with an example.

Let's imagine the deployment of a predictive model helping medical doctors to
detect cancers. In a setting where this model would be a tool to discard
obvious cases, doctors might be interested to have a high recall (all cancers
cases should be tagged as such) to not miss any patient with a cancer. However,
it will be at the cost of having more false positive predictions (i.e. lower
precision). Thus, in terms of decision threshold, it would be better to
classify a patient having a cancer for a lower probability than 0.5.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

Post-tuning of the decision threshold
=====================================

One solution to address the problem stated in the introduction is to tune the decision
threshold of the classifier once this model has been trained. The
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
:class:`~sklearn.model_selection.CutOffClassifier` allows to tune this threshold using
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
an internal cross-validation. The optimum threshold is tuned to maximize a given metric
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
with or without constraints.

The following image illustrate the tuning of the cut-off point for a gradient
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
boosting classifier. While the vanilla and tuned classifiers provide the same
Receiver Operating Characteristic (ROC) and Precision-Recall curves, and thus
the same :term:`predict_proba` outputs, the "hard" predictions defer because of
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
the tuned cut-off point. The vanilla classifier predicts the class of interest
for a probability greater than 0.5 while the tuned classifier predicts the
class of interest for a very low probability (around 0.02). This cut-off point
is maximizes a utility metric defined by the business case (in this case an
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
insurance company).

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cutoff_tuning_002.png
:target: ../auto_examples/model_selection/plot_cutoff_tuning.html
:align: center

Available options to tune the cut-off point
-------------------------------------------

The cut-off point can be tuned with different strategies controlled by the parameter
`objective_metric`.

A straightforward use case is to maximize a pre-defined scikit-learn metric. These
metrics can be found by calling the function :func:`~sklearn.metrics.get_scorer_names`.
We provide an example where we maximize the balanced accuracy.

.. note::

It is important to notice that these metrics comes with default parameter, notably
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
the label of the class of interested (i.e. `pos_label`). Thus, if this label is not
the right one for your application, you need to define a scorer and pass the right
`pos_label` (and additional parameters) using the
:func:`~sklearn.metrics.make_scorer`. You should refer to :ref:`scoring` to get all
information to define your own scoring function. For instance, we show how to pass
the information to the scorer that the label of interest is `0` when maximizing the
:func:`~sklearn.metrics.f1_score`:

>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import CutOffClassifier, train_test_split
>>> from sklearn.metrics import make_scorer, f1_score
>>> X, y = make_classification(
... n_samples=1_000, weights=[0.1, 0.9], random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
>>> pos_label = 0
>>> scorer = make_scorer(f1_score, pos_label=pos_label)
>>> model = CutOffClassifier(LogisticRegression(), objective_metric=scorer).fit(
... X_train, y_train)
>>> scorer(model, X_test, y_test)
0.82...
>>> # compare it with the internal score found by cross-validation
>>> model.objective_score_
0.86...

A second strategy aims at maximizing a metric while imposing constraints on another
metric. Four pre-defined options exist, 2 that uses the Receiver Operating
Characteristic (ROC) statistics and 2 that uses the Precision-Recall statistics.

- `"max_tpr_at_tnr_constraint"`: maximizes the True Positive Rate (TPR) such that the
True Negative Rate (TNR) is the closest to a given value.
- `"max_tnr_at_tpr_constraint"`: maximizes the TNR such that the TPR is the closest to
a given value.
- `"max_precision_at_recall_constraint"`: maximizes the precision such that the recall
is the closest to a given value.
- `"max_recall_at_precision_constraint"`: maximizes the recall such that the precision
is the closest to a given value.

For these options, the `constraint_value` parameter needs to be defined. In addition,
you can use the `pos_label` parameter to indicate the label of the class of interest.

The final strategy maximizes a custom utility function. This problem is also known as
cost-sensitive learning. The utility function is defined by providing dictionary
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
containing the cost-gain associated with the entries of the confusion matrix. The keys
are defined as `{"tn", "fp", "fn", "tp"}`. The class of interest is defined using the
`pos_label` parameter. Refer to :ref:`cost_sensitive_learning_example` for an example
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
depicting the use of such a utility function.

Important notes regarding the internal cross-validation
-------------------------------------------------------

By default :class:`~sklearn.model_selection.CutOffClassifier` uses a 5-fold stratified
cross-validation to tune the cut-off point. The parameter `cv` allows to control the
cross-validation strategy. It is possible to go around cross-validation by passing
`cv="prefit"` and provide an already fitted classifier. In this case, the cut-off point
is tuned on the data provided to the `fit` method.

However, you should be extremely careful when using this option. You should never use
the same data for training the classifier and tuning the cut-off point at the risk of
overfitting. Refer to :ref:`cutoffclassifier_no_cv` that shows such overfitting. If
you are in a situation where you have limited resources, you should can consider using
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
a float number that will use a single split internally.

The option `cv="prefit"` should only be used when the provided classifier was already
trained on some data and you want to tune (or re-tune) on a new validation set.

Examples
--------

- See :ref:`sphx_glr_auto_examples_model_selection_plot_cutoff_tuning.py` example for
an example of tuning the decision threshold of a classifier.
5 changes: 5 additions & 0 deletions doc/whats_new/v1.3.rst
Expand Up @@ -413,6 +413,11 @@ Changelog
`return_indices` to return the train-test indices of each cv split.
:pr:`25659` by :user:`Guillaume Lemaitre <glemaitre>`.

- |MajorFeature| :class:`model_selection.CutOffClassifier` calibrates decision threshold
function of a binary classifier by maximizing a classification metric through
cross-validation.
:pr:`26120` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.naive_bayes`
..........................

Expand Down