You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm not entirely sure what's going on here. The following model works with vanilla NUTS, but returns TypeError: mul got incompatible shapes for broadcasting: (3, 5), (5, 5) when trying to run NUTS after reparameterizing with NeuTraReparam.
If I remove the top two plates and replace the latents with the constants
P_cov=jnp.eye(k, k)
Q_cov=jnp.eye(q, q)
the code runs but I get the following warnings:
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site '_P_log_prob'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
mcmc.run(jax.random.PRNGKey(4), X, Y)
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site '_Q_log_prob'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
mcmc.run(jax.random.PRNGKey(4), X, Y)
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site 'Y'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
mcmc.run(jax.random.PRNGKey(4), X, Y)
Maybe it has something to do with having multiple plate names with the same dimension?
The text was updated successfully, but these errors were encountered:
Minimal example:
I'm not entirely sure what's going on here. The following model works with vanilla NUTS, but returns
TypeError: mul got incompatible shapes for broadcasting: (3, 5), (5, 5)
when trying to run NUTS after reparameterizing with NeuTraReparam.If I remove the top two plates and replace the latents with the constants
the code runs but I get the following warnings:
Maybe it has something to do with having multiple plate names with the same dimension?
The text was updated successfully, but these errors were encountered: