Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AutoContinuous/funsor bug? #1713

Closed
amifalk opened this issue Jan 5, 2024 · 3 comments · Fixed by #1796
Closed

AutoContinuous/funsor bug? #1713

amifalk opened this issue Jan 5, 2024 · 3 comments · Fixed by #1796

Comments

@amifalk
Copy link
Contributor

amifalk commented Jan 5, 2024

Here's a reproducible example that's taken nearly directly from the Gaussian Mixture Model tutorial. The AutoContinuous guide seems to be the failure mode.

import jax.numpy as jnp
import jax.random as random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, TraceEnum_ELBO, autoguide
from numpyro.handlers import block, seed

data = jnp.array([0.0, 1.0, 10.0, 11.0, 12.0])

K = 2  # Fixed number of components.

def model(data):
    # Global variables.
    weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
    scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))
    
    with numpyro.plate("components", K):
        locs = numpyro.sample("locs", dist.Normal(0.0, 10.0))

    with numpyro.plate("data", len(data)):
        # Local variables.
        assignment = numpyro.sample("assignment", dist.Categorical(weights), 
                                    infer={"enumerate":"parallel"})
        numpyro.sample("obs", dist.Normal(locs[assignment], scale), obs=data)
        
# this works
guide = autoguide.AutoNormal(block(seed(model, rng_seed=0), hide=['assignment']))
svi = SVI(model, guide, numpyro.optim.Adam(0.003), TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 100, data)

# this fails
guide = autoguide.AutoDiagonalNormal(block(seed(model, rng_seed=0), hide=['assignment']))
svi = SVI(model, guide, numpyro.optim.Adam(0.003), TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 100, data)

Here's the associated stack trace.

[426](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=425) if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
    [427](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=426)     for name in msg["args"][0].inputs:
    [428](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=427)         self._saved_globals += (
--> [429](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=428)             (name, _DIM_STACK.global_frame.name_to_dim[name]),
    [430](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=429)         )

KeyError: 'components'

If I replace the components plate with locs = numpyro.sample("locs", dist.Normal(0.0, 10.0).expand((K,)).to_event(1)), I get the KeyError on the 'data' plate.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 5, 2024

Hi @amifalk, AutoContinuous does not work with enumerated models. We should raise a better error message for this.

@amifalk
Copy link
Contributor Author

amifalk commented Jan 5, 2024

Is there a way to add this functionality (even if only for a subset of models), or is it a limitation of numpyro?

@fehiepsi
Copy link
Member

fehiepsi commented Jan 5, 2024

Yes, it's the limitation of the blackbox one. It would be much easier to write custom guides for your models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants