Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 all gather hack #1136

Open
wants to merge 3 commits into
base: ngoyal_added_zero2_shard_modelparams_multiple_gpus
Choose a base branch
from

Conversation

jspark1105
Copy link

@jspark1105 jspark1105 commented Sep 17, 2023

This is based on ngoyal_added_zero2_shard_modelparams_multiple_gpus and adding hacks to use fp8 all-gather with Nvidia's transformer engine (see the latest commit for the changes on top of ngoyal_added_zero2_shard_modelparams_multiple_gpus branch).

This depends on transformer engine changes in https://github.com/facebookresearch/TransformerEngine/pull/20
See https://github.com/fairinternal/xlformers/pull/1403 for an example how to use.
Also depends on PyTorch changes in pytorch/pytorch#109654

To use fp8 allgather, set compute_dtype=torch.float8_e4m3fn and mixed_precision=True
We separate out precision critical parameters like affine weights for norm as non flattened params and hard-code to use bf16.
We update scale/scale_inv inside forward before _rebuild_full_params that calls _cast_fp32_param_shards_to_fp16 vs. TE baseline that updates scale/scale_inv in prepare_forward. This because we need fp8 quantization of weights earlier before allgather. (One can consider doing this in post backward but this has a problem since updating bwd amax update is done after bwd of all layers are finished which can be later than post backward so we won't be using the latest bwd amax info for scale/scale_inv update).
We hard-code special handling for a couple of TransformerEngine layers like Linear, LayerNormLinear, and LayerNormMLP in _cast_fp32_param_shards_to_fp16 to access their fp8 meta data to quantize with right scales (TODO: we may want to extract this as a user customizable call back functions?)

CC @awgu @ngoyal2707 @vedanuj @jiecaoyu @yf225 @GD06

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 17, 2023
@jspark1105 jspark1105 marked this pull request as ready for review September 18, 2023 04:33
@jspark1105 jspark1105 changed the base branch from main to ngoyal_added_zero2_shard_modelparams_multiple_gpus October 4, 2023 23:05
@jspark1105
Copy link
Author

Will merge main_grad related changes with #1142

@jspark1105 jspark1105 force-pushed the fp8_all_gather branch 2 times, most recently from bd70153 to af3d2d7 Compare October 5, 2023 03:49
# Cast grad to FP32.
grad_or_main_grad.data = grad_or_main_grad.to(param.dtype)
elif self._is_fp8_dtype():
# Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is not working with the latest FP8 wgrad ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This meant to be for future work when we have fp8 reduce-scatter. I'll update the comment.

@@ -1393,7 +1447,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:

# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
is_bf16 = self.compute_dtype == torch.bfloat16
is_bf16 = self.compute_dtype in [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is_bf16_or_fp8

@jspark1105 jspark1105 force-pushed the fp8_all_gather branch 2 times, most recently from b9b093b to a2b49d1 Compare October 7, 2023 03:10
@@ -2265,8 +2361,7 @@ def local_metadata_dict(self) -> Dict[str, Any]:
backing_param_name = m.module.flat_param_names[i]
names, shapes, numels = m.module.metadata(i)
else:
assert len(m._param_name_groups[i]) == 1
backing_param_name = m._param_name_groups[i][0]
backing_param_name = m._param_name_groups[m._num_flatten_params][i - m._num_flatten_params]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to make sure checkpointing works properly with this.

@jspark1105 jspark1105 force-pushed the fp8_all_gather branch 2 times, most recently from d92dc0f to 6a4d7f4 Compare October 15, 2023 18:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants