Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] Support Automatic Mixed Precision training #3316

Open
austinv11 opened this issue Jan 31, 2024 · 7 comments
Open

[FR] Support Automatic Mixed Precision training #3316

austinv11 opened this issue Jan 31, 2024 · 7 comments
Labels
enhancement help wanted Issues suitable for, and inviting external contributions

Comments

@austinv11
Copy link
Contributor

Issue Description

Better support for mixed precision training would be extremely helpful, at least for SVI. I can manually cast data into float16 or bfloat16 but I am unable to leverage PyTorch's automatic mixed precision training. This is because it requires the use of the GradScaler class during the optimization loop to properly scale gradients in a mixed-precision-aware manner. See the documentation for more info: https://pytorch.org/docs/stable/amp.html

It would be nice to have support for using this class within pyro optimizers to allow for amp support.

@fritzo fritzo added enhancement help wanted Issues suitable for, and inviting external contributions labels Feb 1, 2024
@fritzo fritzo changed the title [Feature Request] Support AMP training [FR] Support Automatic Mixed Precision training Feb 1, 2024
@austinv11
Copy link
Contributor Author

@fritzo I might be willing to try to tackle this, do you have any opinions on how to expose the functionality to the end user?

@fritzo
Copy link
Member

fritzo commented Feb 2, 2024

Hi @austinv11, Thanks for offering. I'd guess there are a few ways we could support AMP in Pyro:

  1. Use Pyro's ELBOModule to construct a differentiable loss function as in the lightning tutorial, then do standard PyTorch training with AMP. I think Pyro's code already supports this, we'd just need improved documentation and maybe an example:
  • Add a docstring to ELBOModule explaining how it is created and why it is useful.
  • Add ELBO.__call__ method to sphinx's :special-members: list here
  • Add an examples/svi_amp.py similar to examples/svi_lightning.py
  1. Do something similar, but with the Trace_ELBO.differentiable_loss() method.
  2. Add more native AMP support to pyro.optim's wrapper class. This seems intricate and more difficult to maintain though.

Would you be interested in getting (1) or (2) working for yourself then contributing docs to show how you did it? We're happy to answer any questions about Pyro, but I think you know more about AMP than us 🙂

@austinv11
Copy link
Contributor Author

It looks like I might need to try option 3 since AMP-aware gradient scaling requires access to the optimizer's step() function.

I could try making it a boolean flag for PyroOptim to enable AMP. Additionally, once that is enabled the user would need to manually use Pytorch's autocast context manager within their models.

But I could see most users wanting to just activate AMP for their entire model rather than just specific portions of code. Do you think it might be worth adding a new ELBO function that autocasts the entire model for the user?

@fritzo
Copy link
Member

fritzo commented Feb 7, 2024

Let me try again to persuade you towards options (1) or (2) 😄, admitting I don't know your details or how AMP works.

Back in the early days of Pyro we decided to wrap PyTorch's optimizer classes so we could have more control over dynamically created parameters. In practice this made Pyro's optimization idioms incompatible with other frameworks build on top of PyTorch, e.g. lightning, horovod, AMP, new higher-order optimizers. To work around this incompatibility we've since added ways to compute differentiable losses in Pyro so that optimization can be done entirely using torch idioms, without ever using pyro.optim.

For example instead of the original pyro-idiomatic optimization

def model(args):
    ...
guide = AutoNormal(model)
elbo = Trace_ELBO()
optim = pyro.optim.Adam(...)  # <---- pyro idioms
svi = SVI(model, guide, optim, elbo)
for step in range(...):
    svi.step(args)

you can use torch-idiomatic optimizers

class Model(PyroModule):
    def forward(args):
        ...
model = Model()
guide = AutoNormal(model)
elbo = Trace_ELBO()
loss_fn = elbo(model, guide)
optim = torch.optim.Adam(elbo.parameters(), ...)  # <---- torch idioms
for step in range(...):
    optimizer.zero_grad()
    loss = loss_fn(args)
    loss.backward()
    optimizer.step()  # <---- Can we use AMP here?

What I'm hoping is that by switching to torch-native optimizers it will be easy/trivial to support AMP.

That said, we'd still be open to adding AMP support to pyro.optim if you can find a simple maintainable way to do so 🙂.

@austinv11
Copy link
Contributor Author

Ah, I see what you mean. Am I correct in understanding that this wouldn't be compatible with the SVI trainer and would require using PyroModules then?

@ilia-kats
Copy link
Contributor

That is also incompatible with models/guides that dynamically create parameters during training, if I understand correctly.

@fritzo
Copy link
Member

fritzo commented Feb 8, 2024

@austinv11 @ilia-kats correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement help wanted Issues suitable for, and inviting external contributions
Projects
None yet
Development

No branches or pull requests

3 participants