Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maintain chain information from prediction functions #1534

Open
n-kall opened this issue Aug 7, 2023 · 3 comments
Open

Maintain chain information from prediction functions #1534

n-kall opened this issue Aug 7, 2023 · 3 comments
Labels
Milestone

Comments

@n-kall
Copy link

n-kall commented Aug 7, 2023

Currently posterior_predict and similar functions return an array of draws (not a draws_array from posterior), which is S x N dimensions. I would like to combine this array of predictions with the posterior draws (in draws format), but as this array does not have any record of the chains, posterior::bind_draws does not work ("chain_ids do not match").

Example

fit <- brm(yield ~ N * P * K, npk)
posterior::bind_draws(as_draws(fit), as_draws(posterior_predict(fit)))
#> Error: 'chain_ids' of bound objects do not match.

I'd like to avoid having to merge the chains of the posterior draws, so if it would be possible for the prediction functions from brms to return chain information (e.g. using the draws_array format), that would be very useful.

@paul-buerkner
Copy link
Owner

Agreed. However, I cannot just change the output format of such a frequently used function. This has to be part of brms 3.0. I am definitely planning on outputting posterior draws formats. Which of them it will become, I am not yet sure.

@paul-buerkner paul-buerkner added this to the brms 3.0 milestone Aug 9, 2023
@n-kall
Copy link
Author

n-kall commented Aug 9, 2023

That makes total sense. In the meantime, here is a workaround function that I've been using which might be useful for others. It should also work for multivariate responses:

##' predictions as draws
##'
##' @param x brmsfit object
##' @param predict_fn function for predictions
##' @param prediction_names optional names of the predictions
##' @param ... further arguments passed to predict_fn
##' @return draws array of predictions
predictions_as_draws <- function(x, predict_fn, prediction_names = NULL, ...) {
  terms <- brms::brmsterms(x$formula)
  if(inherits(terms, "mvbrmsterms")) {
    responses <- brms::brmsterms(x$formula)$responses
    mv <- TRUE
  } else {
    responses <- ""
    mv <- FALSE
  }
  pred_draws <- list()
  predictions <- predict_fn(x, ...)
  if (!(mv)) {
    # add additional dimension in univariate case
    dim(predictions) <- c(dim(predictions), 1)
  }
  for (resp in seq_along(responses)) {
    # create draws array of predictions for each response variable
    predicted_draws <- posterior::as_draws_array(
      array(
        predictions[, , resp],
        dim = c(
          posterior::ndraws(x) / posterior::nchains(x),
          posterior::nchains(x), dim(predictions)[2]
        )
      )
    )
    # name predicted variables
    posterior::variables(predicted_draws) <-  c(
      paste0(
        responses[[resp]],
        "_pred[",
        seq_along(posterior::variables(predicted_draws)),
        "]")
    )
    pred_draws[[resp]] <- predicted_draws
  }
  # bind draws from different responses
  out <- posterior::bind_draws(pred_draws)
  if (!(is.null(prediction_names))) {
    posterior::variables(out) <- prediction_names
  }
  out
}

@avehtari
Copy link
Contributor

avehtari commented Jan 9, 2024

Would it be possible to add an option to choose the output format, and keep the current format as the default?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants