Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HMCGibbs with chain_method=”vectorized” #1725

Open
WolfgangEnzi opened this issue Jan 30, 2024 · 4 comments
Open

HMCGibbs with chain_method=”vectorized” #1725

WolfgangEnzi opened this issue Jan 30, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@WolfgangEnzi
Copy link

I am trying to use HMCGibbs sampling with more than one chain using chain_method=“vectorized”, but there appears to be some problem with splitting the random keys.

Consider this toy example that I copied from the numpyro documentation, where I only changed the chain_method and the number of chains:

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMCGibbs

def model():
    x = numpyro.sample("x", dist.Normal(0.0, 2.0))
    y = numpyro.sample("y", dist.Normal(0.0, 2.0))
    numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))

def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
    y = hmc_sites['y']
    new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key)
    return {'x': new_x}

hmc_kernel = NUTS(model)
kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x'])
mcmc = MCMC(kernel, num_warmup=100, num_chains=2, num_samples=100, progress_bar=False, chain_method='vectorized',)
mcmc.run(random.PRNGKey(0))
mcmc.print_summary()

I find that I get the following Error when running the above Code:
TypeError: split accepts a single key, but was given a key array of shape (2,) != (). Use jax.vmap for batching.

Is there a way to make the vectorize option available for HMCGibbs sampling?

@fehiepsi fehiepsi added the bug Something isn't working label Jan 30, 2024
@fehiepsi
Copy link
Member

Could you change this line to jax.vmap(...) with the default parallel method to see if it works for HMCGibbs?

@WolfgangEnzi
Copy link
Author

WolfgangEnzi commented Feb 13, 2024

Choosing the "parallel" option and changing to jax.vmap did not work for me. It seems like it still processes the chains in sequential order when I do that.

@fehiepsi
Copy link
Member

Did you set host device to the number of chains: https://num.pyro.ai/en/stable/utilities.html#set-host-device-count?

@CKrawczyk
Copy link

I think the HMCGibbs class's init method needs something similar to https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/hmc.py#L782-L790 to detect if it is getting one key or a list of keys and vmap its initialization and sampling functions as needed.

When I was writing a custom Gibbs sampler (that does an HMC step for each conditional rather than drawing from a known distribution), I was able to get vectorized working this way, so it should be doable for this sampler as well.

I imagine it would look a bit like:

def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
    model_kwargs = {} if model_kwargs is None else model_kwargs.copy()
    def init_fn(init_parms, rng_key):
        ...
       return HMCGibbsState(z, hmc_stat, rng_key)
    if is_prng_key(rng_key):
        init_state = init_fn(init_params, rng_key)
        self._sample_fn = self._sample_one_chain
    else:
        init_state = vmap(init_fn)(init_params, rng_key)
        self._sample_fn = vmap(self._sample_one_chian, in_axis=(0, None, None))
    return device_put(init_state)

and rename the current sample method _sample_one_chain and make a new sample that calls self._sample_fn.

Might need a bit of extra logic around to work as expected but I think it is what the solution would look like.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants