Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] Parallelism support for sequential plate/guide-side enumeration #3219

Open
amifalk opened this issue May 22, 2023 · 3 comments

Comments

@amifalk
Copy link

amifalk commented May 22, 2023

For mixture models with arbitrary distributions over each feature, sampling currently must be done serially, even though these operations are trivially parallelizable.

To sample priors from a hierarchical mixture model with one continuous and one binary feature, you would need to do something like

with pyro.plate('components', n_components):
   for i in pyro.plate('features', 2):
      if i == 0:     
         pyro.sample('mu', dist.Normal(0, 1))
         pyro.sample('sigma_sq', dist.InverseGamma(1, 1))
      if i == 1:
         pyro.sample('theta', dist.Beta(.1, .1))

For mixture models with large number of features, this can become very slow.

I would love to be able to use a Joblib-like syntax for loops like these, i.e.

features = [['mu', dist.Normal(0, 1)], ['sigma_sq', dist.InverseGamma(1, 1)]],  ['theta', dist.Beta(.1, .1)]]

with pyro.plate('components', n_components):
   Parallel(n_jobs=-1)(delayed(sample_priors)(features[i]) for i in pyro.plate('features', 2)) 

I have tried something like this, and something about the Joblib backend and Pyro don't play nicely together-the model doesn't converge.

In a similar vein, adding parallelism for sequential guide-side enumeration could also enable dramatic speedups. For example, when trying to fit CrossCat with SVI and two truncated stick breaking processes over views and clusters (my personal use-case), enumerating out the view assignments in the model is not possible. Enumerating the views out in the guide is much too slow if they can't be done simultaneously over multiple cores. Since each model run doesn't share information with the others it seems like this should be possible in theory.

I realize this may be difficult for reasons mentioned in #2354, but is any parallelism like this possible in Pyro?

@amifalk amifalk changed the title [feature request] Multi-core parallelism support for sequential plate/guide-side enumeration [feature request] Parallelism support for sequential plate/guide-side enumeration May 22, 2023
@fritzo
Copy link
Member

fritzo commented May 28, 2023

Hmm, I'd guess the most straightforward approach to inter-distribution cpu parallelism would be to rely on the PyTorch jit by simply using JitTrace_ELBO or similar guide.

Pros:

  • it's a one-line change
  • let PyTorch systems folks solve the problem

Cons:

  • the PyTorch jit seems to break every other release, and doesn't seem engineered to work with large compute graphs as arise in Pyro
  • jit-traced models require fixed static model structure

@pavleb
Copy link
Contributor

pavleb commented Dec 1, 2023

@amifalk Did you have any progress in this area? I'm facing with the same issue when dealing with model selection from a set of models with significantly different structure. I have a partial solution of using poutine.mask to mask out the log-likelihood parts in the model and guide trace from the models that are not currently selected with the discrete enumeration. Parallel enumeration can be used.

However, for complicated model structures and large set of models, the masking becomes quite complicated and prone to mistakes that can not be easily debugged.

@amifalk
Copy link
Author

amifalk commented Dec 1, 2023

Sorry, no updates currently @pavleb. We ended up resolving speed issues by moving over to numpyro.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants