Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thoughts about a set_inits() or gen_inits() function? #1645

Open
venpopov opened this issue Apr 9, 2024 · 3 comments
Open

thoughts about a set_inits() or gen_inits() function? #1645

venpopov opened this issue Apr 9, 2024 · 3 comments
Labels

Comments

@venpopov
Copy link
Contributor

venpopov commented Apr 9, 2024

For some of the models we are developing in https://github.com/venpopov/bmm/, we have run into an issue that the models need us to specify initial values, otherwise sampling cannot initiate. I know how to do this for a specific model, by looking at the names of parameters in parameters block of the generated Stan code, and at their dimensions by looking at the generated stan data. And then constructing a function to generate n lists of random initial values given some specification (where n is the number of chains). However, we cannot currently figure out how to set inits in a way that would generalize to arbitrary formulas that users provide.

I was thinking it would be helpful to have an exported function in brms like set_inits() in which you could specify how initial values should be generated in a way that does not require you to know the internal names of parameters or their dimensions. I know this topic has come up a few times over the years, and in one issue you mentioned it would be nice, but that it's unclear what the design of such a function should look like.

I had an idea. A function like set_inits could be modeled after set_prior. Let's say you have the following model code (for a wiener family, but this is just an example):

library(brms)
library(rtdists)
options(width = 1000)
data <- speed_acc
data$resp <- ifelse(speed_acc$response == "word", 1, 0)
data <- data[!data$censor,]

formula <- bf(rt | dec(resp) ~ stim_cat + (stim_cat | id),
              bs ~ stim_cat + (stim_cat | id),
              ndt ~ 1 + (1 | id),
              bias ~ 1 + (1 | id))

family <- wiener()

default_prior(formula, data, family)
#>                   prior     class            coef group resp dpar nlpar lb ub       source
#>                  (flat)         b                                                  default
#>                  (flat)         b stim_catnonword                             (vectorized)
#>                  lkj(1)       cor                                                  default
#>                  lkj(1)       cor                    id                       (vectorized)
#>  student_t(3, 0.6, 2.5) Intercept                                                  default
#>    student_t(3, 0, 2.5)        sd                                        0         default
#>    student_t(3, 0, 2.5)        sd                    id                  0    (vectorized)
#>    student_t(3, 0, 2.5)        sd       Intercept    id                  0    (vectorized)
#>    student_t(3, 0, 2.5)        sd stim_catnonword    id                  0    (vectorized)
#>          logistic(0, 1) Intercept                            bias                  default
#>    student_t(3, 0, 2.5)        sd                            bias        0         default
#>    student_t(3, 0, 2.5)        sd                    id      bias        0    (vectorized)
#>    student_t(3, 0, 2.5)        sd       Intercept    id      bias        0    (vectorized)
#>                  (flat)         b                              bs                  default
#>                  (flat)         b stim_catnonword              bs             (vectorized)
#>       normal(-0.6, 1.3) Intercept                              bs                  default
#>    student_t(3, 0, 2.5)        sd                              bs        0         default
#>    student_t(3, 0, 2.5)        sd                    id        bs        0    (vectorized)
#>    student_t(3, 0, 2.5)        sd       Intercept    id        bs        0    (vectorized)
#>    student_t(3, 0, 2.5)        sd stim_catnonword    id        bs        0    (vectorized)
#>                  (flat) Intercept                             ndt                  default
#>    student_t(3, 0, 2.5)        sd                             ndt        0         default
#>    student_t(3, 0, 2.5)        sd                    id       ndt        0    (vectorized)
#>    student_t(3, 0, 2.5)        sd       Intercept    id       ndt        0    (vectorized)

Created on 2024-04-09 with reprex v2.1.0

Then imagine having a function set_inits, which would be called like this:

inits <- set_inits('normal(-0.6, 0.5)', class = 'Intercept', dpar = "bs") +
  set_inits('normal(0, 0.2)', class = 'b', dpar = "bs") +
  set_inits('uniform(0, 0.3)', class = 'sd', dpar = "bs") +
  set_inits('normal(0, 0.5)', class = 'z', dpar = "bs") + # on standardized group-level effects
  set_inits('uniform(-6, -2)', class = "Intercept", dpar = "ndt") +
  set_inits('uniform(0, 0.3)', class = 'sd', dpar = "ndt") +
  set_inits('uniform(-0.1, 0.1)', class = 'z', dpar = "ndt") + # on standardized group-level effects
# etc....

and can be passed to brms's init argument:

brm(formula, data, family, init = inits)

Internally, this function would create the random initial values that you might code currently manually as:

sdata <- standata(formula, data, family)
inits_fun <- function() {
  list(
    Intercept_bs = rnorm(1, -0.6, 0.5),
    b_bs = array(rnorm(sdata$Kc_bs, 0, 0.2)),
    sd_2 = array(runif(sdata$M_2, 0, 0.1)),
    z_2 = array(rnorm(sdata$M_2*sdata$N_2, 0, 0.5), dim = c(sdata$M_2, sdata$N_2)),
    sd_3 = array(runif(sdata$M_3, 0, 0.3)),
    z_3 = array(runif(sdata$M_3*sdata$N_3, -0.1, 0.1), dim = c(sdata$M_3, sdata$N_3)),
    Intercept_ndt = runif(1, -6, -3)
  )
}
inits <- replicate(4, inits_fun(), simplify = FALSE)

What do you think? I think this can nicely avoid having to figure out what the appropriate dimensions are, what structure to put the inits in (e.g. arrays, vs scalars), and can be specified using the familiar set_prior syntax.

I tried to play around with this concept, but the code behind generating the parameter names and dimensions is a bit too deeply nested for me to understand. But the building blocks should be there - the function will need to repurpose some of the code that currently generates the stancode and standata.

If you like the idea I can give it a go for a simple prototype, but I'd likely need your help for making it work in general.

(PS: in principle I can currently generate a function like the second inits_fun by generating the standata and stancode to get the parameter names and dimensions, but I'm worried about how reliable that is - for example, switching the order of parameters in the formula also switches the indexing of the z_* and sd_* parameters. It also unnecessarily constructs the data and stancode, which our call to brm will then do anyway again. So I would love if we can make such a set_inits() function native to brms)

@paul-buerkner
Copy link
Owner

I agree this would be a nice feature and I also like your general idea. Currently, the mapping between parameter names and priors for that matter is indeed deeply hidden inside the Stan code generation. But if we were to abstract this mapping, we could use it for both priors and inits without too much code duplication.

A simple prototype as the basis for discussion sounds good. Just make sure to not spend too much time on the prototype for now since we will likely have to change the initial approach in various ways.

@venpopov
Copy link
Contributor Author

venpopov commented Apr 9, 2024

Cool. Yes, I will put up something simple that works for a simple case (perhaps just for an Intercept and population parameters), and a design doc with notes of what would be necessary for full functionality and make a draft PR to use that as the basis for discussion.

@paul-buerkner
Copy link
Owner

paul-buerkner commented Apr 9, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants