Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug in NeuTraReparam #1694

Open
amifalk opened this issue Dec 7, 2023 · 1 comment
Open

bug in NeuTraReparam #1694

amifalk opened this issue Dec 7, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@amifalk
Copy link
Contributor

amifalk commented Dec 7, 2023

Minimal example:

import jax
import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Trace_ELBO, SVI
from numpyro.infer.reparam import NeuTraReparam
from numpyro.infer.autoguide import AutoBNAFNormal

n = 100
p = 10 # n_dim x
q = 5 # n_dim y
k = min(3, p, q) # n_dim latent

X = dist.MultivariateNormal(jnp.zeros(p), jnp.eye(p, p)).sample(PRNGKey(0), (n,))
Y = dist.MultivariateNormal(jnp.zeros(q), jnp.eye(q, q)).sample(PRNGKey(1), (n,))

def model(X, Y=None):    
    with numpyro.plate('_k', k):
         P_cov = numpyro.sample('P_cov', dist.InverseGamma(3, 1))
        
    with numpyro.plate('_q', q):
         Q_cov = numpyro.sample('Q_cov', dist.InverseGamma(3, 1))    
    
    P_cov = P_cov * jnp.eye(k, k)
    Q_cov = Q_cov * jnp.eye(q, q)

    with numpyro.plate('p', p):
        P = numpyro.sample('P', dist.MultivariateNormal(jnp.zeros(k), P_cov))
    
    with numpyro.plate('k', k):
        Q = numpyro.sample('Q', dist.MultivariateNormal(jnp.zeros(q), Q_cov))
        
    with numpyro.plate('n', n):
        Z = X @ P # low rank representation of X
        Y_pred = Z @ Q # transform back into Y via Q

        return numpyro.sample('Y', dist.MultivariateNormal(Y_pred, jnp.eye(q, q)), obs=Y)

#  --- this works ---
mcmc = MCMC(NUTS(model), num_warmup=50, num_samples=50)
mcmc.run(jax.random.PRNGKey(2), X, Y) 

# --- this fails ---
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8, 8])
svi = SVI(model, guide, numpyro.optim.Adam(0.003), Trace_ELBO())

svi_result = svi.run(jax.random.PRNGKey(3), 5_000, X, Y)
neutra = NeuTraReparam(guide, svi_result.params)

mcmc = MCMC(NUTS(neutra.reparam(model)), num_warmup=1_000, num_samples=3_000)
mcmc.run(jax.random.PRNGKey(4), X, Y)

I'm not entirely sure what's going on here. The following model works with vanilla NUTS, but returns TypeError: mul got incompatible shapes for broadcasting: (3, 5), (5, 5) when trying to run NUTS after reparameterizing with NeuTraReparam.

If I remove the top two plates and replace the latents with the constants

P_cov =  jnp.eye(k, k)
Q_cov = jnp.eye(q, q)

the code runs but I get the following warnings:

<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site '_P_log_prob'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(jax.random.PRNGKey(4), X, Y)
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site '_Q_log_prob'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(jax.random.PRNGKey(4), X, Y)
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site 'Y'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(jax.random.PRNGKey(4), X, Y)

Maybe it has something to do with having multiple plate names with the same dimension?

@fehiepsi fehiepsi added the bug Something isn't working label Dec 7, 2023
@fehiepsi
Copy link
Member

fehiepsi commented Dec 7, 2023

Thanks @amifalk! This is a bug because we allow plate to be applied to the unconstrained value:

z_unconstrained = numpyro.sample(
"{}_shared_latent".format(self.guide.prefix),
self.guide.get_base_dist().mask(False),
)

A temporary fix is to remove plate for the first site

P_cov = numpyro.sample('P_cov', dist.InverseGamma(3, 1).expand([k]).to_event())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants