Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant values for non-parametric/BART component predictions with rbart_vi #63

Open
marjoleinF opened this issue Dec 9, 2023 · 0 comments

Comments

@marjoleinF
Copy link

marjoleinF commented Dec 9, 2023

Thank you for the great work on dbarts!
I've seen this behavior on multiple datasets and it seems counterintuitive; but I might be overlooking something:

## example from function rbart_vi
f <- function(x) {
  10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
    10 * x[,4] + 5 * x[,5]
}

set.seed(99)
sigma <- 1.0
n     <- 100

x  <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y  <- rnorm(n, Ey, sigma)

n.g <- 10
g <- sample(n.g, length(y), replace = TRUE)
sigma.b <- 1.5
b <- rnorm(n.g, 0, sigma.b)

y <- y + b[g]

df <- as.data.frame(x)
colnames(df) <- paste0("x_", seq_len(ncol(x)))
df$y <- y
df$g <- g

## low numbers to reduce run time (works fine)
set.seed(42)
rbartFit <- rbart_vi(y ~ . - g, df, group.by = g,
                     n.samples = 40L, n.burn = 10L, n.thin = 2L,
                     n.chains = 1L,
                     n.trees = 25L, n.threads = 1L
head(predict(rbartFit, newdata = df[1:5,], group.by = df[1:5,]$g, type = "ppd"))
##         [,1]     [,2]     [,3]     [,4]     [,5]
##[1,] 3.314817 17.32898 20.01549 1.172588 19.14767
##[2,] 9.574205 15.93575 14.94165 5.810090 21.45006
##[3,] 9.927755 14.90057 16.09273 3.274698 19.30742
##[4,] 8.075448 13.32316 17.34530 1.973367 18.65846
##[5,] 9.501832 13.13619 19.36455 3.259061 16.43262
##[6,] 6.951968 16.95147 16.73997 5.225885 20.07958
head(predict(rbartFit, newdata = df[1:5,], group.by = df[1:5,]$g, type = "bart"))
##          [,1]     [,2]     [,3]     [,4]     [,5]
##[1,]  6.045431 16.89998 16.67287 4.972953 20.60489
##[2,]  9.460054 15.46783 15.19634 5.326231 20.54520
##[3,]  8.894971 16.68499 14.30837 6.777138 19.86926
##[4,]  9.668461 17.78030 14.42151 6.032235 20.18210
##[5,] 10.771178 16.67549 16.71494 5.109395 18.24331
##[6,]  7.975217 18.83930 15.40911 5.932053 18.58619

## default rbart_vi settings yield constants for BART-component predictions
set.seed(42)
rbartFit <- rbart_vi(y ~ . - g, df, group.by = g)
head(predict(rbartFit, newdata = df[1:5,], group.by = df[1:5,]$g, type = "ppd"))
##         [,1]     [,2]     [,3]     [,4]     [,5]
##[1,] 17.51067 11.59959 17.29335 14.27244 17.42347
##[2,] 14.94680 12.06890 15.79795 13.94135 16.37092
##[3,] 16.14532 11.71485 17.15503 13.54932 16.84710
##[4,] 16.27910 11.84871 16.80194 13.02856 16.48098
##[5,] 15.98320 12.77454 16.56062 13.61612 16.37043
##[6,] 15.77031 12.38151 16.85415 14.70622 14.87979
> head(predict(rbartFit, newdata = df[1:5,], group.by = df[1:5,]$g, type = "bart"))
##         [,1]     [,2]     [,3]     [,4]     [,5]
##[1,] 14.94328 14.94328 14.94328 14.94328 14.94328
##[2,] 14.94328 14.94328 14.94328 14.94328 14.94328
##[3,] 14.94328 14.94328 14.94328 14.94328 14.94328
##[4,] 14.94328 14.94328 14.94328 14.94328 14.94328
##[5,] 14.94328 14.94328 14.94328 14.94328 14.94328
##[6,] 14.94328 14.94328 14.94328 14.94328 14.94328
```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant