Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

catboost with 'repeated_cv' resampling and predict_type = 'prob' gives 0s #353

Open
bkmontgom opened this issue May 11, 2024 · 0 comments
Open

Comments

@bkmontgom
Copy link

Description

When trying to evaluate a classif.catboost model using repeated cross-validation with predict_type = 'prob', the predicted probabilities are all 0. This doesn't happen with other models like classif.lightgbm. regr.catboost works fine, as does classif.catboost with predict_type = 'response'. Training a model on the whole dataset and predicting probabilities from that works fine.

Reproducible example

library(mlr3) # ver 0.19.0
library(mlr3extralearners) # ver 0.8.0
# catboost ver 1.2.5

iris_task <- tsk("iris")
lrn_catboost_classif <- lrn("classif.catboost", predict_type = "prob")
resamp <- rsmp("repeated_cv", repeats = 1, folds = 10)
rr <- resample(iris_task, ,
  learner = lrn_catboost_classif,
  resampling = resamp
)
#> INFO  [12:21:56.406] [mlr3] Applying learner 'classif.catboost' on task 'iris' (iter 1/10)
#> INFO  [12:21:57.425] [mlr3] Applying learner 'classif.catboost' on task 'iris' (iter 2/10)
#> INFO  [12:21:58.257] [mlr3] Applying learner 'classif.catboost' on task 'iris' (iter 3/10)
#> INFO  [12:21:59.080] [mlr3] Applying learner 'classif.catboost' on task 'iris' (iter 4/10)
#> INFO  [12:21:59.918] [mlr3] Applying learner 'classif.catboost' on task 'iris' (iter 5/10)
#> INFO  [12:22:00.726] [mlr3] Applying learner 'classif.catboost' on task 'iris' (iter 6/10)
#> INFO  [12:22:01.539] [mlr3] Applying learner 'classif.catboost' on task 'iris' (iter 7/10)
#> INFO  [12:22:02.373] [mlr3] Applying learner 'classif.catboost' on task 'iris' (iter 8/10)
#> INFO  [12:22:03.171] [mlr3] Applying learner 'classif.catboost' on task 'iris' (iter 9/10)
#> INFO  [12:22:03.994] [mlr3] Applying learner 'classif.catboost' on task 'iris' (iter 10/10)
rr$score(measures = msr("classif.logloss")) # all same and horrible
#>     task_id       learner_id resampling_id iteration classif.logloss
#>      <char>           <char>        <char>     <int>           <num>
#>  1:    iris classif.catboost   repeated_cv         1        34.53878
#>  2:    iris classif.catboost   repeated_cv         2        34.53878
#>  3:    iris classif.catboost   repeated_cv         3        34.53878
#>  4:    iris classif.catboost   repeated_cv         4        34.53878
#>  5:    iris classif.catboost   repeated_cv         5        34.53878
#>  6:    iris classif.catboost   repeated_cv         6        34.53878
#>  7:    iris classif.catboost   repeated_cv         7        34.53878
#>  8:    iris classif.catboost   repeated_cv         8        34.53878
#>  9:    iris classif.catboost   repeated_cv         9        34.53878
#> 10:    iris classif.catboost   repeated_cv        10        34.53878
#> Hidden columns: task, learner, resampling, prediction
rr$prediction() # responses appear random, and all probs are 0
#> <PredictionClassif> for 150 observations:
#>  row_ids     truth   response prob.setosa prob.versicolor prob.virginica
#>        7    setosa  virginica           0               0              0
#>       23    setosa versicolor           0               0              0
#>       42    setosa  virginica           0               0              0
#>      ---       ---        ---         ---             ---            ---
#>      108 virginica     setosa           0               0              0
#>      133 virginica  virginica           0               0              0
#>      150 virginica  virginica           0               0              0

model <- lrn_catboost_classif$train(task = iris_task)
predict(model, predict_type = "prob", newdata = iris[c(1, 75, 150), ]) # works fine
#>            setosa   versicolor    virginica
#> [1,] 0.9994297973 0.0003649189 0.0002052838
#> [2,] 0.0004712565 0.9992043693 0.0003243742
#> [3,] 0.0013331991 0.0027428026 0.9959239983

Created on 2024-05-11 with reprex v2.1.0

Session info
sessionInfo()
#> R version 4.4.0 (2024-04-24 ucrt)
#> Platform: x86_64-w64-mingw32/x64
#> Running under: Windows 10 x64 (build 19045)
#> 
#> Matrix products: default
#> 
#> 
#> locale:
#> [1] LC_COLLATE=English_United States.utf8 
#> [2] LC_CTYPE=English_United States.utf8   
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C                          
#> [5] LC_TIME=English_United States.utf8    
#> 
#> time zone: America/Denver
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] mlr3extralearners_0.8.0 mlr3_0.19.0            
#> 
#> loaded via a namespace (and not attached):
#>  [1] jsonlite_1.8.8       future.apply_1.11.2  compiler_4.4.0      
#>  [4] crayon_1.5.2         reprex_2.1.0         parallel_4.4.0      
#>  [7] globals_0.16.3       uuid_1.2-0           RhpcBLASctl_0.23-42 
#> [10] yaml_2.3.8           fastmap_1.1.1        R6_2.5.1            
#> [13] knitr_1.46           palmerpenguins_0.1.1 backports_1.4.1     
#> [16] checkmate_2.3.1      future_1.33.2        paradox_0.11.1      
#> [19] R.cache_0.16.0       mlr3measures_0.5.0   R.utils_2.12.3      
#> [22] rlang_1.1.3          lgr_0.4.4            xfun_0.43           
#> [25] fs_1.6.4             mlr3misc_0.15.0      cli_3.6.2           
#> [28] withr_3.0.0          magrittr_2.0.3       digest_0.6.35       
#> [31] rstudioapi_0.16.0    catboost_1.2.5       lifecycle_1.0.4     
#> [34] R.methodsS3_1.8.2    R.oo_1.26.0          vctrs_0.6.5         
#> [37] evaluate_0.23        glue_1.7.0           data.table_1.15.99  
#> [40] listenv_0.9.1        styler_1.10.3        codetools_0.2-20    
#> [43] parallelly_1.37.1    rmarkdown_2.26       purrr_1.0.2         
#> [46] tools_4.4.0          htmltools_0.5.8.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant