Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Divergence checking does not maintain time-reversibility #1719

Open
howsiyu opened this issue Jan 12, 2024 · 3 comments
Open

Divergence checking does not maintain time-reversibility #1719

howsiyu opened this issue Jan 12, 2024 · 3 comments
Labels
wontfix This will not be worked on

Comments

@howsiyu
Copy link

howsiyu commented Jan 12, 2024

As per title. This can cause sampling bias when we have jump near max_delta_energy. Here's a concrete example translated from @nhuurre 's stan code at https://discourse.mc-stan.org/t/divergence-check-does-not-satisfy-time-reversibility/33738.

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import seaborn as sns

# Sampling is biased when jump is near `max_delta_energy`.
jump = 1000.0

def model():
    x = numpyro.sample("x", dist.Normal())
    numpyro.factor("jump", jnp.where(x > 0, jump, 0))

nuts_kernel = NUTS(model, target_accept_prob=0.6)
mcmc = MCMC(nuts_kernel, num_warmup=2_000, num_samples=20_000)
mcmc.run(jax.random.key(0))

sns.histplot(x=mcmc.get_samples()['x'], stat="density")
xs = jnp.linspace(0.0, 3.5, 100)
sns.lineplot(x=xs, y=jnp.exp(jax.vmap(dist.HalfNormal().log_prob)(xs)))
@fehiepsi fehiepsi added the wontfix This will not be worked on label Jan 12, 2024
@fehiepsi
Copy link
Member

fehiepsi commented Jan 12, 2024

I guess when a divergence happens, we will ignore the trajectory and diagnose the model to find the issue. So I don't think we need to worry about the time-reversibility here.

@howsiyu
Copy link
Author

howsiyu commented Jan 12, 2024

when a divergence happens, we will ignore the trajectory and diagnose the model to find the issue.

If that's the case, shouldn't numpyro stop immedately once it encounters a divergence post-warmup?

@martinjankowiak
Copy link
Collaborator

maybe numpyro should stop immediately because floating point errors violate time-reversibility ; )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants