Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Length of weights #58

Open
n8layman opened this issue Jul 25, 2023 · 1 comment
Open

Length of weights #58

n8layman opened this issue Jul 25, 2023 · 1 comment

Comments

@n8layman
Copy link

I'm opening this issue to get more clarity about #57.

Whenever both test and weights are specified in bart2 I get the following error:

Error in model.frame.default(formula = as.formula("price ~ ."), data = train_data, : variable lengths differ (found for '(weights)')

However, I am unclear as to what length the weights vector should be. All three attempts below generate the same error.

library(tidyverse)
library(rsample)
library(dbarts)

resample <- rsample::initial_split(diamonds)
train_data <- rsample::training(resample)
test_data <- rsample::testing(resample)

fit <- dbarts::bart2(formula = as.formula("price ~ ."),
              data = train_data,
              test = test_data,
              weights = train_data |> pull(carat),
              verbose = T)

fit <- dbarts::bart2(formula = as.formula("price ~ ."),
                     data = train_data,
                     test = test_data,
                     weights = test_data |> pull(carat),
                     verbose = T)

fit <- dbarts::bart2(formula = as.formula("price ~ ."),
                     data = train_data,
                     test = test_data,
                     weights = diamonds |> pull(carat),
                     verbose = T)

The only time it works is when train_data == test_data. I'm likely just misunderstanding something basic but I'm not clear where I'm going wrong.

@EoghanONeill
Copy link

EoghanONeill commented Aug 1, 2023

I have received the same error message for the dbarts function. A different error occurs if the test and training data contain the same number of observations.

The error message occurs if both test data and weights are included as input, and there is no option for test data weights. Also, I assume outputs such as $train and $test are not weighted?

Perhaps this is not an issue if weights are not included in the dbarts function, and then before each time I run sampler1$run(), I set the weights

sampler1$setWeights(weights = weightstemp)
sampler1$run()

The test weights are presumably unnecessary because they have no impact on the fitted model or test predictions?

Here is a reproducible example.

library(dbarts)

f <- function(x) {
  10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
    10 * x[,4] + 5 * x[,5]
}

set.seed(99)
sigma <- 1.0
n     <- 100
ntest <- 50


x  <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y  <- rnorm(n, Ey, sigma)

data <- data.frame(y00 = y, x00 = x)

xtest  <- matrix(runif(ntest * 10), ntest, 10)
Eytest <- f(xtest)
ytest  <- rnorm(ntest, Eytest, sigma)

datatest <- data.frame(y00 = NA,x00 = xtest)

weightstemp <- rep(1,n)
weightstemp_test <- rep(1,ntest)

control1 <- dbartsControl(n.samples = 1L,
                          n.chains = 1L,
                          n.threads = 1L)

sampler1 <- dbarts(y00 ~ .,
                   data =data,
                   test = datatest,
                   control = control1,
                   weights = weightstemp)

This produces the message

Error in model.frame.default(formula = ~. - y, data = list(y = c(NA, NA,  :
  variable lengths differ (found for '(weights)')

However, if the test data has the same length as the training data the following error occurs:

Error in validObject(.Object) :
  invalid class “dbartsData” object: 'weights.test' must be null or have the same number of rows as 'x.test'

Here is a reproducible example of the second error message:

library(dbarts)

f <- function(x) {
  10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
    10 * x[,4] + 5 * x[,5]
}

set.seed(99)
sigma <- 1.0
n     <- 100
ntest <- 100


x  <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y  <- rnorm(n, Ey, sigma)

data <- data.frame(y00 = y, x00 = x)


xtest  <- matrix(runif(ntest * 10), ntest, 10)
Eytest <- f(xtest)
ytest  <- rnorm(ntest, Eytest, sigma)

datatest <- data.frame(y00 = NA,x00 = xtest)

weightstemp <- rep(1,n)
weightstemp_test <- rep(1,ntest)


control1 <- dbartsControl(n.samples = 1L,
                          n.chains = 1L,
                          n.threads = 1L)

sampler1 <- dbarts(y00 ~ .,
                   data =data,
                   test = datatest,
                   control = control1,
                   weights = weightstemp)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants