Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'predict' error if categorical variable is constant in 'newdata' #53

Open
AMBarbosa opened this issue Feb 20, 2023 · 2 comments
Open

'predict' error if categorical variable is constant in 'newdata' #53

AMBarbosa opened this issue Feb 20, 2023 · 2 comments

Comments

@AMBarbosa
Copy link

AMBarbosa commented Feb 20, 2023

'predict' for bart models fails with an error message if a categorical variable has only one level in 'newdata' (which is common in our line of work). Here's a reproducible example:

library(dbarts)

# generate some data as in ?bart examples:
f <- function(x) {
  10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
    10 * x[,4] + 5 * x[,5]
}
set.seed(99)
sigma <- 1.0
n     <- 100
x  <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y  <- rnorm(n, Ey, sigma)

# make one of the x variables categorical / character:
x <- data.frame(x)
x[,1] <- ifelse(x[,1] > mean(x[,1]), "high", "low")
head(x)

# fit a bart model:
set.seed(99)
bartFit <- bart(x, y, keeptrees = TRUE)

# simulate a different data frame for prediction, by jittering the numeric variables in x:
x_pred <- data.frame(x[ , 1, drop=F], sapply(x[ , -1], jitter))
head(x_pred)
str(x_pred)

# predict to this new data frame:
prd <- predict(bartFit, newdata = x_pred)  # OK

# however, if we make the categorical variable constant in the prediction data frame:
x_pred$X1 <- x_pred$X1[1]
prd <- predict(bartFit, newdata = x_pred) 
# Error in validateXTest(x.test, attr(data@x, "term.labels"), ncol(data@x),  : 
#  number of columns in 'test' must be equal to that of 'x'

I can trick it into working by appending to 'newdata' a copy of one of its rows; then changing, in this new row, the value of the categorical value to another one of its levels (so it's no longer constant in this data frame); and deleting the results for that row after using 'predict':

x_pred <- rbind(x_pred[1, ], x_pred)
x_pred$X1[1] <- unique(x$X1)[2]
head(x_pred)
prd <- predict(bartFit, newdata = x_pred)  # OK
prd <- prd[ , -1]

But this can be a cumbersome hack, and it's very hard to explain to students without making them lose focus to these code patches instead of the models. So, a fix would be greatly appreciated. Thanks!

@AMBarbosa
Copy link
Author

AMBarbosa commented Mar 31, 2023

Actually (as I just found out) the problem is not only when the categorical variable is constant (i.e. has only one level) in 'newdata', but whenever it has less levels in 'newdata' than the variable that entered the model. So, if a character or categorical variable with 3 levels is included in a model, and we try to predict() with that model to a 'newdata' where the variable only has 2 of those levels, again we get "Error in validateXTest(x.test, attr(data@x, "term.labels"), ncol(data@x), : number of columns in 'test' must be equal to that of 'x'".

The hack in my above comment is not needed if the variable is a factor (not character) and if the missing levels() are added to the variable in 'newdata'; but this is still a bit of a hassle, and anyway it should at least be documented and result in a less obscure error message, as it was very difficult to debug. Cheers!

@bachlaw
Copy link

bachlaw commented Mar 25, 2024

Have had this pop up from time to time also. Perhaps it would help if the validation function checked simply that the levels of the categorical variable were not greater in the testing set than in the training set, rather than equivalent. We don't care if the test set has fewer categories than the training test, and this will happen all the time. We do care when the test set introduces new levels the training set has never seen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants