-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weights become NaN with torch.compile optimizer capturable=True, lr=0.0, nn.Embedding #126514
Comments
@bdhirsh is this related to the torchtitan NaN loss you were talking about? |
@ad8e Does the NaN repro with single gpu? |
P sure this is caused by this line https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py#L552. To confirm, @ad8e you can likely repro this with just a
The real solution is to allow foreach_div to support Scalar as the first argument, but I'm not sure how hard that is cc @crcrpar. It feels like we should be able to just add an overload. Regarding priority, I'm not sure this is high pri. How likely is this use case? Is there a real use case for having lr be 0? |
DTensor doesn't work when I change TP mesh size from 2 to 1: I receive
which means the process group isn't being created when the dim size is 1. So I cannot test if the NaN would appear or not with single GPU. If I remove the DTensor, like so:
Then no NaNs appear. So the NaN only appears with DTensor. It's not high priority for me because DTensor TP is currently useless due to low performance, so I don't use it anywhere. If DTensor actually mattered (above 70B scale, or if it finally gets comm/comp overlap working), then 0 LR would affect linear decay/warmup, in which case LR=0.0 is common at the endpoints, but avoidable. Another use case would be re-baking the AdamW second moment, which is necessary for resuming from a saved checkpoint without optimizer states, which is useful for saving disk space. This can be done using a very low LR instead of 0.0. If anyone else cared about DTensor, they would be able to spot the NaN issue and work around it in both cases, since it is not a silent failure. |
I tried Jane's testcase, by taking the original DTensor TP=2 example, and making these modifications:
The NaNs appear. So her diagnosis is correct. |
Is this actually related to DTensor or this is more about torch.compile + optimizer? Based on the analysis above, I think if we just use normal torch.Tensor and torch.compile, set the lr=0.0, we should still repro the issue? |
The underlying bug is not in DTensor; it's in the optimizer. It's only that DTensor exposes this code path in the optimizer. Normal torch.Tensor and torch.compile with lr=0.0 doesn't hit it; it's the |
馃悰 Describe the bug
After an optimizer step, the weights become NaN.
Testcase: train_distributed.txt (actually .py)
Code walkthrough:
At the top are imports of everything under the sun, ignore those.
The lines until dist.init_process_group() are to work with Slurm; your own setup will be different.
I init a 2-GPU TP mesh. I define a very simple model, with sharded outputs.
I create an optimizer. Crucially: its lr is 0.0. If I set the LR to a positive number, the NaNs do not appear.
I compile the optimizer. If I do not torch.compile, the NaNs do not appear.
If I use nn.Linear instead of nn.Embedding, the NaNs don't appear in my testcase, but I'm not sure if that's generally true.
Here's the slurm script I use, but it's probably not compatible with your setup. The only important detail is that it uses 2 GPUs.
sample_slurm.txt
Error logs
Notice that only some weights in the embedding layer become NaN. This is because only some weights are active: I set the input IDs to 0 and 1, out of a vocab of 8192.
Minified repro
No response
Versions
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar @bdhirsh @anijain2305 @chauhang @wanchaol @XilunWu @tianyu-l @d4l3k
The text was updated successfully, but these errors were encountered: