Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AutoNormal, AutoDelta, and AutoGuideList do not support subsamples of variable size. #1739

Open
tillahoffmann opened this issue Feb 19, 2024 · 2 comments
Labels
enhancement New feature or request
Milestone

Comments

@tillahoffmann
Copy link
Contributor

AutoNormal, AutoDelta, and AutoGuideList raise an exception in SVI when the subsample size varies across different log_density evaluation. Here is an example reproducing the issue (run on master).

import numpyro
from jax import numpy as jnp


def model(n, x=None, subsample_size=None):
    mu = numpyro.sample("mu", numpyro.distributions.Normal())
    with numpyro.plate("n", n, subsample_size=subsample_size):
        numpyro.sample("x", numpyro.distributions.Normal(mu, 1), obs=x)


def demo(guide_cls):
    n = 10
    x_obs = jnp.zeros(n)
    guide = guide_cls(model)

    with numpyro.handlers.seed(rng_seed=0):
        # Initialize the guide with the full dataset, get a trace, and replay against
        # the model.
        guide(n, x_obs)
        guide_trace = numpyro.handlers.trace(guide).get_trace()
        replayed = numpyro.handlers.replay(model, guide_trace)

        print("evaluate log density for full data")
        numpyro.infer.util.log_density(replayed, (n, x_obs), {}, {})

        print("evaluate log density for subsampled data")
        numpyro.infer.util.log_density(replayed, (n, x_obs[:3], 3), {}, {})

        print("done")

# This works just fine.
demo(numpyro.infer.autoguide.AutoDiagonalNormal)
# This raises an error (see traceback below).
demo(numpyro.infer.autoguide.AutoNormal)

The traceback for the failed call is as follows.

evaluate log density for full data
evaluate log density for subsampled data
.../numpyro/playground/test.py:7: UserWarning: subsample_size does not match len(subsample), 3 vs 10. Did you accidentally use different subsample_size in the model and guide?
  with numpyro.plate("n", n, subsample_size=subsample_size):
Traceback (most recent call last):
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 151, in broadcast_shapes
    return _broadcast_shapes_cached(*shapes)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/util.py", line 287, in wrapper
    return cached(config.config._trace_context(), *args, **kwargs)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/util.py", line 280, in cached
    return f(*args, **kwargs)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 157, in _broadcast_shapes_cached
    return _broadcast_shapes_uncached(*shapes)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 173, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (10,)]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File ".../numpyro/numpyro/infer/util.py", line 80, in log_density
    broadcast_shapes(guide_shape, model_shape)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 153, in broadcast_shapes
    return _broadcast_shapes_uncached(*shapes)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 173, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (10,)]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File ".../numpyro/playground/test.py", line 32, in <module>
    demo(numpyro.infer.autoguide.AutoNormal)
  File ".../numpyro/playground/test.py", line 27, in demo
    numpyro.infer.util.log_density(replayed, (n, x_obs[:3], 3), {}, {})
  File ".../numpyro/numpyro/infer/util.py", line 82, in log_density
    raise ValueError(
ValueError: Model and guide shapes disagree at site: 'x': (10,) vs (3,)

I think the issue is that these guides use _create_plates which in turn uses prototype traces to determine the subsample size.

for name, frame in sorted(self._prototype_frames.items()):
if name not in self.plates:
full_size = self._prototype_frame_full_sizes[name]
self.plates[name] = numpyro.plate(
name, full_size, dim=frame.dim, subsample_size=frame.size
)

The prototype traces are of course only created on the first invocation such that there is a discrepancy in the expected subsample size when a different mini-batch size is used. Guides inheriting from AutoContinuous do not call _create_plates and do not use plates in their __call__ method. I couldn't quite figure out why some guides do and some guides don't.

@fehiepsi
Copy link
Member

fehiepsi commented Feb 20, 2024

This is a good point. I guess a better check is to make sure that there are no latent variables under the subsample plates. When that is the case, there is no need to specify the create_plates argument.

@fehiepsi fehiepsi added the enhancement New feature or request label Feb 20, 2024
@fehiepsi fehiepsi added this to the 0.15 milestone May 12, 2024
@fehiepsi
Copy link
Member

@tillahoffmann sorry for the last misleading comment. For subsampling, the usage is

create_plates = lambda n, x, subsample_size=None: numpyro.plate("n", n, subsample_size=subsample_size)
AutoNormal(..., create_plates=create_plates)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants