Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
schalkdaniel committed Mar 27, 2023
1 parent 429421d commit 4b8e982
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .Rbuildignore
Expand Up @@ -41,7 +41,7 @@ cran-comments.md
^revdep$
^README\.Rmd$
^README-.*\.png$
^LICENSE\.md$
^LICENSE$
^cran-comments\.md$
man/figures/*
.lintr
Expand Down
7 changes: 3 additions & 4 deletions R/mlr_learner.R
Expand Up @@ -50,7 +50,7 @@ cleanList = function(ll) {
#' @title Component-wise gradient boosting learner
#'
#' @description
#' A [Learner] for a component-wise boosting model implemented in [compboost::Compboost].
#' A [Learner][mlr3::Learner] for a component-wise boosting model implemented in [Compboost].
#'
#' @importFrom mlr3 mlr_learners Learner
LearnerCompboost = R6::R6Class("LearnerCompboost", inherit = Learner,
Expand Down Expand Up @@ -244,7 +244,7 @@ LearnerCompboost = R6::R6Class("LearnerCompboost", inherit = Learner,
#' @title Component-wise gradient boosting classification learner
#'
#' @description
#' A [Learner] for a component-wise boosting model implemented in [compboost::Compboost].
#' A [Learner][mlr3::Learner] for a component-wise boosting model implemented in [Compboost].
#'
#' @examples
#' task = mlr3::tsk("german_credit")
Expand All @@ -266,7 +266,7 @@ LearnerClassifCompboost = R6::R6Class("LearnerClassifCompboost", inherit = Learn
#' @title Component-wise gradient boosting regression learner
#'
#' @description
#' A [Learner] for a component-wise boosting model implemented in [compboost::Compboost].
#' A [Learner][mlr3::Learner] for a component-wise boosting model implemented in [Compboost].
#'
#' @examples
#' task = mlr3::tsk("mtcars")
Expand All @@ -283,4 +283,3 @@ LearnerRegrCompboost = R6::R6Class("LearnerRegrCompboost", inherit = LearnerComp
}
)
)

2 changes: 1 addition & 1 deletion man/LearnerClassifCompboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/LearnerCompboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/LearnerRegrCompboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/compboost.cpp
Expand Up @@ -304,6 +304,9 @@ std::map<std::string, arma::mat> Compboost::predictIndividual (const std::map<st

arma::vec Compboost::predict (const std::map<std::string, std::shared_ptr<data::Data>>& data_map, const bool& as_response) const
{
if (data_map.size() == 0) {
throw std::range_error("Require data in 'data_map' for prediction.");
}
arma::mat pred(data_map.begin()->second->getNObs(), _sh_ptr_response->getResponse().n_cols, arma::fill::zeros);

if (_sh_ptr_response->getInitialization().n_rows == 1)
Expand Down
5 changes: 5 additions & 0 deletions tests/testthat/test_api.R
Expand Up @@ -32,6 +32,10 @@ test_that("train works", {
expect_output(cboost$train(4000))
expect_output(cboost$print())

expect_error(expect_warning(cboost$predict(iris)))
expect_error(expect_warning(cboost$predict(mtcars[, 1, FALSE])))
expect_error(cboost$predict(data.frame()))

expect_error(cboost$addBaselearner("wt", "spline", BaselearnerPSpline, degree = 3,
n_knots = 10, penalty = 2, differences = 2))

Expand All @@ -47,6 +51,7 @@ test_that("train works", {
expect_equal(sort(cboost$getBaselearnerNames()), sort(c("mpg_cat_A_binary", "mpg_cat_B_binary", "hp_spline")))
expect_equal(cboost$bl_factory_list$getRegisteredFactoryNames(), sort(c("mpg_cat_A_binary", "mpg_cat_B_binary", "hp_spline")))


expect_equal(cboost$getCurrentIteration(), 4000)
expect_length(cboost$getInbagRisk(), 4001)
expect_length(cboost$getSelectedBaselearner(), 4000)
Expand Down

0 comments on commit 4b8e982

Please sign in to comment.