Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for data frames with an rvar column #604

Open
mattansb opened this issue Jun 4, 2023 · 4 comments
Open

Add support for data frames with an rvar column #604

mattansb opened this issue Jun 4, 2023 · 4 comments
Labels
enhancement 💥 Implemented features can be improved or revised

Comments

@mattansb
Copy link
Member

mattansb commented Jun 4, 2023

The rvar class is a super convenient way to wrangle posteriors.

Here is a basic workflow I've adopted:

library(brms)
library(posterior)
library(bayestestR)

mod <- brm(mpg ~ factor(cyl) + hp, data = mtcars,
           backend = "cmdstanr", cores = 4, refresh = 0)
#> Compiling Stan program...
#> Start sampling
#> Running MCMC with 4 parallel chains...
#> 
#> Chain 1 finished in 0.3 seconds.
#> Chain 2 finished in 0.3 seconds.
#> Chain 3 finished in 0.2 seconds.
#> Chain 4 finished in 0.2 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 0.3 seconds.
#> Total execution time: 0.5 seconds.
#> 

grid <- expand.grid(
  cyl = unique(mtcars$cyl),
  hp = mean(mtcars$hp)
)

grid$epred <- posterior_epred(mod, grid) |> rvar()

grid
#>   cyl       hp    epred
#> 1   6 146.6875 19 ± 1.3
#> 2   4 146.6875 25 ± 1.4
#> 3   8 146.6875 17 ± 1.3

Unfortunately, this data frame structure does not work with bayestestR functions (here is an example with describe_posterior()):

describe_posterior(grid)
#> Error in rbind(deparse.level, ...) : 
#>   numbers of columns of arguments do not match

And passing the rvar column loses all the grid data:

describe_posterior(grid$epred)
#> Summary of Posterior Distribution
#> 
#> Parameter | Median |         95% CI |   pd |          ROPE | % in ROPE
#> ----------------------------------------------------------------------
#> x[1]      |  19.13 | [16.63, 21.74] | 100% | [-0.10, 0.10] |        0%
#> x[2]      |  25.11 | [22.24, 27.94] | 100% | [-0.10, 0.10] |        0%
#> x[3]      |  16.57 | [14.00, 19.26] | 100% | [-0.10, 0.10] |        0%

What I'd like is for the foo.data.frame() methods to check if there is an rvar column - if there is (only 1), it should be processed and added to the grid. Like so:

describe_posterior.data.frame <- function(posteriors, ...) {
  if (any(is_rvar <- sapply(posteriors, inherits, "rvar"))) {
    if (sum(is_rvar, na.rm = TRUE) > 1L) stop("Cannot use more than 1 rvar")
    
    cl <- match.call()
    cl[[1]] <- quote(describe_posterior)
    cl$posteriors <- posteriors[,is_rvar, drop = TRUE]
    out <- eval(cl)
    out$Parameter <- NULL
    
    other_columns <- posteriors[,!is_rvar, drop = FALSE]
    to_add <- colnames(other_columns)
    out[to_add] <- other_columns
    out <- datawizard::data_relocate(out, select = to_add, before = 1)
    return(out)
  }

  # else, do as usual...
  bayestestR:::describe_posterior.data.frame(posteriors, ...)
}

describe_posterior(grid)
#> Summary of Posterior Distribution
#> 
#> cyl  |     hp | Median |         95% CI |   pd |          ROPE | % in ROPE
#> --------------------------------------------------------------------------
#> 6.00 | 146.69 |  19.19 | [16.50, 21.88] | 100% | [-0.10, 0.10] |        0%
#> 4.00 | 146.69 |  25.16 | [22.28, 27.99] | 100% | [-0.10, 0.10] |        0%
#> 8.00 | 146.69 |  16.55 | [13.94, 19.21] | 100% | [-0.10, 0.10] |        0%

Should be easy enough to implement...

@mattansb mattansb added the enhancement 💥 Implemented features can be improved or revised label Jun 4, 2023
@strengejacke
Copy link
Member

I'm not quite sure what the function is doing, isn't there an easy way to convert rvar into a regular data frame that can be cbind() to posteriors? Why do you need to deal with calls and evaluation?

@mattansb
Copy link
Member Author

Nah, ignore all that gross code, I copied it from something else I had. We definitely can convert the rvar into a data frame, but I want to preserve the other columns and return them as well.

@strengejacke
Copy link
Member

It could be roughly something like

library(brms)
library(posterior)
library(bayestestR)

mod <- brm(mpg ~ factor(cyl) + hp, data = mtcars,
           backend = "cmdstanr", cores = 4, refresh = 0)

grid <- expand.grid(
  cyl = unique(mtcars$cyl),
  hp = mean(mtcars$hp)
)

grid$epred <- posterior_epred(mod, grid) |> rvar()

x <- grid
x$epred <- NULL

rbind(
  bayestestR::describe_posterior(x),
  bayestestR::describe_posterior(grid$epred)
)

#> Summary of Posterior Distribution
#> 
#> Parameter | Median |           95% CI |   pd |          ROPE | % in ROPE
#> ------------------------------------------------------------------------
#> cyl       |   6.00 | [  4.10,   7.90] | 100% | [-0.10, 0.10] |        0%
#> hp        | 146.69 | [146.69, 146.69] | 100% | [-0.10, 0.10] |        0%
#> x[1]      |  19.16 | [ 16.55,  21.76] | 100% | [-0.10, 0.10] |        0%
#> x[2]      |  25.12 | [ 22.25,  28.03] | 100% | [-0.10, 0.10] |        0%
#> x[3]      |  16.56 | [ 13.87,  19.21] | 100% | [-0.10, 0.10] |        0%

with some nicer printing. Not sure what rvar() exactly does or is useful for, so I'm not sure if this is the most appropriate implementation?

@mattansb
Copy link
Member Author

mattansb commented May 8, 2024

rvar is a data type holding a random variable. So instead of having a vector of posterior samples, you can have it in a scalar:

x <- rnorm(4000)
head(x)
#> [1] -0.627925890  0.675107948  1.375053570  1.195942455 -0.006494341  0.894640065

rvar(x)
#> rvar<4000>[1] mean ± sd:
#> [1] 0.024 ± 0.99

Or instead of having a matrix of posterior samples x estimates, you can have a vector of rvars:

m <- matrix(rnorm(4000*4), 4000, 4)
head(m)
#>            [,1]       [,2]        [,3]       [,4]
#> [1,] -0.8296439 -0.2135017  0.35422961  1.0199753
#> [2,] -0.7488206 -0.2938458  0.67141688  1.1986600
#> [3,] -0.7133707  1.0533782  0.20289016 -0.8099680
#> [4,] -0.4385958 -0.4919160  0.07978666  0.5154708
#> [5,] -0.5045125  1.5209865  0.58179130  0.8505250
#> [6,] -0.1048083 -0.7765123 -1.12228048 -0.9841341

rvar(m)
#> rvar<4000>[4] mean ± sd:
#> [1]  0.0049 ± 1.00   0.0151 ± 1.00   0.0038 ± 1.02  -0.0186 ± 0.99 

You can read more about it here: https://mc-stan.org/posterior/articles/rvar.html


describe_posterior.data.frame <- function(posteriors, ...) {
  is_rvar <- which(sapply(posteriors, inherits, "rvar"))
  if (length(is_rvar) == 0L) {
    # Do usual stuff...
    return(bayestestR:::describe_posterior.data.frame(posteriors, ...))
  } else if (length(is_rvar) > 1L) {
    insight::format_error("Cannot use more than 1 rvar")
  }
  
  # describe_posterior(<rvar>)
  out <- describe_posterior(posteriors[[is_rvar]], ...)
  out$Parameter <- NULL
  
  # append other data frame columns to the beginning of the output
  df <- posteriors[,-is_rvar, drop = FALSE]
  df_nms <- colnames(df)
  out[df_nms] <- df
  out <- datawizard::data_relocate(out, select = df_nms, before = 1)

  return(out)
}

Here's an example of typical workflow:

library(brms)
library(posterior)
library(bayestestR)

mod <- brm(mpg ~ factor(cyl) + hp, data = mtcars,
           backend = "cmdstanr", cores = 4, refresh = 0)
#> Compiling Stan program...
#> Start sampling
#> Running MCMC with 4 parallel chains...
#> 
#> Chain 1 finished in 0.3 seconds.
#> Chain 2 finished in 0.3 seconds.
#> Chain 3 finished in 0.2 seconds.
#> Chain 4 finished in 0.2 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 0.3 seconds.
#> Total execution time: 0.5 seconds.
#> 

grid <- expand.grid(
  cyl = unique(mtcars$cyl),
  hp = mean(mtcars$hp)
)

grid$epred <- posterior_epred(mod, grid) |> rvar()

grid
#> #>   cyl       hp    epred
#> #> 1   6 146.6875 19 ± 1.3
#> #> 2   4 146.6875 25 ± 1.4
#> #> 3   8 146.6875 17 ± 1.3

describe_posterior(grid)
#> Summary of Posterior Distribution
#> 
#> cyl  |     hp | Median |         95% CI |   pd |          ROPE | % in ROPE
#> --------------------------------------------------------------------------
#> 6.00 | 146.69 |  19.16 | [16.63, 21.63] | 100% | [-0.10, 0.10] |        0%
#> 4.00 | 146.69 |  25.10 | [22.20, 27.79] | 100% | [-0.10, 0.10] |        0%
#> 8.00 | 146.69 |  16.65 | [14.10, 19.29] | 100% | [-0.10, 0.10] |        0%

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement 💥 Implemented features can be improved or revised
Projects
None yet
Development

No branches or pull requests

2 participants