Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offset/multiplier noncentering gets trapped at low sigma #3088

Open
jsocolar opened this issue May 7, 2021 · 3 comments
Open

Offset/multiplier noncentering gets trapped at low sigma #3088

jsocolar opened this issue May 7, 2021 · 3 comments

Comments

@jsocolar
Copy link

jsocolar commented May 7, 2021

Summary:

In a non-centered normal distribution, if the standard deviation is initialized to a very low value, it crashes to an extremely low value and never recovers. Manual non-centering does not encounter the same problem. This makes non-centering via offset/multiplier unusable for some models (see below).

For more, see: https://discourse.mc-stan.org/t/offset-multiplier-initialization/20712

Description and Reprex

Here's @bbbales2 on the issue, pasted over from discourse (link above):

Here is the offset-multiplier model

parameters{
  real<lower = 0> sigma;
  real<multiplier=sigma> x;
}
model{
  sigma ~ std_normal();
  x ~ normal(0, sigma);
}

Here is the manual offset-multiplier model:

parameters{
  real<lower = 0> sigma;
  real x_raw;
}
transformed parameters {
  real x = x_raw * sigma;
}
model{
  sigma ~ std_normal();
  x ~ normal(0, sigma);
  target += log(sigma);
}

Code to run them is

library(tidyverse)
library(cmdstanr)

mod1 = cmdstan_model("mod1.stan")
inits_chain_1 = list(sigma = 1e-20)
fit1 = mod1$sample(chains = 1, init = list(inits_chain_1), iter_sampling = 1000)
fit1$summary()

mod2 = cmdstan_model("mod2.stan")
fit2 = mod2$sample(chains = 1, init = list(inits_chain_1), iter_sampling = 1000)
fit2$summary()

You get output like this for the build-in offset-multiplier:

> fit1$summary()
# A tibble: 3 x 10
  variable      mean    median    sd   mad        q5       q95  rhat ess_bulk
  <chr>        <dbl>     <dbl> <dbl> <dbl>     <dbl>     <dbl> <dbl>    <dbl>
1 lp__     -9.01e+38 -9.01e+38     0     0 -9.01e+38 -9.01e+38    NA       NA
2 sigma     3.12e-20  3.12e-20     0     0  3.12e-20  3.12e-20    NA       NA
3 x        -1.33e+ 0 -1.33e+ 0     0     0 -1.33e+ 0 -1.33e+ 0    NA       NA
# … with 1 more variable: ess_tail <dbl>

And the output with a custom offset-multiplier looks like:

> fit2$summary()
# A tibble: 4 x 10
  variable     mean  median    sd   mad      q5    q95  rhat ess_bulk ess_tail
  <chr>       <dbl>   <dbl> <dbl> <dbl>   <dbl>  <dbl> <dbl>    <dbl>    <dbl>
1 lp__     -1.57    -1.24   0.987 0.761 -3.64   -0.563  1.00     292.     430.
2 sigma     0.821    0.703  0.594 0.591  0.0667  2.02   1.00     437.     366.
3 x_raw     0.0115   0.0396 0.983 1.01  -1.64    1.59   1.00     523.     660.
4 x         0.00698  0.0104 0.967 0.547 -1.66    1.51   1.00     492.     503.

This is pretty repeatable that the custom code doesn’t have a problem with inits but the built in does.

Encountering this issue "in the wild"

Some classes of model reliably pinch through very small standard deviations early in warmup. Here's an example--not extreme enough to hit the "sticky boundary", but enough to show why it can be an issue. Notice how in early warmup sigma pinches down to a very low value before recovering. This is consistent across seeds.

data{
  int n;
  real y[n];
}

parameters{
  real<lower = 0> sigma;
  real<multiplier=sigma> x[n];
}

model{
  sigma ~ std_normal();
  x ~ normal(0, sigma);
  y ~ normal(x, .01);
}

Code to run:

library(cmdstanr)
pinch <- cmdstan_model("
<img width="853" alt="Screen Shot 2021-05-07 at 12 32 38 PM" src="https://user-images.githubusercontent.com/11272480/117487161-54be9800-af30-11eb-8147-7d482dac60ad.png">
pinch.stan")
set.seed(10)
n <- 50000
pinch_samples <- pinch$sample(data = list(n=n, y=rnorm(n)),
                              chains = 1, save_warmup = T,
                              iter_warmup = 30, iter_sampling = 1)
pinch_csv <- read_cmdstan_csv("filename")  # $draws() is still really slow on many-parameter models
plot(pinch_csv$warmup_draws[,1,"sigma"])

Screen Shot 2021-05-07 at 12 32 38 PM

It's not the end of the world, because manual non-centering still works fine, but this issue makes offset/multiplier noncentering unusable in some of the models I work with.

Current Version:

v2.26.1

@jsocolar
Copy link
Author

jsocolar commented May 7, 2021

Let me know if this actually needs to be filed against math or somewhere else.

@bbbales2
Copy link
Member

bbbales2 commented May 8, 2021

This is interesting and seems quite annoying, though I don't know what to do about it. @LuZhangstat @bob-carpenter @avehtari for visibility

@rok-cesnovar rok-cesnovar transferred this issue from stan-dev/cmdstan Nov 22, 2021
@rok-cesnovar
Copy link
Member

Moving this to Stan, as the interface cant really help with this problem I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants