Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA add TunedThresholdClassifier meta-estimator to post-tune the cut-off threshold #26120

Merged
merged 228 commits into from May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
228 commits
Select commit Hold shift + click to select a range
b44dd9d
MAINT refactor scorer using _get_response_values
glemaitre Mar 31, 2023
516f62f
Add __name__ for method of Mock
glemaitre Apr 1, 2023
d2fbee0
remove multiclass issue
glemaitre Apr 1, 2023
29e5e87
make response_method a mandatory arg
glemaitre Apr 3, 2023
b645ade
Update sklearn/metrics/_scorer.py
glemaitre Apr 3, 2023
3397c56
apply jeremie comments
glemaitre Apr 3, 2023
092689a
Merge branch 'main' into is/18589_restart
glemaitre Apr 3, 2023
200ec31
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 4, 2023
e871558
iter
glemaitre Apr 4, 2023
31aa1c0
Merge remote-tracking branch 'glemaitre/is/18589_restart' into cutoff…
glemaitre Apr 4, 2023
74614e8
FEA add CutOffClassifier to post-tune prediction threshold
glemaitre Apr 7, 2023
27713af
DOC add changelog entry
glemaitre Apr 7, 2023
ed1d9b3
refresh implementation
glemaitre Apr 7, 2023
8410317
add files
glemaitre Apr 7, 2023
c7d1fe4
remove random state for the moment
glemaitre Apr 7, 2023
c9d7a22
TST make sure to pass the common test
glemaitre Apr 15, 2023
9981f3a
TST metaestimator sample_weight
glemaitre Apr 15, 2023
b9c9d5e
API add prediction functions
glemaitre Apr 15, 2023
588f1c4
TST bypass the test for classification
glemaitre Apr 17, 2023
243d173
iter before another bug
glemaitre Apr 17, 2023
883e929
iter
glemaitre Apr 18, 2023
69333ed
TST add test for _fit_and_score
glemaitre Apr 19, 2023
8616da1
iter
glemaitre Apr 19, 2023
99a10b3
integrate refit
glemaitre Apr 19, 2023
0f6dce2
TST more test
glemaitre Apr 19, 2023
d6fb9f7
TST more test with sample_weight
glemaitre Apr 19, 2023
7ff3d0d
BUG fit_params split
glemaitre Apr 19, 2023
6985ae9
TST add test for fit_params
glemaitre Apr 19, 2023
239793a
TST check underlying response method for TNR/TPR
glemaitre Apr 20, 2023
92083ed
FEA add the possibility to provide a dict
glemaitre Apr 20, 2023
55d0844
TST check string and pos_label interation for cost-matrix
glemaitre Apr 20, 2023
7dfc4a6
TST add sample_weight test for cost-matrix
glemaitre Apr 20, 2023
729c9a8
iter
glemaitre Apr 20, 2023
787be21
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 24, 2023
146b170
change strategy for finding max
glemaitre Apr 24, 2023
03b1f7f
iter
glemaitre Apr 24, 2023
8a09a5f
add some test for precision-recall
glemaitre Apr 24, 2023
d56f57f
TST add invariance zeros weight
glemaitre Apr 24, 2023
cf164c5
DOC fix default n_thresholds
glemaitre Apr 24, 2023
c943f5e
DOC add a small example
glemaitre Apr 25, 2023
bf1462b
iter
glemaitre Apr 25, 2023
fa89431
iter
glemaitre Apr 25, 2023
862519d
bug fixes everywhere
glemaitre Apr 25, 2023
aa520da
iter
glemaitre Apr 25, 2023
5403cf6
Do not allow for single threshold
glemaitre Apr 26, 2023
cd37743
TST add random state checkingclassifier
glemaitre Apr 26, 2023
e7d07af
TST more test for _ContinuousScorer
glemaitre Apr 26, 2023
bc20a47
TST add test for pos_label
glemaitre Apr 26, 2023
bba2f97
TST add pos_label test for TNR/TPR
glemaitre Apr 26, 2023
f925503
some more
glemaitre Apr 26, 2023
d539235
avoid extrapolation
glemaitre Apr 27, 2023
c0acd44
FEA add all thresholds and score computed as attributes
glemaitre Apr 27, 2023
f87baa7
fix docstring
glemaitre Apr 28, 2023
e4dac09
EXA add example of cut-off tuning
glemaitre Apr 28, 2023
4da7cef
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 28, 2023
bd86595
solving the issue of unknown categories
glemaitre Apr 28, 2023
45e6e5a
fix
glemaitre Apr 28, 2023
402a1a7
EXA add hyperlink in the example
glemaitre Apr 29, 2023
6745afc
DOC add warning regarding overfitting
glemaitre Apr 29, 2023
4d557cc
some more doc
glemaitre Apr 29, 2023
2c6ee7e
some more doc
glemaitre Apr 29, 2023
91c8222
DOC more documentation
glemaitre May 2, 2023
9a96ae1
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre May 2, 2023
3d4ce81
fix import
glemaitre May 2, 2023
d7d8dac
fix import
glemaitre May 2, 2023
aa3e83d
iter
glemaitre May 2, 2023
ab97d63
fix
glemaitre May 2, 2023
acb6af8
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre May 4, 2023
486a2bd
Update sklearn/metrics/_scorer.py
glemaitre May 4, 2023
6d4c4aa
Apply suggestions from code review
glemaitre May 15, 2023
1d12e1f
Fix linter
ogrisel Jun 1, 2023
21e20e0
Merge branch 'main' into cutoff_classifier_again
ogrisel Jun 1, 2023
7952cce
Add routing to LogisticRegressionCV
Jun 7, 2023
66ad513
Add a test with enable_metadata_routing=False and fix an issue in sco…
Jun 7, 2023
7e8b824
Add metaestimator tests and fix passing routed params in score method
Jun 13, 2023
d7e50a6
PR suggestions
Jun 25, 2023
3844706
Merge branch 'main' into logistic_cv_routing
Jun 26, 2023
0866c42
Add changelog entry
Jun 26, 2023
43f971b
Add user and pr information
Jun 26, 2023
db63769
Changelog adjustment
Jun 26, 2023
a9b984f
Remove repr method from ConsumingScorer
Jun 26, 2023
97105a4
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 3, 2023
52f5921
handle the np.inf case in roc-curve
glemaitre Jul 3, 2023
637c18e
Merge branch 'main' into logistic_cv_routing
adrinjalali Jul 7, 2023
314bc83
Adjust changelog
Jul 7, 2023
9a8ef4e
Add tests for error when passing params when routing not enabled in L…
OmarManzoor Jul 10, 2023
5b723a0
Address PR suggestions partially
Jul 13, 2023
9ce463d
address comment Tim
glemaitre Jul 13, 2023
bba8f55
iter
glemaitre Jul 13, 2023
1c5487d
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 13, 2023
d302678
MAINT rename modules as per olivier comment
glemaitre Jul 13, 2023
2ade221
add missing module
glemaitre Jul 13, 2023
dca5770
update changelog
glemaitre Jul 13, 2023
8897533
more renaming
glemaitre Jul 13, 2023
75bd7ac
iter
glemaitre Jul 13, 2023
c07a980
Adjust and change the name of params in _check_method_params
OmarManzoor Jul 13, 2023
cc5ba48
Resolve conflict in changelog
OmarManzoor Jul 13, 2023
66c4c7f
iter
glemaitre Jul 13, 2023
378930e
iter
glemaitre Jul 13, 2023
c88ed94
iter
glemaitre Jul 13, 2023
4715e67
iter
glemaitre Jul 13, 2023
b3bb39f
iter
glemaitre Jul 13, 2023
915624a
Merge branch 'main' into logistic_cv_routing
glemaitre Jul 13, 2023
b72a72a
iter
glemaitre Jul 13, 2023
5108e43
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 13, 2023
95150b0
Merge branch 'pr/OmarManzoor/26525' into cutoff_with_metadata_routing
glemaitre Jul 13, 2023
5b66ab8
Merge remote-tracking branch 'origin/main' into cutoff_with_metadata_…
glemaitre Jul 14, 2023
b4e67fb
Add metadata routing
glemaitre Jul 14, 2023
005126a
Apply suggestions from code review
glemaitre Jul 14, 2023
767a05f
CLN clean up some repeated code related to SLEP006
adrinjalali Jul 14, 2023
080ba5c
iter
glemaitre Jul 14, 2023
05ec85d
iter
glemaitre Jul 14, 2023
63c32bd
ENH add new response_method in make_scorer
glemaitre Jul 15, 2023
1584c5b
add non-regression test
glemaitre Jul 15, 2023
1a5a247
update validation param
glemaitre Jul 15, 2023
4cc61b9
more coverage
glemaitre Jul 15, 2023
8f36235
TST add mulitlabel test
glemaitre Jul 15, 2023
9e6b384
Merge branch 'make_scorer_list_response' into cutoff_classifier_again
glemaitre Jul 15, 2023
5490ce4
simplify scorer
glemaitre Jul 15, 2023
8dad0a4
iter
glemaitre Jul 15, 2023
b918708
remove unecessary part in doc
glemaitre Jul 15, 2023
5e23523
iter
glemaitre Jul 15, 2023
d5578f9
iter
glemaitre Jul 16, 2023
f3f844e
address tim comments
glemaitre Jul 24, 2023
44ad195
Merge remote-tracking branch 'origin/main' into make_scorer_list_resp…
glemaitre Jul 24, 2023
e489eab
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 24, 2023
1cf5528
Merge branch 'make_scorer_list_response' into cutoff_classifier_again
glemaitre Jul 24, 2023
26dc94e
iter
glemaitre Jul 26, 2023
6a1a6c7
iter
glemaitre Jul 26, 2023
b17b59e
iter
glemaitre Jul 28, 2023
43c1da8
iter
glemaitre Jul 28, 2023
41a6d07
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 28, 2023
ca06717
iter
glemaitre Jul 28, 2023
ab8b466
iter
glemaitre Jul 30, 2023
235abf5
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Aug 7, 2023
d9ec528
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Aug 11, 2023
69f60a6
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Sep 28, 2023
45a8504
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Oct 17, 2023
8c4c88d
solve deprecation
glemaitre Oct 17, 2023
b97ebf4
Merge remote-tracking branch 'glemaitre/cutoff_classifier_again' into…
glemaitre Oct 17, 2023
e37f831
update changelog
glemaitre Oct 17, 2023
383937f
whoops
glemaitre Oct 17, 2023
d4ce3fb
Update sklearn/metrics/_scorer.py
glemaitre Oct 17, 2023
b6b3548
fix doc
glemaitre Oct 17, 2023
759d680
remove useless fitted attributes
glemaitre Oct 18, 2023
23e65e6
Merge branch 'main' into cutoff_classifier_again
ogrisel Dec 4, 2023
6904817
bump pandas to 1.1.5
glemaitre Dec 4, 2023
bee1ebe
update lock file
glemaitre Dec 4, 2023
4d86a36
iter
glemaitre Jan 13, 2024
48fd7cd
update doc-min lock file
glemaitre Jan 13, 2024
0854cd4
partial reviews
glemaitre Jan 13, 2024
ac75300
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Mar 18, 2024
2df616e
Apply suggestions from code review
glemaitre Mar 18, 2024
b14225c
update lock files
glemaitre Mar 18, 2024
98dcefd
Merge remote-tracking branch 'glemaitre/cutoff_classifier_again' into…
glemaitre Mar 18, 2024
7e3d7aa
iter
glemaitre Mar 18, 2024
b958bb0
simplify refit and do not allow cv == 1
glemaitre Mar 19, 2024
e7722f6
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Mar 19, 2024
98a1db8
check raise for multilabel
glemaitre Mar 19, 2024
c28a3e1
fix test name
glemaitre Mar 19, 2024
076fd29
test another to check if beta is forwarded
glemaitre Mar 19, 2024
e728f1d
iter
glemaitre Mar 20, 2024
c73b205
refit=True and cv is float
glemaitre Mar 20, 2024
a4890df
rename scorer to curve scorer internally
glemaitre Mar 22, 2024
f8a5a79
add a note regarding the abuse of the scorer API
glemaitre Mar 22, 2024
5dfa435
use None instead of highest
glemaitre Mar 22, 2024
d45a71b
use a closer CV API
glemaitre Mar 23, 2024
a32c151
fix example
glemaitre Mar 23, 2024
7592437
simplify model
glemaitre Mar 23, 2024
dc5346b
fix
glemaitre Mar 23, 2024
843ca04
fix docstring
glemaitre Mar 23, 2024
51ed9a8
Apply suggestions from code review
glemaitre Mar 30, 2024
3c89ab3
Apply suggestions from code review
glemaitre Apr 2, 2024
8cd5582
pep8
glemaitre Apr 2, 2024
a48487c
rephrase suggestions
glemaitre Apr 2, 2024
dd18549
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 4, 2024
27515ca
fix
glemaitre Apr 8, 2024
8a87b26
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 8, 2024
811dec9
include and discuss more about amount
glemaitre Apr 8, 2024
c83b4e1
iter
glemaitre Apr 8, 2024
d4e232f
Apply suggestions from code review
glemaitre Apr 25, 2024
92f6e05
Update examples/model_selection/plot_tuned_decision_threshold.py
glemaitre Apr 25, 2024
1c5c3f4
iter
glemaitre Apr 25, 2024
94160ba
iter
glemaitre Apr 25, 2024
85c8484
other comment
glemaitre Apr 25, 2024
4f86e9d
addressed comments
glemaitre Apr 25, 2024
d747098
Apply suggestions from code review
glemaitre Apr 27, 2024
6d0f418
rename TunedThresholdClassifier to TunedThresholdClassifierCV
glemaitre Apr 27, 2024
5671dd6
use meaningful values for check the thresholds values depending on po…
glemaitre Apr 27, 2024
f04085d
TST add more info regarding why not exactly 0 and 1
glemaitre Apr 27, 2024
d179b5f
DOC add documentation for base scorer
glemaitre Apr 27, 2024
2c375f8
DOC add more details regarding the curve scorer
glemaitre Apr 27, 2024
66ba8da
directly test curve_scorer instead to look for function anem
glemaitre Apr 27, 2024
a6b19c1
add required arguments
glemaitre Apr 27, 2024
b3b99ff
DOC add docstring for interpolated score
glemaitre Apr 27, 2024
dda0d2c
Update sklearn/model_selection/tests/test_classification_threshold.py
glemaitre Apr 29, 2024
48e7829
Update sklearn/model_selection/tests/test_classification_threshold.py
glemaitre Apr 29, 2024
17839e8
Apply suggestions from code review
glemaitre Apr 29, 2024
3f02bc3
remove duplicated check
glemaitre Apr 29, 2024
553cfce
remove duplicated check
glemaitre Apr 29, 2024
6ae6d27
check cv_results_ API
glemaitre Apr 29, 2024
d010096
clone classifier
glemaitre Apr 29, 2024
8bb8ca6
TST better comments
glemaitre Apr 29, 2024
fd971c7
iter
glemaitre Apr 29, 2024
bf57dac
FEA add a ConstantThresholdClassifier instead of strategy="constant" …
glemaitre Apr 30, 2024
0409932
make FixedThresholdClassifier appear in example
glemaitre Apr 30, 2024
66ea575
iter
glemaitre Apr 30, 2024
1c97dd4
Update doc/modules/classes.rst
glemaitre Apr 30, 2024
c8c1d0c
Update sklearn/model_selection/_classification_threshold.py
glemaitre Apr 30, 2024
8a52bc6
TST and fix default parameter
glemaitre Apr 30, 2024
ef668cf
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 30, 2024
fdbf68e
TST metadarouting FixedThresholdClassifier
glemaitre Apr 30, 2024
9c0c13d
rename n_thresholds to thresholds
glemaitre Apr 30, 2024
f419371
cover constant predictor error
glemaitre Apr 30, 2024
42eafe5
TST some tests for get_response_values_binary
glemaitre Apr 30, 2024
581133f
use conditional p(y|X) instead of posterior
glemaitre May 2, 2024
0f803d9
be more explicit that strings need to be provided to objective_metric
glemaitre May 2, 2024
eb0defc
factorize plotting into a function
glemaitre May 2, 2024
ffd5669
fix typo in code
glemaitre May 2, 2024
18abafe
use proper scoring rule and robust estimator to scale
glemaitre May 2, 2024
ce9464c
improve narrative
glemaitre May 2, 2024
89d67cf
use grid-search
glemaitre May 2, 2024
db3360b
Apply suggestions from code review
glemaitre May 2, 2024
1789cc0
remove constrainted metrics option
glemaitre May 3, 2024
e7c31b9
partial review
glemaitre May 3, 2024
0fd667c
rename objective_metric to scoring
glemaitre May 3, 2024
07e4387
fix typo
glemaitre May 3, 2024
9bd68e6
remove pos_label and delegate to make_scorer
glemaitre May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 0 additions & 17 deletions doc/modules/classification_threshold.rst
Expand Up @@ -117,23 +117,6 @@ a meaningful metric for their use case.
>>> model.best_score_
0.86...

A second strategy aims to maximize one metric while imposing constraints on another
metric. There are four pre-defined options that can be provided to `objective_metric`
parameter, two use the Receiver Operating Characteristic (ROC) statistics and two use
the Precision-Recall statistics.

- `"max_tpr_at_tnr_constraint"`: maximizes the True Positive Rate (TPR) such that the
True Negative Rate (TNR) is the closest to a given value.
- `"max_tnr_at_tpr_constraint"`: maximizes the TNR such that the TPR is the closest to
a given value.
- `"max_precision_at_recall_constraint"`: maximizes the precision such that the recall
is the closest to a given value.
- `"max_recall_at_precision_constraint"`: maximizes the recall such that the precision
is the closest to a given value.

For these options, the `constraint_value` parameter needs to be defined. In addition,
you can use the `pos_label` parameter to indicate the label of the class of interest.

Important notes regarding the internal cross-validation
-------------------------------------------------------

Expand Down
200 changes: 1 addition & 199 deletions examples/model_selection/plot_tuned_decision_threshold.py
Expand Up @@ -11,7 +11,7 @@

This example shows how to use the
:class:`~sklearn.model_selection.TunedThresholdClassifierCV` to tune the decision
threshold, depending on a metric of interest as well as under a specific constraints.
threshold, depending on a metric of interest.
"""

# %%
Expand Down Expand Up @@ -184,201 +184,3 @@
# example entitled,
# :ref:`sphx_glr_auto_examples_model_selection_plot_cost_sensitive_learning.py`,
# for more details.
#
# Tuning the decision threshold under constraint
# ----------------------------------------------
#
# In some cases, we do not want to only maximize a given metric but instead to maximize
# a metric while satisfying a constraint on another metric. In the current example, we
# could imagine that the decision of our predictive model will be reviewed by a medical
# doctor. In this case, this doctor will only accept a ratio of false positive lower
# than a given value. Therefore, we are interested in maximizing the true positive rate
# while having a false positive rate lower than this value.
#
# The :class:`~sklearn.model_selection.TunedThresholdClassifierCV` allows to tune the
# decision threshold with such specification. We illustrate this strategy together with
# a single train-test split split to display the Receiver Operating Characteristic (ROC)
# curves to get better intuitions.
#
# First, we split the data into a training and testing set.

# %%
from sklearn.model_selection import train_test_split

data_train, data_test, target_train, target_test = train_test_split(
data, target, random_state=42
)

# %%
# Now, we will train both the vanilla and tuned model on the training set. We recall
# that the tuned model is internally maximizing the balanced accuracy for the moment.
model.fit(data_train, target_train)
tuned_model.fit(data_train, target_train)

# %%
# To show the benefit on optimizing a metric under constraint, we will evaluate the
# models using the ROC curve statistics: the true positive rate (TPR) and the false
# positive rate (FPR).
#
# The FPR is not defined in scikit-learn and we define it below:
from sklearn.metrics import confusion_matrix, make_scorer, recall_score


def fpr_score(y, y_pred, neg_label, pos_label):
cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label])
tn, fp, _, _ = cm.ravel()
tnr = tn / (tn + fp)
return 1 - tnr


tpr_score = recall_score # TPR and recall are the same metric
scoring = {
"fpr": make_scorer(fpr_score, neg_label=neg_label, pos_label=pos_label),
"tpr": make_scorer(tpr_score, pos_label=pos_label),
}

# %%
# Now, we plot the ROC curve of both models and the FPR and TPR statistics for the
# decision thresholds of both models.
from sklearn.metrics import RocCurveDisplay

disp = RocCurveDisplay.from_estimator(
model, data_test, target_test, name="Vanilla model", linestyle="--", alpha=0.5
)
RocCurveDisplay.from_estimator(
tuned_model,
data_test,
target_test,
name="Tuned model",
linestyle="-.",
alpha=0.5,
ax=disp.ax_,
)
disp.ax_.plot(
scoring["fpr"](model, data_test, target_test),
scoring["tpr"](model, data_test, target_test),
marker="o",
markersize=10,
color="tab:blue",
label="Default cut-off point at a probability of 0.5",
)
disp.ax_.plot(
scoring["fpr"](tuned_model, data_test, target_test),
scoring["tpr"](tuned_model, data_test, target_test),
marker=">",
markersize=10,
color="tab:orange",
label=f"Cut-off point at probability of {tuned_model.best_threshold_:.2f}",
)
disp.ax_.legend()
_ = disp.ax_.set_title("ROC curves")

# %%
# As expected, both models have the same ROC curves since the tuned
# model is only a post-processing step of the vanilla model. The tuning step is only
# changing the decision threshold, as displayed by the blue and orange markers.
# To optimize the balanced accuracy, the tuned model moved the decision threshold
# from 0.5 to 0.22. By shifting this point, we increase the FPR while increasing
# the TPR: in short we make more false positive but also more true positive. This is
# exactly what we concluded in the previous section when looking at the balanced
# accuracy score.
#
# However, this decision threshold might not be acceptable for our medical doctor. He
# might be interested to have a low FPR instead, let say lower than 5%. For this level
# of FPR, he would like our predictive model to maximize the TPR.
#
# The :class:`~sklearn.model_selection.TunedThresholdClassifierCV` allows to specify
# such constraint by providing the name of the metric and the constraint value. Here, we
# use `max_tpr_at_tnr_constraint` which is exactly what we want. Since the true negative
# rate (TNR) is equal to 1 - FPR, we can rewrite the constraint value as `1 - 0.05 =
# 0.95`.

# %%
constraint_value = 0.95
tuned_model.set_params(
objective_metric="max_tpr_at_tnr_constraint",
constraint_value=constraint_value,
pos_label=pos_label,
store_cv_results=True,
)
tuned_model.fit(data_train, target_train)

# %%
# Now, we can plot the ROC curves and analyse the results.
import matplotlib.pyplot as plt

_, axs = plt.subplots(ncols=2, figsize=(12, 5))

disp = RocCurveDisplay(
fpr=1 - tuned_model.cv_results_["constrained_scores"],
tpr=tuned_model.cv_results_["maximized_scores"],
estimator_name="ROC of the tuned model",
pos_label=pos_label,
)
axs[0].plot(
1 - tuned_model.constrained_score_,
tuned_model.best_score_,
marker="o",
markersize=10,
color="tab:blue",
label=f"Cut-off point at probability of {tuned_model.best_threshold_:.2f}",
)
axs[0].axvline(
1 - constraint_value, 0, 1, color="tab:blue", linestyle="--", label="FPR constraint"
)
axs[0].set_title("Average ROC curve for the tuned model\nacross CV folds")
RocCurveDisplay.from_estimator(
model,
data_test,
target_test,
name="Vanilla model",
linestyle="--",
alpha=0.5,
ax=axs[1],
)
RocCurveDisplay.from_estimator(
tuned_model,
data_test,
target_test,
name="Tuned model",
linestyle="-.",
alpha=0.5,
ax=axs[1],
)
axs[1].plot(
scoring["fpr"](model, data_test, target_test),
scoring["tpr"](model, data_test, target_test),
marker="o",
markersize=10,
color="tab:blue",
label="Default cut-off point at a probability of 0.5",
)
axs[1].plot(
1 - tuned_model.constrained_score_,
tuned_model.best_score_,
marker="^",
markersize=10,
color="tab:orange",
label=f"Cut-off point at probability of {tuned_model.best_threshold_:.2f}",
)
axs[1].legend()
axs[1].set_title("ROC curves")
_ = disp.plot(ax=axs[0])

# %%
# We start with the right-hand side plot. It depicts the ROC curves as in the previous
# section. We observe that the control point of the tuned model moved to a low FPR
# that was defined by our constraint. To achieve this low FPR, the decision threshold
# was moved to a probability of 0.72.
#
# The left-hand side plot shows the averaged ROC curve on the internal validation set
# across the different cross-validation folds. This curve is used to define the decision
# threshold. The vertical dashed line represents the FPR constraint that we defined.
# The decision threshold corresponds to the maximum TPR on the left of this dashed line
# and is represented by a blue marker.
#
# An important point to note is that the decision threshold is defined on averaged
# statistics on an internal validation set. It means that the constraint is respected
# on the train/validation dataset but not necessarily on the test set, in case the
# statistical performance of the model differ from the train/validation set to the test
# set (i.e. overfitting).