Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mean_accept_prob significantly different after warmup #1786

Open
jonny-so opened this issue Apr 19, 2024 · 7 comments
Open

mean_accept_prob significantly different after warmup #1786

jonny-so opened this issue Apr 19, 2024 · 7 comments

Comments

@jonny-so
Copy link

jonny-so commented Apr 19, 2024

I notice that after warmup, the mean_accept_prob significantly higher than both target_accept_prob and the mean_accept_prob observed during warmup, even on a trivial isotropic gaussian example. Minimum working example:

import jax.numpy as jnp
from jax.lax import scan
from numpyro.infer.hmc import hmc

def potential(x):
    return 0.5 * jnp.sum(x**2)

d = 10
nwarmup = 100000
nsamples = 100000

init_kernel, sample_kernel = hmc(potential, algo='HMC')
hmc_state = init_kernel(init_params=jnp.zeros(d), num_warmup=nwarmup, adapt_step_size=True, adapt_mass_matrix=False)

hmc_state = scan(lambda s, _: (sample_kernel(s), None), hmc_state, None, length=nwarmup)[0]
print("post warmup", hmc_state.mean_accept_prob)

hmc_state = scan(lambda s, _: (sample_kernel(s), None), hmc_state, None, length=nsamples)[0]
print("post samples", hmc_state.mean_accept_prob)

outputs:

post warmup 0.7992318
post samples 0.97852784

am I misusing something here?

@fehiepsi
Copy link
Member

fehiepsi commented Apr 30, 2024

In the early phase, I guess the sampler tends to reject many samples. Hence you can see the smaller accept_prob than in the sampling phase. We use dual averaging to adapt step size and update the step size at the end of the warm-up phase,

# note: at the end of warmup phase, use average of log step_size

@fehiepsi fehiepsi added the question Further information is requested label Apr 30, 2024
@jonny-so
Copy link
Author

jonny-so commented Apr 30, 2024

I see that they won't be the same, but the eventual accept rate is almost 100% suggesting the learned step size is too small. Note that I am targeting the default accept rate of 80%. Could this be the same issue discussed by the stan guys here? stan-dev/stan#3105.

@fehiepsi fehiepsi added discussion and removed question Further information is requested labels Apr 30, 2024
@fehiepsi
Copy link
Member

You're right - the step size seems to be small. I'll look into the adaptation dynamic later this week. If you are interested, you can extract more information from scan body function, like step_size, accept_prob, etc.

@fehiepsi fehiepsi modified the milestone: 0.15 May 12, 2024
@fehiepsi
Copy link
Member

fehiepsi commented May 13, 2024

@jonny-so This turns out to be the issue of the dual averaging algorithm that we used

import jax.numpy as jnp
from jax.lax import scan
from numpyro.infer.hmc import hmc

def potential(x):
    return 0.5 * jnp.sum(x**2)

d = 10
nwarmup = 10000
nsamples = 10000

init_kernel, sample_kernel = hmc(potential, algo='HMC')
hmc_state = init_kernel(init_params=jnp.zeros(d), num_warmup=nwarmup, adapt_step_size=True, adapt_mass_matrix=False)

hmc_state_warmup, step_sizes = scan(lambda s, _: (sample_kernel(s), s.adapt_state.step_size), hmc_state, None, length=nwarmup)
print("post warmup", hmc_state_warmup.mean_accept_prob)

hmc_state = scan(lambda s, _: (sample_kernel(s), None), hmc_state_warmup, None, length=nsamples)[0]
print("post samples", hmc_state.mean_accept_prob)

print("exp(mean(log(last_50_step_sizes)))", jnp.exp(jnp.log(step_sizes[-50:]).mean()))
print("mean(last_50_step_sizes)", step_sizes[-50:].mean())
post warmup 0.7974499
post samples 0.97814786
exp(mean(log(last_50_step_sizes))) 0.8056808
mean(last_50_step_sizes) 1.37691

We use dual averaging over the last window buffer (50 steps) of the warmup phase. With that, the estimation for step_size is biased (exp_mean_log <= mean)). Let me think a bit more about what we can do here. We can expose a configuration to modify the length of the last window buffer, so that the estimation is better. What do you think?

cc @martinjankowiak do you have any suggestions dealing with this issue?

@martinjankowiak
Copy link
Collaborator

i'm not sure but if you wanted to reduce that specific bias i guess you could use the formula for the mean of a log normal distribution....

log_step_sizes = jnp.log(step_sizes[-50:])
jnp.exp(log_step_sizes.mean() + 0.5 * log_step_sizes.var())

@fehiepsi
Copy link
Member

It looks like the implementation agrees with Algorithm 5 in https://jmlr.org/papers/volume15/hoffman14a/hoffman14a.pdf#page=18.62 I guess it is better to let users control the last window size.

@jonny-so
Copy link
Author

Sorry for the delay, I've been flat out for the neurips deadline. I need to think about this a bit, but I'm taking a week off to recover... I'll come back to you soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants