Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use BFloat16 in distributed quantization when supported by NCCL #125113

Closed
wants to merge 2 commits into from

Conversation

cyyever
Copy link
Collaborator

@cyyever cyyever commented Apr 28, 2024

Copy link

pytorch-bot bot commented Apr 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125113

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4694f99 with merge base 91a4740 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Apr 28, 2024
@cyyever cyyever requested a review from dagitses April 28, 2024 02:07
Copy link
Collaborator

@dagitses dagitses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI i'm no longer a part of the project, so I can't approve changes.

@@ -69,15 +69,16 @@ at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {

auto output = at::empty(
{nrows, output_columns},
input.options().dtype(at::kHalf)); // at::kHalf
#if HAS_NCCL_BF16_DATATYPE
input.options().dtype(at::kBFloat16));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi i don't think you need to do this one in the preprocessor, you should be able to do it like:

input.options().dtype(HAS_NCCL_BF16_DATATYPE ? at::kBFloat16 : at::kHalf));

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HAS_NCCL_BF16_DATATYPE is a macro and I think it's better to format code like this so that it is easy to identify and remove the old branch in the future.

@cyyever cyyever added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 29, 2024
@cyyever cyyever requested a review from Skylion007 April 29, 2024 01:43
@cpuhrsch cpuhrsch requested a review from wconstab April 30, 2024 19:46
@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 30, 2024
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>())
#endif
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the C10_CUDA_KERNEL_LAUNCH_CHECK function do? What's the purpose of uncommenting it?

@cyyever
Copy link
Collaborator Author

cyyever commented May 1, 2024

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

andoorve pushed a commit to andoorve/pytorch that referenced this pull request May 1, 2024
…rch#125113)

This PR enables BFloat16 in torch/csrc/distributed/c10d/quantization/quantization_gpu.cu .

Pull Request resolved: pytorch#125113
Approved by: https://github.com/kwen2501
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…rch#125113)

This PR enables BFloat16 in torch/csrc/distributed/c10d/quantization/quantization_gpu.cu .

Pull Request resolved: pytorch#125113
Approved by: https://github.com/kwen2501
@cyyever cyyever deleted the quantization_bf16 branch May 5, 2024 04:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants