New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sample from distribution without storing #1695
Comments
I don't think we store the latent values. Could you elaborate? |
Sorry for the late reply. I may be misunderstanding, but the problem is that we are sampling latent variables that are nuisance parameters, so we don't need estimates of their posteriors. Is using |
Are you using MCMC? There is |
I have the exact same issue ! |
See e.g. this comment https://forum.pyro.ai/t/reducing-mcmc-memory-usage/5639/4?u=fehiepsi |
I wasn't able to discern from that comment how to use |
If my understanding is correct, the only way is to run the MCMC step by step and manually trace the parameters of interest |
Sorry, my brain was not working when I sent the previous comment. The argument name is mcmc = MCMC(NUTS(model))
mcmc.sampler.default_fields = ("z.foo", "z.bar") following the changes in the forum comment (linked in my last comment). |
I'm still a bit lost on how I might do this for nested arrays. For instance, I am running the following model on an "image" of satellite data and I've got 60 subgrids of size 50x50. I sample a random vector def satellite_model(T=None):
sigma_T = numpyro.sample("sigma_T", dist.HalfNormal(10))
for b in range(num_subgrids):
z = numpyro.sample(f"z[{b}]", dist.Normal(0, 1).expand([z_dim]))
ls = numpyro.sample(f"ls[{b}]", dist.Beta(3, 6))
var = numpyro.sample(f"var[{b}]", dist.LogNormal(0, 1))
c = jnp.hstack([ls, var])
mu = simulator.apply(
{"params": params}, z, c, method=simulator.decode
).squeeze()
numpyro.sample(
f"T[{b}]",
dist.Normal(mu[non_nan_idx[b]], sigma_T),
obs=T[b][non_nan_idx[b]],
) |
if you have a model with density of course to save memory you needn't actually save all the there's also another possibility in which you're not actually trying to do "proper inference" and maybe instead |
We do need to do inference over |
If you don't want to change the source code, then you can do
But it's easy to support this feature. As outlined above, we can:
Let's keep this issue open in case a contributor wants to support this feature. You can use the above |
I find I often have a pattern where my random variable is a nuisance variable, but some deterministic function of it is meaningful. In this case, the desired behavior is more so a function of the model than a function of the inference algorithm, so it's inconvenient to have to tamper with settings for every fit. I would much prefer to have a flag in the a = numpyro.sample('a_', dist.MultivariateNormal(jnp.zeros(2500), jnp.zeros(2500)), collect=False)
a = numpyro.deterministic('a', a*2) What do you think @fehiepsi ? |
Yes, we can add a field to the |
It looks like all the samplers create a trace on initialization, most via def _init_state(self, ...):
model_trace, ... = numpyro.infer.util.initialize_model(...)
self.default_fields = numpyro.infer.util.get_default_fields(model_trace) Is this solution ok for you? If so, I would be happy to draft up a PR. |
Yup, I think the solution looks good. Users can use either |
I am currently working on a project where we embed a VAE-decoder inside a model. Accordingly, we need to sample
z
s from a multivariate normal distribution, but we are not interested in the posterior of thez
s. Here is an example model:Currently, we are running inference on 50x50 grids with a
z
dimension of 2500 (onez
per point in the grid), which means a standard model saves 2500z
s per step. We never use thesez
s and would like to prevent storing them to save memory and computation. We would greatly appreciate any advice!The text was updated successfully, but these errors were encountered: