Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA add TunedThresholdClassifier meta-estimator to post-tune the cut-off threshold #26120

Merged
merged 228 commits into from May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 151 commits
Commits
Show all changes
228 commits
Select commit Hold shift + click to select a range
b44dd9d
MAINT refactor scorer using _get_response_values
glemaitre Mar 31, 2023
516f62f
Add __name__ for method of Mock
glemaitre Apr 1, 2023
d2fbee0
remove multiclass issue
glemaitre Apr 1, 2023
29e5e87
make response_method a mandatory arg
glemaitre Apr 3, 2023
b645ade
Update sklearn/metrics/_scorer.py
glemaitre Apr 3, 2023
3397c56
apply jeremie comments
glemaitre Apr 3, 2023
092689a
Merge branch 'main' into is/18589_restart
glemaitre Apr 3, 2023
200ec31
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 4, 2023
e871558
iter
glemaitre Apr 4, 2023
31aa1c0
Merge remote-tracking branch 'glemaitre/is/18589_restart' into cutoff…
glemaitre Apr 4, 2023
74614e8
FEA add CutOffClassifier to post-tune prediction threshold
glemaitre Apr 7, 2023
27713af
DOC add changelog entry
glemaitre Apr 7, 2023
ed1d9b3
refresh implementation
glemaitre Apr 7, 2023
8410317
add files
glemaitre Apr 7, 2023
c7d1fe4
remove random state for the moment
glemaitre Apr 7, 2023
c9d7a22
TST make sure to pass the common test
glemaitre Apr 15, 2023
9981f3a
TST metaestimator sample_weight
glemaitre Apr 15, 2023
b9c9d5e
API add prediction functions
glemaitre Apr 15, 2023
588f1c4
TST bypass the test for classification
glemaitre Apr 17, 2023
243d173
iter before another bug
glemaitre Apr 17, 2023
883e929
iter
glemaitre Apr 18, 2023
69333ed
TST add test for _fit_and_score
glemaitre Apr 19, 2023
8616da1
iter
glemaitre Apr 19, 2023
99a10b3
integrate refit
glemaitre Apr 19, 2023
0f6dce2
TST more test
glemaitre Apr 19, 2023
d6fb9f7
TST more test with sample_weight
glemaitre Apr 19, 2023
7ff3d0d
BUG fit_params split
glemaitre Apr 19, 2023
6985ae9
TST add test for fit_params
glemaitre Apr 19, 2023
239793a
TST check underlying response method for TNR/TPR
glemaitre Apr 20, 2023
92083ed
FEA add the possibility to provide a dict
glemaitre Apr 20, 2023
55d0844
TST check string and pos_label interation for cost-matrix
glemaitre Apr 20, 2023
7dfc4a6
TST add sample_weight test for cost-matrix
glemaitre Apr 20, 2023
729c9a8
iter
glemaitre Apr 20, 2023
787be21
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 24, 2023
146b170
change strategy for finding max
glemaitre Apr 24, 2023
03b1f7f
iter
glemaitre Apr 24, 2023
8a09a5f
add some test for precision-recall
glemaitre Apr 24, 2023
d56f57f
TST add invariance zeros weight
glemaitre Apr 24, 2023
cf164c5
DOC fix default n_thresholds
glemaitre Apr 24, 2023
c943f5e
DOC add a small example
glemaitre Apr 25, 2023
bf1462b
iter
glemaitre Apr 25, 2023
fa89431
iter
glemaitre Apr 25, 2023
862519d
bug fixes everywhere
glemaitre Apr 25, 2023
aa520da
iter
glemaitre Apr 25, 2023
5403cf6
Do not allow for single threshold
glemaitre Apr 26, 2023
cd37743
TST add random state checkingclassifier
glemaitre Apr 26, 2023
e7d07af
TST more test for _ContinuousScorer
glemaitre Apr 26, 2023
bc20a47
TST add test for pos_label
glemaitre Apr 26, 2023
bba2f97
TST add pos_label test for TNR/TPR
glemaitre Apr 26, 2023
f925503
some more
glemaitre Apr 26, 2023
d539235
avoid extrapolation
glemaitre Apr 27, 2023
c0acd44
FEA add all thresholds and score computed as attributes
glemaitre Apr 27, 2023
f87baa7
fix docstring
glemaitre Apr 28, 2023
e4dac09
EXA add example of cut-off tuning
glemaitre Apr 28, 2023
4da7cef
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 28, 2023
bd86595
solving the issue of unknown categories
glemaitre Apr 28, 2023
45e6e5a
fix
glemaitre Apr 28, 2023
402a1a7
EXA add hyperlink in the example
glemaitre Apr 29, 2023
6745afc
DOC add warning regarding overfitting
glemaitre Apr 29, 2023
4d557cc
some more doc
glemaitre Apr 29, 2023
2c6ee7e
some more doc
glemaitre Apr 29, 2023
91c8222
DOC more documentation
glemaitre May 2, 2023
9a96ae1
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre May 2, 2023
3d4ce81
fix import
glemaitre May 2, 2023
d7d8dac
fix import
glemaitre May 2, 2023
aa3e83d
iter
glemaitre May 2, 2023
ab97d63
fix
glemaitre May 2, 2023
acb6af8
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre May 4, 2023
486a2bd
Update sklearn/metrics/_scorer.py
glemaitre May 4, 2023
6d4c4aa
Apply suggestions from code review
glemaitre May 15, 2023
1d12e1f
Fix linter
ogrisel Jun 1, 2023
21e20e0
Merge branch 'main' into cutoff_classifier_again
ogrisel Jun 1, 2023
7952cce
Add routing to LogisticRegressionCV
Jun 7, 2023
66ad513
Add a test with enable_metadata_routing=False and fix an issue in sco…
Jun 7, 2023
7e8b824
Add metaestimator tests and fix passing routed params in score method
Jun 13, 2023
d7e50a6
PR suggestions
Jun 25, 2023
3844706
Merge branch 'main' into logistic_cv_routing
Jun 26, 2023
0866c42
Add changelog entry
Jun 26, 2023
43f971b
Add user and pr information
Jun 26, 2023
db63769
Changelog adjustment
Jun 26, 2023
a9b984f
Remove repr method from ConsumingScorer
Jun 26, 2023
97105a4
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 3, 2023
52f5921
handle the np.inf case in roc-curve
glemaitre Jul 3, 2023
637c18e
Merge branch 'main' into logistic_cv_routing
adrinjalali Jul 7, 2023
314bc83
Adjust changelog
Jul 7, 2023
9a8ef4e
Add tests for error when passing params when routing not enabled in L…
OmarManzoor Jul 10, 2023
5b723a0
Address PR suggestions partially
Jul 13, 2023
9ce463d
address comment Tim
glemaitre Jul 13, 2023
bba8f55
iter
glemaitre Jul 13, 2023
1c5487d
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 13, 2023
d302678
MAINT rename modules as per olivier comment
glemaitre Jul 13, 2023
2ade221
add missing module
glemaitre Jul 13, 2023
dca5770
update changelog
glemaitre Jul 13, 2023
8897533
more renaming
glemaitre Jul 13, 2023
75bd7ac
iter
glemaitre Jul 13, 2023
c07a980
Adjust and change the name of params in _check_method_params
OmarManzoor Jul 13, 2023
cc5ba48
Resolve conflict in changelog
OmarManzoor Jul 13, 2023
66c4c7f
iter
glemaitre Jul 13, 2023
378930e
iter
glemaitre Jul 13, 2023
c88ed94
iter
glemaitre Jul 13, 2023
4715e67
iter
glemaitre Jul 13, 2023
b3bb39f
iter
glemaitre Jul 13, 2023
915624a
Merge branch 'main' into logistic_cv_routing
glemaitre Jul 13, 2023
b72a72a
iter
glemaitre Jul 13, 2023
5108e43
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 13, 2023
95150b0
Merge branch 'pr/OmarManzoor/26525' into cutoff_with_metadata_routing
glemaitre Jul 13, 2023
5b66ab8
Merge remote-tracking branch 'origin/main' into cutoff_with_metadata_…
glemaitre Jul 14, 2023
b4e67fb
Add metadata routing
glemaitre Jul 14, 2023
005126a
Apply suggestions from code review
glemaitre Jul 14, 2023
767a05f
CLN clean up some repeated code related to SLEP006
adrinjalali Jul 14, 2023
080ba5c
iter
glemaitre Jul 14, 2023
05ec85d
iter
glemaitre Jul 14, 2023
63c32bd
ENH add new response_method in make_scorer
glemaitre Jul 15, 2023
1584c5b
add non-regression test
glemaitre Jul 15, 2023
1a5a247
update validation param
glemaitre Jul 15, 2023
4cc61b9
more coverage
glemaitre Jul 15, 2023
8f36235
TST add mulitlabel test
glemaitre Jul 15, 2023
9e6b384
Merge branch 'make_scorer_list_response' into cutoff_classifier_again
glemaitre Jul 15, 2023
5490ce4
simplify scorer
glemaitre Jul 15, 2023
8dad0a4
iter
glemaitre Jul 15, 2023
b918708
remove unecessary part in doc
glemaitre Jul 15, 2023
5e23523
iter
glemaitre Jul 15, 2023
d5578f9
iter
glemaitre Jul 16, 2023
f3f844e
address tim comments
glemaitre Jul 24, 2023
44ad195
Merge remote-tracking branch 'origin/main' into make_scorer_list_resp…
glemaitre Jul 24, 2023
e489eab
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 24, 2023
1cf5528
Merge branch 'make_scorer_list_response' into cutoff_classifier_again
glemaitre Jul 24, 2023
26dc94e
iter
glemaitre Jul 26, 2023
6a1a6c7
iter
glemaitre Jul 26, 2023
b17b59e
iter
glemaitre Jul 28, 2023
43c1da8
iter
glemaitre Jul 28, 2023
41a6d07
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Jul 28, 2023
ca06717
iter
glemaitre Jul 28, 2023
ab8b466
iter
glemaitre Jul 30, 2023
235abf5
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Aug 7, 2023
d9ec528
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Aug 11, 2023
69f60a6
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Sep 28, 2023
45a8504
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Oct 17, 2023
8c4c88d
solve deprecation
glemaitre Oct 17, 2023
b97ebf4
Merge remote-tracking branch 'glemaitre/cutoff_classifier_again' into…
glemaitre Oct 17, 2023
e37f831
update changelog
glemaitre Oct 17, 2023
383937f
whoops
glemaitre Oct 17, 2023
d4ce3fb
Update sklearn/metrics/_scorer.py
glemaitre Oct 17, 2023
b6b3548
fix doc
glemaitre Oct 17, 2023
759d680
remove useless fitted attributes
glemaitre Oct 18, 2023
23e65e6
Merge branch 'main' into cutoff_classifier_again
ogrisel Dec 4, 2023
6904817
bump pandas to 1.1.5
glemaitre Dec 4, 2023
bee1ebe
update lock file
glemaitre Dec 4, 2023
4d86a36
iter
glemaitre Jan 13, 2024
48fd7cd
update doc-min lock file
glemaitre Jan 13, 2024
0854cd4
partial reviews
glemaitre Jan 13, 2024
ac75300
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Mar 18, 2024
2df616e
Apply suggestions from code review
glemaitre Mar 18, 2024
b14225c
update lock files
glemaitre Mar 18, 2024
98dcefd
Merge remote-tracking branch 'glemaitre/cutoff_classifier_again' into…
glemaitre Mar 18, 2024
7e3d7aa
iter
glemaitre Mar 18, 2024
b958bb0
simplify refit and do not allow cv == 1
glemaitre Mar 19, 2024
e7722f6
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Mar 19, 2024
98a1db8
check raise for multilabel
glemaitre Mar 19, 2024
c28a3e1
fix test name
glemaitre Mar 19, 2024
076fd29
test another to check if beta is forwarded
glemaitre Mar 19, 2024
e728f1d
iter
glemaitre Mar 20, 2024
c73b205
refit=True and cv is float
glemaitre Mar 20, 2024
a4890df
rename scorer to curve scorer internally
glemaitre Mar 22, 2024
f8a5a79
add a note regarding the abuse of the scorer API
glemaitre Mar 22, 2024
5dfa435
use None instead of highest
glemaitre Mar 22, 2024
d45a71b
use a closer CV API
glemaitre Mar 23, 2024
a32c151
fix example
glemaitre Mar 23, 2024
7592437
simplify model
glemaitre Mar 23, 2024
dc5346b
fix
glemaitre Mar 23, 2024
843ca04
fix docstring
glemaitre Mar 23, 2024
51ed9a8
Apply suggestions from code review
glemaitre Mar 30, 2024
3c89ab3
Apply suggestions from code review
glemaitre Apr 2, 2024
8cd5582
pep8
glemaitre Apr 2, 2024
a48487c
rephrase suggestions
glemaitre Apr 2, 2024
dd18549
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 4, 2024
27515ca
fix
glemaitre Apr 8, 2024
8a87b26
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 8, 2024
811dec9
include and discuss more about amount
glemaitre Apr 8, 2024
c83b4e1
iter
glemaitre Apr 8, 2024
d4e232f
Apply suggestions from code review
glemaitre Apr 25, 2024
92f6e05
Update examples/model_selection/plot_tuned_decision_threshold.py
glemaitre Apr 25, 2024
1c5c3f4
iter
glemaitre Apr 25, 2024
94160ba
iter
glemaitre Apr 25, 2024
85c8484
other comment
glemaitre Apr 25, 2024
4f86e9d
addressed comments
glemaitre Apr 25, 2024
d747098
Apply suggestions from code review
glemaitre Apr 27, 2024
6d0f418
rename TunedThresholdClassifier to TunedThresholdClassifierCV
glemaitre Apr 27, 2024
5671dd6
use meaningful values for check the thresholds values depending on po…
glemaitre Apr 27, 2024
f04085d
TST add more info regarding why not exactly 0 and 1
glemaitre Apr 27, 2024
d179b5f
DOC add documentation for base scorer
glemaitre Apr 27, 2024
2c375f8
DOC add more details regarding the curve scorer
glemaitre Apr 27, 2024
66ba8da
directly test curve_scorer instead to look for function anem
glemaitre Apr 27, 2024
a6b19c1
add required arguments
glemaitre Apr 27, 2024
b3b99ff
DOC add docstring for interpolated score
glemaitre Apr 27, 2024
dda0d2c
Update sklearn/model_selection/tests/test_classification_threshold.py
glemaitre Apr 29, 2024
48e7829
Update sklearn/model_selection/tests/test_classification_threshold.py
glemaitre Apr 29, 2024
17839e8
Apply suggestions from code review
glemaitre Apr 29, 2024
3f02bc3
remove duplicated check
glemaitre Apr 29, 2024
553cfce
remove duplicated check
glemaitre Apr 29, 2024
6ae6d27
check cv_results_ API
glemaitre Apr 29, 2024
d010096
clone classifier
glemaitre Apr 29, 2024
8bb8ca6
TST better comments
glemaitre Apr 29, 2024
fd971c7
iter
glemaitre Apr 29, 2024
bf57dac
FEA add a ConstantThresholdClassifier instead of strategy="constant" …
glemaitre Apr 30, 2024
0409932
make FixedThresholdClassifier appear in example
glemaitre Apr 30, 2024
66ea575
iter
glemaitre Apr 30, 2024
1c97dd4
Update doc/modules/classes.rst
glemaitre Apr 30, 2024
c8c1d0c
Update sklearn/model_selection/_classification_threshold.py
glemaitre Apr 30, 2024
8a52bc6
TST and fix default parameter
glemaitre Apr 30, 2024
ef668cf
Merge remote-tracking branch 'origin/main' into cutoff_classifier_again
glemaitre Apr 30, 2024
fdbf68e
TST metadarouting FixedThresholdClassifier
glemaitre Apr 30, 2024
9c0c13d
rename n_thresholds to thresholds
glemaitre Apr 30, 2024
f419371
cover constant predictor error
glemaitre Apr 30, 2024
42eafe5
TST some tests for get_response_values_binary
glemaitre Apr 30, 2024
581133f
use conditional p(y|X) instead of posterior
glemaitre May 2, 2024
0f803d9
be more explicit that strings need to be provided to objective_metric
glemaitre May 2, 2024
eb0defc
factorize plotting into a function
glemaitre May 2, 2024
ffd5669
fix typo in code
glemaitre May 2, 2024
18abafe
use proper scoring rule and robust estimator to scale
glemaitre May 2, 2024
ce9464c
improve narrative
glemaitre May 2, 2024
89d67cf
use grid-search
glemaitre May 2, 2024
db3360b
Apply suggestions from code review
glemaitre May 2, 2024
1789cc0
remove constrainted metrics option
glemaitre May 3, 2024
e7c31b9
partial review
glemaitre May 3, 2024
0fd667c
rename objective_metric to scoring
glemaitre May 3, 2024
07e4387
fix typo
glemaitre May 3, 2024
9bd68e6
remove pos_label and delegate to make_scorer
glemaitre May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 11 additions & 11 deletions build_tools/circle/doc_min_dependencies_linux-64_conda.lock
@@ -1,6 +1,6 @@
# Generated by conda-lock.
# platform: linux-64
# input_hash: a58a98732e5815c15757bc1def8ddc0d87f20f11edcf6e7b408594bf948cbb3e
# input_hash: 46b1818af4901a4b14e79dab7a99627a28da9815d13cdb73c40e4590b2bd6259
@EXPLICIT
https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81
https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2023.11.17-hbcca054_0.conda#01ffc8d36f9eba0ce0b3c1955fa780ee
Expand Down Expand Up @@ -48,7 +48,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.2-hd590300_0.co
https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda#5aa797f8787fe7a17d1b0821485b5adc
https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda#f36c115f1ee199da648e0597ec2047ad
https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda#318b08df404f9c9be5712aaa5a6f0bb0
https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.32.3-h59595ed_0.conda#bdadff838d5437aea83607ced8b37f75
https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.32.4-h59595ed_0.conda#3f1017b4141e943d9bc8739237f749e8
https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda#7dbaa197d7ba6032caf7ae7f32c1efa0
https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda#da0ec11a6454ae19bff5b02ed881a2b1
https://conda.anaconda.org/conda-forge/linux-64/openssl-3.2.0-hd590300_1.conda#603827b39ea2b835268adb8c821b8570
Expand Down Expand Up @@ -105,7 +105,7 @@ https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.0-h8ee46fc_
https://conda.anaconda.org/conda-forge/linux-64/xcb-util-renderutil-0.3.9-hd590300_1.conda#e995b155d938b6779da6ace6c6b13816
https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.1-h8ee46fc_1.conda#90108a432fb5c6150ccfee3f03388656
https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.7-h8ee46fc_0.conda#49e482d882669206653b095f5206c05b
https://conda.anaconda.org/conda-forge/noarch/alabaster-0.7.13-pyhd8ed1ab_0.conda#06006184e203b61d3525f90de394471e
https://conda.anaconda.org/conda-forge/noarch/alabaster-0.7.16-pyhd8ed1ab_0.conda#def531a3ac77b7fb8c21d17bb5d0badb
https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.1.0-py39h3d6467e_1.conda#c48418c8b35f1d59ae9ae1174812b40a
https://conda.anaconda.org/conda-forge/linux-64/c-compiler-1.7.0-hd590300_0.conda#fad1d0a651bf929c6c16fbf1f6ccfa7c
https://conda.anaconda.org/conda-forge/noarch/certifi-2023.11.17-pyhd8ed1ab_0.conda#2011bcf45376341dd1d690263fdbc789
Expand All @@ -117,7 +117,7 @@ https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda#5
https://conda.anaconda.org/conda-forge/linux-64/cython-0.29.33-py39h227be39_0.conda#34bab6ef3e8cdf86fe78c46a984d3217
https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2#ecfff944ba3960ecb334b9a2663d708d
https://conda.anaconda.org/conda-forge/linux-64/docutils-0.19-py39hf3d152e_1.tar.bz2#adb733ec2ee669f6d010758d054da60f
https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_0.conda#f6c211fee3c98229652b60a9a42ef363
https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda#8d652ea2ee8eaee02ed8dc820bc794aa
https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9
https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda#0f69b688f52ff6da70bccb7ff7001d1d
https://conda.anaconda.org/conda-forge/noarch/fsspec-2023.12.2-pyhca7485f_0.conda#bf40f2a8835b78b1f91083d306b493d2
Expand Down Expand Up @@ -175,7 +175,7 @@ https://conda.anaconda.org/conda-forge/linux-64/cytoolz-0.12.2-py39hd1e30aa_1.co
https://conda.anaconda.org/conda-forge/linux-64/fortran-compiler-1.7.0-heb67821_0.conda#7ef7c0f111dad1c8006504a0f1ccd820
https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.3-hfc55251_0.conda#e08e51acc7d1ae8dbe13255e7b4c64ac
https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-7.0.1-pyha770c72_0.conda#746623a787e06191d80a2133e5daff17
https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.2-pyhd8ed1ab_1.tar.bz2#c8490ed5c70966d232fdd389d0dbed37
https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.3-pyhd8ed1ab_0.conda#e7d8df6509ba635247ff9aea31134262
https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc
https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-20_linux64_openblas.conda#36d486d72ab64ffea932329a1d3729a3
https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_hb11cfb5_4.conda#c90f4cbb57839c98fef8f830e4b9972f
Expand All @@ -201,7 +201,7 @@ https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.12.2-py39h3d6467e_5
https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-3.5.0-pyhd8ed1ab_0.conda#d5f595da2daead898ca958ac62f0307b
https://conda.anaconda.org/conda-forge/noarch/requests-2.31.0-pyhd8ed1ab_0.conda#a30144e4156cdbb236f99ebb49828f8b
https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-20_linux64_openblas.conda#9932a1d4e9ecf2d35fb19475446e361e
https://conda.anaconda.org/conda-forge/noarch/dask-core-2023.12.1-pyhd8ed1ab_0.conda#bf6ad72d882bc3f04e6a0fb50fd2cce8
https://conda.anaconda.org/conda-forge/noarch/dask-core-2024.1.0-pyhd8ed1ab_0.conda#cab4cec272dc1e30086f7d32faa4f130
https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.8-h8e1006c_1.conda#3926dab94fe06d88ade0e716d77b8cf8
https://conda.anaconda.org/conda-forge/linux-64/imagecodecs-lite-2019.12.3-py39hd257fcd_5.tar.bz2#32dba66d6abc2b4b5b019c9e54307312
https://conda.anaconda.org/conda-forge/noarch/imageio-2.33.1-pyh8c1a49c_0.conda#1c34d58ac469a34e7e96832861368bce
Expand All @@ -226,10 +226,10 @@ https://conda.anaconda.org/conda-forge/noarch/numpydoc-1.2-pyhd8ed1ab_0.tar.bz2#
https://conda.anaconda.org/conda-forge/noarch/sphinx-copybutton-0.5.2-pyhd8ed1ab_0.conda#ac832cc43adc79118cf6e23f1f9b8995
https://conda.anaconda.org/conda-forge/noarch/sphinx-gallery-0.15.0-pyhd8ed1ab_0.conda#1a49ca9515ef9a96edff2eea06143dc6
https://conda.anaconda.org/conda-forge/noarch/sphinx-prompt-1.3.0-py_0.tar.bz2#9363002e2a134a287af4e32ff0f26cdc
https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-applehelp-1.0.7-pyhd8ed1ab_0.conda#aebfabcb60c33a89c1f9290cab49bc93
https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-devhelp-1.0.5-pyhd8ed1ab_0.conda#ebf08f5184d8eaa486697bc060031953
https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-htmlhelp-2.0.4-pyhd8ed1ab_0.conda#a9a89000dfd19656ad004b937eeb6828
https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-qthelp-1.0.6-pyhd8ed1ab_0.conda#cf5c9649272c677a964a7313279e3a9b
https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-applehelp-1.0.8-pyhd8ed1ab_0.conda#611a35a27914fac3aa37611a6fe40bb5
https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-devhelp-1.0.6-pyhd8ed1ab_0.conda#d7e4954df0d3aea2eacc7835ad12671d
https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-htmlhelp-2.0.5-pyhd8ed1ab_0.conda#7e1e7437273682ada2ed5e9e9714b140
https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-qthelp-1.0.7-pyhd8ed1ab_0.conda#26acae54b06f178681bfb551760f5dd1
https://conda.anaconda.org/conda-forge/noarch/sphinx-6.0.0-pyhd8ed1ab_2.conda#ac1d3b55da1669ee3a56973054fd7efb
https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.9-pyhd8ed1ab_0.conda#0612e497d7860728f2cda421ea2aec09
https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_0.conda#e507335cb4ca9cff4c3d0fa9cdab255e
# pip sphinxext-opengraph @ https://files.pythonhosted.org/packages/50/ac/c105ed3e0a00b14b28c0aa630935af858fd8a32affeff19574b16e2c6ae8/sphinxext_opengraph-0.4.2-py3-none-any.whl#sha256=a51f2604f9a5b6c0d25d3a88e694d5c02e20812dc0e482adf96c8628f9109357
1 change: 1 addition & 0 deletions doc/model_selection.rst
Expand Up @@ -14,5 +14,6 @@ Model selection and evaluation

modules/cross_validation
modules/grid_search
modules/classification_threshold
modules/model_evaluation
modules/learning_curve
10 changes: 10 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -1248,6 +1248,16 @@ Hyper-parameter optimizers
model_selection.RandomizedSearchCV
model_selection.HalvingRandomSearchCV

Model post-fit tuning
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
---------------------

.. currentmodule:: sklearn

.. autosummary::
:toctree: generated/
:template: class.rst

model_selection.TunedThresholdClassifier

Model validation
----------------
Expand Down
166 changes: 166 additions & 0 deletions doc/modules/classification_threshold.rst
@@ -0,0 +1,166 @@
.. currentmodule:: sklearn.model_selection

.. _tunedthresholdclassifier:

======================================================
Tuning cut-off decision threshold for class prediction
======================================================
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Classifiers are predictive models: they use statistical learning to predict outcomes.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
The outcomes of a classifier are scores for each sample in relation to each class and
categorical prediction (class label). Scores are obtained from :term:`predict_proba` or
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
:term:`decision_function`. The former returns posterior probability estimates for each
class, while the latter returns a decision score for each class. The decision score is a
measure of how strongly the sample is predicted to belong to the positive class (e.g.,
the distance to the decision boundary). In binary classification, a decision rule is
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
then defined by thresholding the scores, leading to a single class label for each
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
sample. Those labels are obtained with :term:`predict`.

For binary classification in scikit-learn, class labels are obtained by associating the
positive class with posterior probability estimates greater than 0.5 (obtained with
:term:`predict_proba`) or decision scores greater than 0 (obtained with
:term:`decision_function`).
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

Here, we show an example that illustrates the relation between posterior
probability estimates and class labels::
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

>>> from sklearn.datasets import make_classification
>>> from sklearn.tree import DecisionTreeClassifier
>>> X, y = make_classification(random_state=0)
>>> classifier = DecisionTreeClassifier(max_depth=2, random_state=0).fit(X, y)
>>> classifier.predict_proba(X[:4])
array([[0.94 , 0.06 ],
[0.94 , 0.06 ],
[0.0416..., 0.9583...],
[0.0416..., 0.9583...]])
>>> classifier.predict(X[:4])
array([0, 0, 1, 1])

While these approaches are reasonable as default behaviors, they are not ideal for
all cases. Let's illustrate with an example.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

Let's consider a scenario where a predictive model is being deployed to assist medical
doctors in detecting tumors. In this setting, doctors will be most likely interested in
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
correctly identifying all patients with cancer so that they can provide them with the
right treatment. In other words, doctors prioritize achieving a high recall rate,
meaning they want to identify all cases of cancer without missing any patients who have
it. This emphasis on recall comes, of course, with the trade-off of potentially more
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
false-positive predictions, reducing the precision of the model, but that is a risk
doctors are willing to take. Consequently, when it comes to deciding whether to classify
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
a patient as having cancer or not, it may be more beneficial to classify them as
positive for cancer when the posterior probability estimate is lower than 0.5.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Post-tuning the decision threshold
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned in the other example, I would introduce the roc curve earlier in the explanation.

==================================

One solution to address the problem stated in the introduction is to tune the decision
threshold of the classifier once the model has been trained. The
:class:`~sklearn.model_selection.TunedThresholdClassifier` tunes this threshold using an
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
internal cross-validation. The optimum threshold is chosen to maximize a given metric
Copy link
Member

@jnothman jnothman Dec 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need justify using "internal cross-validation" somewhere rather than tuning this like any other hyperparameter? Indeed, ignoring constraint_value is this tool just a workaround for a design flaw wherein a parameter to BaseClassifier is not exposed to *SearchCV with warm start?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today, I would indeed consider it as a design and methodological flaw. I don't know what level of justification do you think we should add-on here:

... tunes this threshold using an internal cross-validation since the scikit-learn
:term:`predict` API does not offer this flexibility.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. Isn't the problem that we can't freeze the estimator? Maybe I need to look at the code first...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I realized there's two sort-of independent long-term API issues going on here and you were talking about the other one lol. So we need to have internal cross-validation because we still don't have an API for path-like algorithms and efficient CV and that's what you were talking about. But there's also the "prefit" part which I was thinking about, which is now actually fixed via the __clone__ protocol.
Basically if we're not fitting the underlying estimator, then we should use the __clone__ protocol to not clone it if the TunedThresholdClassifier is being cloned, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still don't have an API for path-like algorithms and efficient CV

@amueller What do you mean by that?

is this tool just a workaround for a design flaw wherein a parameter to BaseClassifier is not exposed

and

Today, I would indeed consider it as a design and methodological flaw.

@glemaitre @jnothman Shouldn't we fix this design flaw? Do you have suggestions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we fix this design flaw? Do you have suggestions?

This is something that I would introduce in 2.0 because we will surely break some API. In this case, we could the freedom to think about a good API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess what I am missing is a roadmap or a plan/vision.

with or without constraints.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we note the contrast with what CalibratedClassifierCV is for?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes for sure.

glemaitre marked this conversation as resolved.
Show resolved Hide resolved

The following image illustrates the tuning of the cut-off point for a gradient boosting
classifier. While the vanilla and tuned classifiers provide the same Receiver Operating
Characteristic (ROC) and Precision-Recall curves, and thus the same
:term:`predict_proba` outputs, the class label predictions differ because of the tuned
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
decision threshold. The vanilla classifier predicts the class of interest for a
posterior probability greater than 0.5 while the tuned classifier predicts the class of
interest for a very low probability (around 0.02). This cut-off point optimizes a
utility metric defined by the business (in this case an insurance company).

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cost_sensitive_learning_002.png
:target: ../auto_examples/model_selection/plot_cost_sensitive_learning.html
:align: center

Options to tune the cut-off point
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
---------------------------------

The cut-off point can be tuned through different strategies controlled by the parameter
`objective_metric`.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

One way to tune the threshold is by maximizing a pre-defined scikit-learn metric. These
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe mention that a common tuning is picking the top right point on the ROC curve which is the same as picking f2 score (I think?) here. Or maybe mention that that has a nice geometric explanation but doesn't really consider the application.

metrics can be found by calling the function :func:`~sklearn.metrics.get_scorer_names`.
In this example, we maximize the balanced accuracy.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

.. note::

It is important to notice that these metrics come with default parameters, notably
the label of the class of interested (i.e. `pos_label`). Thus, if this label is not
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
the right one for your application, you need to define a scorer and pass the right
`pos_label` (and additional parameters) using the
:func:`~sklearn.metrics.make_scorer`. Refer to :ref:`scoring` to get
information to define your own scoring function. For instance, we show how to pass
the information to the scorer that the label of interest is `0` when maximizing the
:func:`~sklearn.metrics.f1_score`:

>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import (
... TunedThresholdClassifier, train_test_split
... )
>>> from sklearn.metrics import make_scorer, f1_score
>>> X, y = make_classification(
... n_samples=1_000, weights=[0.1, 0.9], random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
>>> pos_label = 0
>>> scorer = make_scorer(f1_score, pos_label=pos_label)
>>> base_model = LogisticRegression()
>>> model = TunedThresholdClassifier(base_model, objective_metric=scorer).fit(
... X_train, y_train)
>>> scorer(model, X_test, y_test)
0.79...
>>> # compare it with the internal score found by cross-validation
>>> model.objective_score_
0.86...

A second strategy aims to maximize one metric while imposing constraints on another
metric. There are four pre-defined options, two use the Receiver Operating
Characteristic (ROC) statistics and two use the Precision-Recall statistics.

- `"max_tpr_at_tnr_constraint"`: maximizes the True Positive Rate (TPR) such that the
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
True Negative Rate (TNR) is the closest to a given value.
- `"max_tnr_at_tpr_constraint"`: maximizes the TNR such that the TPR is the closest to
a given value.
- `"max_precision_at_recall_constraint"`: maximizes the precision such that the recall
is the closest to a given value.
- `"max_recall_at_precision_constraint"`: maximizes the recall such that the precision
is the closest to a given value.

For these options, the `constraint_value` parameter needs to be defined. In addition,
you can use the `pos_label` parameter to indicate the label of the class of interest.

Important notes regarding the internal cross-validation
-------------------------------------------------------

By default :class:`~sklearn.model_selection.TunedThresholdClassifier` uses a 5-fold
stratified cross-validation to tune the cut-off point. The parameter `cv` allows to
control the cross-validation strategy. It is possible to bypass cross-validation by
setting `cv="prefit"` and providing a fitted classifier. In this case, the cut-off point
is tuned on the data provided to the `fit` method.

However, you should be extremely careful when using this option. You should never use
the same data for training the classifier and tuning the cut-off point due to the risk
of overfitting. Refer to the following example section for more details (cf.
:ref:`tunedthresholdclassifier_no_cv`). If you have limited resources, consider using a
float number to limit to an internal single train-test split.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

glemaitre marked this conversation as resolved.
Show resolved Hide resolved
The option `cv="prefit"` should only be used when the provided classifier was already
trained, and you just want to find the best cut-off using a new validation set.

Manually setting the decision threshold
---------------------------------------

The previous sections discussed strategies to find an optimal decision threshold. It is
also possible to manually set the decision threshold in
:class`~sklearn.model_selection.TunedThresholdClassifier` by setting the parameter
`strategy` to `"constant"` and providing the desired threshold using the parameter
`constant_threshold`.

Examples
--------

- See the example entitled
:ref:`sphx_glr_auto_examples_model_selection_plot_tuned_decision_threshold.py`,
to get insights on the post-tuning of the decision threshold.
- See the example entitled
:ref:`sphx_glr_auto_examples_model_selection_plot_cost_sensitive_learning.py`,
to learn about cost-sensitive learning and decision threshold tuning.
8 changes: 8 additions & 0 deletions doc/whats_new/v1.5.rst
Expand Up @@ -31,6 +31,14 @@ Changelog
by passing a function in place of a strategy name.
:pr:`28053` by :user:`Mark Elliot <mark-thm>`.

:mod:`sklearn.model_selection`
..............................

- |MajorFeature| :class:`model_selection.TunedThresholdClassifier` calibrates
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
the decision threshold function of a binary classifier by maximizing a
classification metric through cross-validation.
:pr:`26120` by :user:`Guillaume Lemaitre <glemaitre>`.

Code and Documentation Contributors
-----------------------------------

Expand Down