Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample from distribution without storing #1695

Open
danjenson opened this issue Dec 12, 2023 · 16 comments
Open

Sample from distribution without storing #1695

danjenson opened this issue Dec 12, 2023 · 16 comments
Labels
enhancement New feature or request

Comments

@danjenson
Copy link

danjenson commented Dec 12, 2023

I am currently working on a project where we embed a VAE-decoder inside a model. Accordingly, we need to sample zs from a multivariate normal distribution, but we are not interested in the posterior of the zs. Here is an example model:

def model(y=None):
    var = numpyro.sample("variance", dist.HalfNormal())
    ls = numpyro.sample("lengthscale", dist.HalfNormal())
    z = numpyro.sample("z", dist.MultivariateNormal(jnp.zeros(2500), jnp.zeros(2500)))  # <- want to sample but not store
    y_hat = numpyro.deterministic("y_hat", vae.decode(jnp.array([*z, ls, var])))
    sigma = numpyro.sample("sigma", dist.HalfNormal(0.1))
    numpyro.sample("obs", dist.Normal(y_hat[mask], sigma), obs=y)

Currently, we are running inference on 50x50 grids with a z dimension of 2500 (one z per point in the grid), which means a standard model saves 2500 zs per step. We never use these zs and would like to prevent storing them to save memory and computation. We would greatly appreciate any advice!

@fehiepsi
Copy link
Member

I don't think we store the latent values. Could you elaborate?

@fehiepsi fehiepsi added the question Further information is requested label Dec 15, 2023
@danjenson
Copy link
Author

Sorry for the late reply. I may be misunderstanding, but the problem is that we are sampling latent variables that are nuisance parameters, so we don't need estimates of their posteriors. Is using numpyro.sample still the correct construct for latent nuisance parameters or is there a lighter weight sampling procedure, e.g. a pure jax method that might be more appropriate?

@fehiepsi
Copy link
Member

fehiepsi commented Jan 9, 2024

Are you using MCMC? There is collect_fields to filter out variables that are not required. If you are using SVI, then we don't store latent variables during training.

@renecotyfanboy
Copy link
Contributor

I have the exact same issue !
@fehiepsi Could you elaborate on the use of collect_fields ? I can't find relevant entries in the docs

@fehiepsi
Copy link
Member

fehiepsi commented Jan 9, 2024

See e.g. this comment https://forum.pyro.ai/t/reducing-mcmc-memory-usage/5639/4?u=fehiepsi

@danjenson
Copy link
Author

danjenson commented Jan 10, 2024

I wasn't able to discern from that comment how to use collect_fields. It isn't an argument to NUTS(...), MCMC(...), or mcmc.run(..., collect_fields=...). Where / how do you add collect fields, and is it just a list of variable names you want to keep? I'm using numpyro==0.13.2. Thank you!

@renecotyfanboy
Copy link
Contributor

If my understanding is correct, the only way is to run the MCMC step by step and manually trace the parameters of interest

@fehiepsi fehiepsi added enhancement New feature or request and removed question Further information is requested labels Jan 12, 2024
@fehiepsi
Copy link
Member

Sorry, my brain was not working when I sent the previous comment. The argument name is extra_fields, not collect_fields. There is a property named default_fields which will store the variables. I think we can enable an api to allow doing

mcmc = MCMC(NUTS(model))
mcmc.sampler.default_fields = ("z.foo", "z.bar")

following the changes in the forum comment (linked in my last comment).

@danjenson
Copy link
Author

danjenson commented Jan 14, 2024

I'm still a bit lost on how I might do this for nested arrays. For instance, I am running the following model on an "image" of satellite data and I've got 60 subgrids of size 50x50. I sample a random vector z of 512 values for each subgrid for each sample. So, if I'm doing 1000 samples, that is 60 * 512 * 1000 values stored. However, I don't care about the posteriors of these values -- they are simply used to seed a generative model that I have inserted as a deterministic transformation (simulator.decode in the model). What is the best way to ignore the posteriors of the 60 * 512 z values?

    def satellite_model(T=None):
        sigma_T = numpyro.sample("sigma_T", dist.HalfNormal(10))
        for b in range(num_subgrids):
            z = numpyro.sample(f"z[{b}]", dist.Normal(0, 1).expand([z_dim]))
            ls = numpyro.sample(f"ls[{b}]", dist.Beta(3, 6))
            var = numpyro.sample(f"var[{b}]", dist.LogNormal(0, 1))
            c = jnp.hstack([ls, var])
            mu = simulator.apply(
                {"params": params}, z, c, method=simulator.decode
            ).squeeze()
            numpyro.sample(
                f"T[{b}]",
                dist.Normal(mu[non_nan_idx[b]], sigma_T),
                obs=T[b][non_nan_idx[b]],
            )

@martinjankowiak
Copy link
Collaborator

if you have a model with density p(x, y) and y is a "nuisance" variable in the sense that you don't care about it's posterior but you still want to integrate out the uncertainty associated with its unknown value it's still required to to do inference over y since different y slices of p(x, y) lead to different conditional posteriors over x and so there's no way around doing inference on y.

of course to save memory you needn't actually save all the y samples.

there's also another possibility in which you're not actually trying to do "proper inference" and maybe instead y is fixed once at the beginning or sampled from a fixed distribution at each step in inference---but that's not doing proper inference over x in the presence of uncertainty over y.

@danjenson
Copy link
Author

We do need to do inference over z, especially since we are using HMC and it will be calculated gradients over z in the latent space, but we would prefer not to save all these samples to save memory. Is there a way to do this?

@fehiepsi
Copy link
Member

If you don't want to change the source code, then you can do

import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def model():
    numpyro.sample("x", dist.Normal(0, 1))
    numpyro.sample("y", dist.Normal(0, 1))

class CustomNUTS(NUTS):
    def postprocess_fn(self, args, kwargs):
        transform = super().postprocess_fn(args, kwargs)
        def new_transform(z):
            z = transform(z)
            z.pop("x")
            return z
        return new_transform

mcmc = MCMC(CustomNUTS(model), num_warmup=10, num_samples=20)
mcmc.run(jax.random.PRNGKey(0))
mcmc.get_samples().keys()

But it's easy to support this feature. As outlined above, we can:

Let's keep this issue open in case a contributor wants to support this feature. You can use the above CustomNUTS in the mean time.

@amifalk
Copy link
Contributor

amifalk commented Apr 19, 2024

I find I often have a pattern where my random variable is a nuisance variable, but some deterministic function of it is meaningful. In this case, the desired behavior is more so a function of the model than a function of the inference algorithm, so it's inconvenient to have to tamper with settings for every fit.

I would much prefer to have a flag in the numpyro.sample function to toggle whether or not a site is collected during mcmc.

a = numpyro.sample('a_', dist.MultivariateNormal(jnp.zeros(2500), jnp.zeros(2500)), collect=False)
a = numpyro.deterministic('a', a*2) 

What do you think @fehiepsi ?

@fehiepsi
Copy link
Member

Yes, we can add a field to the "infer" keyword. But this requires us to update all MCMC kernels. I feel that supporting mcmc.sampler.default_fields = ("z.a",) is simpler. What do you think?

@amifalk
Copy link
Contributor

amifalk commented Apr 25, 2024

It looks like all the samplers create a trace on initialization, most via initialize_model. It should be easy to add a function to infer.util that takes the trace and returns the default fields. Even though we would have to update each one, I don't think it would add much complexity. We would just need to add a setter method for default_fields in the MCMCKernel superclass and add one line to each kernel.

def _init_state(self, ...):
   model_trace, ... = numpyro.infer.util.initialize_model(...)
   self.default_fields = numpyro.infer.util.get_default_fields(model_trace)

Is this solution ok for you? If so, I would be happy to draft up a PR.

@fehiepsi
Copy link
Member

Yup, I think the solution looks good. Users can use either infer or default_fields.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants