Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gradient scaler #832

Open
coreylowman opened this issue Jul 26, 2023 · 1 comment
Open

Add Gradient scaler #832

coreylowman opened this issue Jul 26, 2023 · 1 comment

Comments

@coreylowman
Copy link
Owner

With the addition of AMP<F> dtype, we also need to add gradient scaling, which is commonly used with AMP training.

I think the frontend interface could look something like:

let mut scaler = GradientScaler { ... }; // similar fields to pytorch scalar

// this would do both parts that you have to do in pytorch now:
// 1. would scale the loss by the correct value
// 2. would unscale the gradients before returning them
grads = scaler.scaled_backward(loss);

We may have to add some methods to Gradients to support scaling them.

Originally posted by @coreylowman in #424 (comment)

@coreylowman
Copy link
Owner Author

Posting pytorch documentation here: https://pytorch.org/docs/stable/amp.html#gradient-scaling

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant