Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a new MCMC method #1662

Open
reubenharry opened this issue Oct 12, 2023 · 5 comments
Open

Adding a new MCMC method #1662

reubenharry opened this issue Oct 12, 2023 · 5 comments
Labels
enhancement New feature or request

Comments

@reubenharry
Copy link

I'm a collaborator on this project https://github.com/JakobRobnik/MicroCanonicalHMC, and we're interested in either adding our algorithm to NumPyro, or using NumPyro in our codebase. With that in mind, we have a couple of questions:

  1. The simplest thing we'd like to do is to be able to write a probabilistic program in Numpyro like
def rosenbrock(d, Q):
    x = numpyro.sample("x", dist.Normal(jnp.ones(d // 2), jnp.ones(d // 2)))
    numpyro.sample("y", dist.Normal(jnp.square(x), np.sqrt(Q) * jnp.ones(d // 2)))

and then be able to extract the density function $f : \mathbb{R}^2 \to \mathbb{R}$. We want $f$ explicitly, because it's what we need to pass to our code in order to run our inference algorithm. However, we had some difficulty extracting it in a simple fashion from Numpyro, and I'm currently doing something a bit hacky, like:

vars = [init_model_trace[i]['name'] for i in init_model_trace]

    def potential_fn(arr):
        tr = trace(condition(model, dict(zip(vars, arr)) )).get_trace()
        return -sum([tr[x]['fn'].log_prob(arr[i]) for i, x in enumerate(tr)])

Is there a simpler and better way?

  1. We'd potentially be interested in adding our algorithm to NumPyro, as a kernel in addition to NUTS and HMC. Would that be of interest, and if so, do you have any guidance? The code for the HMC kernel looks quite complex, but perhaps there's a simpler example somewhere to follow.

Thanks!

Reuben

@martinjankowiak
Copy link
Collaborator

hi @reubenharry -- yes it'd be great to add your mcmc method to numpyro and we can certainly help guide you along the process of getting a PR merged.

i see at least two options: i) you could implement a self-contained kernel that introduces no new dependencies; ii) you could implement a kernel that mostly consists of boiler plate and hands off the core of the algorithm to the MicroCanonicalHMC repo.

the benefit of the former is that there are no new dependencies, unit testing is contained within numpyro, etc, but the disadvantage is that any improvements to the algorithm won't filter down to numpyro without an explicit PR that implements new functionality.

the benefit of the second approach is that numpyro can benefit from any algorithm improvements in the upstream repo, but disadvantages include introducing a new dependency, the possibility that breaking changes are introduced upstream as well as the possibility that maintenance of the upstream repo slows down or ceases entirely.

currently i believe we have one instance of the second path, namely nested sampling which introduces a dependency on jaxns. @reubenharry do you have a preference? the second path probably only makes sense if you're pretty committed to maintaining the repo and if you foresee the algorithm evolving for the better over time.

@fehiepsi do you have suggestions for point 1?

@reubenharry
Copy link
Author

Thanks for your advice! In the medium term, the first option seems appealing, particularly since we're also working in Jax. The kernel itself is actually quite simple; the difficulty is that there is autotuning code which is a little more involved, and it wasn't immediately obvious to me how much control over that I would have if I went with option number one. Furthermore, I got a little intimidated looking at the HMC code, which has a few layers of abstractions, but I'm sure with guidance it wouldn't be so hard to do something similar :)

Currently we've opted for a simpler third option, which is just to express a program in numpyro, extract its density function, and then use that in our repo. I also did something somewhat like option 1. You can see both here: https://github.com/JakobRobnik/MicroCanonicalHMC/pull/18/files#diff-e81cce67759d32ecde8fc48bb864dd0ac7ecc01286a35ceab268e62e9181c0e3

I'll discuss with the other developers and see what their preferences are - perhaps at some point further down the road we can all chat in person.

@fehiepsi
Copy link
Member

Hi @reubenharry you can make a new kernel as a subclass of MCMCKernel. To convert a numpyro model to a potential function, you can use initialize_model helper. This helper also returns a postprocess function to convert unconstrained values into constrained values - which can be used in the postprocess_fn method of the MCMCKernel. Let us know if something is unclear.

@reubenharry
Copy link
Author

Thanks! I'm having one issue:

def m():

    mu = numpyro.sample('mu', dist.Normal(3, 1))
    nu = numpyro.sample('nu', dist.Normal(mu+1, 2))

rng_key = jax.random.PRNGKey(0)
rng_key, init_key = jax.random.split(rng_key)
init_params, potential_fn_gen, *_ = initialize_model(
    init_key,
    m,
    model_args=(),
    dynamic_args=True,
)

print(potential_fn_gen()(jnp.array([4,2])))

>   File "/opt/homebrew/lib/python3.11/site-packages/numpyro/distributions/continuous.py", line 2035, in sample
    assert is_prng_key(key)
> AssertionError

@tare
Copy link
Contributor

tare commented Oct 24, 2023

This should work

print(potential_fn_gen()(init_params.z))

Please check howinit_params.z looks.

@fehiepsi fehiepsi added the enhancement New feature or request label Oct 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants