Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pmap over num_particles in SVI #1645

Open
snehjp2 opened this issue Sep 22, 2023 · 1 comment
Open

pmap over num_particles in SVI #1645

snehjp2 opened this issue Sep 22, 2023 · 1 comment
Labels
enhancement New feature or request

Comments

@snehjp2
Copy link

snehjp2 commented Sep 22, 2023

Hi,

In Trace_ELBO, the num_particles argument allows one to effectively introduce a batch size in estimating the ELBO gradient if num_particles > 1. By default, it's vectorized over the num_particles. Is it possible to also distribute the batch dimension over devices (e.g. when running on multiple GPUs). My particular application is prone to jax OOM errors and would benefit from distribution over jax.pmap.

@fehiepsi fehiepsi added the enhancement New feature or request label Sep 22, 2023
@fehiepsi
Copy link
Member

If you got OOM, you can set vectorize particles to False. You can also use PositionalSharding like in MCMC I guess.

If you want to pmap over particles, could you make a PR for it? I think we can just simply allow a callable vectorize_particles and call it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants