Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDPMixedPrecision setting, Logit norm growth, z-loss. #514

Open
maximilianmbeck opened this issue Mar 21, 2024 · 0 comments
Open

FSDPMixedPrecision setting, Logit norm growth, z-loss. #514

maximilianmbeck opened this issue Mar 21, 2024 · 0 comments
Labels
type/question An issue that's a question

Comments

@maximilianmbeck
Copy link

❓ The question

Hi all,

why did you reduce the gradients in float32 as reported in section 3.1 in the OLMo paper?

We have made some experiments on this and observed that when setting the reduce_dtype=bfloat16 for training setups with more than 4 nodes causes the output logit norm to grow.

I am curious, did you make a similar observation? Did you track the output logit norm during training?

More concretely, during training our models we also observed a growth of the output logit norm which lead to Infs in our PPL metrics (NOT in the loss, loss was still fine) at some point later in training.
Even though we observed that we could mitigate this adding by a regularizing loss that pushes down the output logits, we tried do avoid using such a loss similar to the z-loss as suggested by the PaLM paper.
Instead we investigated PyTorch FSDP Mixed Precision settings, as we suspected bfloat16 to cause issues here.

We trained two Transformer like models of size 125M and 1.3B on next-token-prediction on 4, 16 and 32 nodes (see below).
We trained for approx. 10k steps with the hyperparameters specified below.

Note: drd corresponds to the reduce_dtype setting of FSDPMixedPrecision

Experiment 1:

A model with 125M parameters trained on 4 Nodes and 16 Nodes and both with reduce_dtype=bfloat16.
As sharding strategy we use NO_SHARD.

Brown: B24E768gbs256--s-NO_SHARD-nn-16-drd-bfloat16-sn-125M-utc-1-l-0.0003-wd-0.1-nb-24-ed-768-seed-42
Blue: B24E768gbs256--s-NO_SHARD-nn-4 -drd-bfloat16-sn-125M-utc-1-l-0.0003-wd-0.1-nb-24-ed-768-seed-0

fsdpprecision_125M_nn4_vs_nn16

Experiment 2:

A model with 1_3B parameters trained on 32 Nodes with DDP, FSDP NO_SHARD reduce_dtype=float32 and FSDP NO_SHARD reduce_dtype=bfloat16.
We compare FSDP with sharding strategy NO_SHARD to DDP.

Grey: B48E2048gbs512--s-NO_SHARD -nn-32-drd-float32 -sn-1_3B-utc-1-l-0.0002-wd-0.1-nb-48-ed-2048-seed-42
Red: B48E2048gbs512--s-DDP -nn-32-drd-bfloat16-sn-1_3B-utc-1-l-0.0002-wd-0.1-nb-48-ed-2048-seed-42
Green: B48E2048gbs512--s-NO_SHARD -nn-32-drd-bfloat16-sn-1_3B-utc-1-l-0.0002-wd-0.1-nb-48-ed-2048-seed-42

fsdpprecision_1_3B_fsdpsettings
fsdpprecision_1_3B_fsdpsettings-zoom

We observe that setting reduce_dtype=bfloat16 for training setups with more than 4 nodes causes the output logit norm to grow.
When training with FSDP, setting reduce_dtype=float32 or training with DDP (we think that DDP also reduces gradients in float32) the output logit norm did not grow.
In other experiments we even observed that the growth of the output logit norm scales roughly linear with the number of nodes (when using reduce_dtype=float32)

The tricky thing is that this behavior is not visible in the loss (see screenshots), so it is hard to track down this issue to FSDP Mixed Precision.

We think this is a severe issue that needs more investigation, since the reduce dtype has a major impact on training speed and one would actually prefer bfloat16 for higher training throughput.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type/question An issue that's a question
Projects
None yet
Development

No branches or pull requests

1 participant