-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate device-specific GradScaler autocast API #126527
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126527
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 22b3690 with merge base 5fb11cd (): FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 139fa79ce287da588c7d6d967371057c9081219c Pull Request resolved: #126527
[ghstack-poisoned]
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel [ghstack-poisoned]
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel [ghstack-poisoned]
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel [ghstack-poisoned]
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This deprecation is desirable, but we should remove all mentions of torch.cuda.amp.GradScaler and torch.cpu.amp.GradScaler in the codebase (e.g., in the tests) and replace with the new usage.
This is one way to ensure that the recommended version of GradScaler will remain sufficiently tested too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also update the docs
pytorch/docs/source/notes/amp_examples.rst
Line 112 in 86ad101
.. currentmodule:: torch.cuda.amp.GradScaler |
I update deprecation warning here, I think it will be exposed here https://pytorch.org/docs/stable/amp.html#gradient-scaling feature. May I know if I capture your point? |
@guangyey I mean that this page should also get updated to direct people to the recommended API |
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel [ghstack-poisoned]
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel voznesenskym EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel voznesenskym EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
I updated the doc. Could you help review this PR again? |
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel voznesenskym EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel voznesenskym EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
@janeyx99 could you help review this PR again? Thanks very much~ |
* ``torch.GradScaler("cuda", args...)`` is equivalent to ``torch.cuda.amp.GradScaler(args...)``. | ||
* ``torch.GradScaler("cpu", args...)`` is equivalent to ``torch.cpu.amp.GradScaler(args...)``. | ||
.. warning:: | ||
``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` will be deprecated. Please use ``torch.autocast("cuda", args...)`` or ``torch.autocast("cpu", args...)`` instead. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include this fact here still?
For CPU, only lower precision floating point datatype of torch.bfloat16
is supported for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember that torch.float16
is already supported on CPU now, right @leslie-fang-intel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think CPU Autocast also support torch.float16
now
@@ -25,12 +25,9 @@ However, :class:`torch.autocast` and :class:`torch.GradScaler` are modular, and | |||
As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed precision training/inference" on CPU with | |||
datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`. | |||
|
|||
For CUDA and CPU, APIs are also provided separately: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace mention of torch.cpu.amp.GradScaler and cuda.amp.GradScaler with just amp.Scaler in line 22
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some last nits, thanks!
ghstack-source-id: f454a70e4fddb4af7db42d69d27b1f247004966d Pull Request resolved: #126527
Thanks for your approval. Have a nice day~ |
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel voznesenskym EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #126531 Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/albanD, https://github.com/EikanWang ghstack dependencies: #126527
Summary: # Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. X-link: pytorch/pytorch#126527 Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/janeyx99, https://github.com/EikanWang Reviewed By: PaliC Differential Revision: D57838085 fbshipit-source-id: 09a29e2535e66643d212276779605c573391666f
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. Pull Request resolved: pytorch#126527 Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/janeyx99, https://github.com/EikanWang
Pull Request resolved: pytorch#126531 Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/albanD, https://github.com/EikanWang ghstack dependencies: pytorch#126527
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. Pull Request resolved: #126527 Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/janeyx99, https://github.com/EikanWang (cherry picked from commit c09205a)
Stack from ghstack (oldest at bottom):
Motivation
for
torch.amp.GradScaler
,torch.cpu.amp.GradScaler(args...)
is completely equivalent totorch. amp.GradScaler("cpu", args...)
.torch.cuda.amp.GradScaler(args...)
is completely equivalent totorch.amp.GradScaler("cuda", args...)
.So, we intend to depreate them and strongly recommend developer to use
torch.amp.GradScaler
.for
custom_fwd
andcustom_bwd
,this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU.
So we generalize it to be device-agnostic and put them int
torch/amp/autocast_mode.py
and re-expose totorch.amp.custom_fwd
andtorch.amp.custom_bwd
. Meanwhile, we deprecatetorch.cuda.amp.custom_fwd
andtorch.cuda.amp.custom_bwd
.Additional Context
Add UT to cover the deprecated warning.
No need for more UTs to cover the functionality of
torch.amp.custom_f/bwd
, the existing UTs that previously covered the functionality oftorch.cuda.amp.custom_f/bwd
can cover them.To facilitate the review, we separate these code changes to two PRs. The first PR cover
torch.amp.GradScaler
. The follow-up coverscustom_fwd
andcustom_bwd
.cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @mcarilli @ptrblck @leslie-fang-intel @voznesenskym @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng