Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fairscale support for only performing allreduce in last microbatch #1168

Draft
wants to merge 1 commit into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from

Conversation

jiecaoyu
Copy link

What does this PR do?

Fairscale changes for supporting only performing allreduce in the last microbatch.

The main change in xlformers: https://github.com/fairinternal/xlformers/pull/3236 .

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 18, 2024
Copy link

@jspark1105 jspark1105 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow didn't realize norm main grad allreduce in the last microbatch requires this much change (or some of changes are not related to this specific PR?).

@@ -70,6 +70,9 @@ def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) ->
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
"""Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params]
self._param_require_allreduce = [

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps more descriptive name is _param_require_tp_allreduce?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I have made the modification.

@jiecaoyu jiecaoyu changed the base branch from main to ngoyal_changes_for_pp_fp8 March 18, 2024 17:14
@jiecaoyu jiecaoyu force-pushed the ngoyal_changes_for_pp_fp8_jiecaoyu_norm_allreduce branch from 9eba19b to 74a7313 Compare March 22, 2024 09:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants