Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights become NaN with torch.compile optimizer capturable=True, lr=0.0, nn.Embedding #126514

Open
ad8e opened this issue May 17, 2024 · 7 comments
Labels
module: optimizer Related to torch.optim oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ad8e
Copy link
Contributor

ad8e commented May 17, 2024

馃悰 Describe the bug

After an optimizer step, the weights become NaN.

Testcase: train_distributed.txt (actually .py)

Code walkthrough:

At the top are imports of everything under the sun, ignore those.
The lines until dist.init_process_group() are to work with Slurm; your own setup will be different.
I init a 2-GPU TP mesh. I define a very simple model, with sharded outputs.
I create an optimizer. Crucially: its lr is 0.0. If I set the LR to a positive number, the NaNs do not appear.
I compile the optimizer. If I do not torch.compile, the NaNs do not appear.
If I use nn.Linear instead of nn.Embedding, the NaNs don't appear in my testcase, but I'm not sure if that's generally true.

Here's the slurm script I use, but it's probably not compatible with your setup. The only important detail is that it uses 2 GPUs.
sample_slurm.txt

Error logs

Notice that only some weights in the embedding layer become NaN. This is because only some weights are active: I set the input IDs to 0 and 1, out of a vocab of 8192.

All nodes: slurm-h100-reserved-122-01
Primary: slurm-h100-reserved-122-01
pyxis: imported docker image: kurumuz@novelai/basedformer:4b6b9d3
Initialized process group
Initialized process group
shape torch.Size([8192])
shape torch.Size([8192])
loss rank0 tensor([11.8590, 10.0378, 11.8590,  ..., 14.6079, 11.8590, 14.6079],
       device='cuda:0', grad_fn=<_ToTorchTensorBackward>)
l1.weight DTensor(local_tensor=tensor([[ 1.8800,  0.1367],
        [ 0.6136,  0.4704],
        [ 0.6825, -1.4885],
        ...,
        [ 0.1573, -0.7616],
        [ 0.8190,  0.6479],
        [-0.4992, -0.8853]], device='cuda:0'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=1),)) grad DTensor(local_tensor=tensor([[ 1.1526, -0.0019],
        [ 0.5315,  0.1962],
        [ 0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000]], device='cuda:0'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=1),)) 0
l2.weight DTensor(local_tensor=tensor([[-0.7944, -0.5337,  0.1366,  0.2406],
        [-0.1300,  0.6589, -1.5099, -0.1818],
        [-0.9264,  0.3334, -0.3189,  1.6834],
        ...,
        [ 1.0073, -0.8948,  0.1003,  1.4572],
        [-1.5023, -0.2784, -0.3138,  0.0789],
        [-1.1018, -0.8476,  0.2612,  0.4110]], device='cuda:0'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) grad DTensor(local_tensor=tensor([[-6.3255e-01, -1.5145e-01,  2.8950e-01, -8.6569e-02],
        [-6.1381e-01, -1.5221e-01,  2.8793e-01, -8.4723e-02],
        [ 2.2526e-06,  1.4737e-06,  9.7696e-07,  1.4733e-06],
        ...,
        [ 1.6089e-05,  3.7329e-06,  6.2591e-06,  6.4757e-06],
        [ 6.1414e-07,  4.2662e-07, -2.5482e-07,  2.4957e-07],
        [ 8.1019e-07,  5.1636e-07,  4.8103e-07,  5.6353e-07]], device='cuda:0'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) 0
l2.bias DTensor(local_tensor=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) grad DTensor(local_tensor=tensor([-5.0329e-01, -4.9669e-01,  3.2706e-06,  ...,  1.2613e-05,
         9.3097e-07,  1.1547e-06], device='cuda:0'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) 0
[rank0]:W0517 06:32:09.158000 140041545275200 torch/_logging/_internal.py:1024] [0/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
loss rank1 tensor([11.8590, 10.0378, 11.8590,  ..., 14.6079, 11.8590, 14.6079],
       device='cuda:1', grad_fn=<_ToTorchTensorBackward>)
l1.weight DTensor(local_tensor=tensor([[-1.7834, -0.2130],
        [ 0.6765,  0.5709],
        [-0.7983, -0.3547],
        ...,
        [ 0.2357, -0.4988],
        [ 1.2519,  1.2799],
        [-0.4103,  1.4584]], device='cuda:1'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=1),)) grad DTensor(local_tensor=tensor([[-0.5494, -0.1227],
        [ 0.6680,  0.2597],
        [ 0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000]], device='cuda:1'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=1),)) 1
l2.weight DTensor(local_tensor=tensor([[ 1.8750,  2.2957, -1.8079,  0.4607],
        [ 2.0036, -0.0731,  1.4216,  1.1791],
        [ 0.7447, -0.7806, -0.0764, -0.9730],
        ...,
        [ 0.7182,  0.4091,  0.9995,  0.1174],
        [-0.2682,  0.2800, -0.5404, -0.8055],
        [-1.3858,  0.9973, -0.0299,  0.3903]], device='cuda:1'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) grad DTensor(local_tensor=tensor([[ 3.6600e-04,  8.1472e-05, -4.6716e-04, -4.8854e-05],
        [ 1.9428e-04,  2.7638e-05,  8.3199e-05,  7.0824e-05],
        [ 4.1273e-06,  1.1523e-06, -1.5220e-06,  7.7401e-07],
        ...,
        [ 1.1623e-05,  4.0681e-06,  6.9684e-06,  6.2287e-06],
        [ 2.2720e-06,  1.2412e-06, -2.9601e-06,  9.0997e-08],
        [ 1.2958e-06,  9.2157e-07,  3.8395e-07,  8.3228e-07]], device='cuda:1'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) 1
l2.bias DTensor(local_tensor=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:1'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) grad DTensor(local_tensor=tensor([2.8147e-04, 1.2472e-04, 3.5435e-06,  ..., 1.1281e-05, 2.9109e-06,
        1.9982e-06], device='cuda:1'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) 1
[rank1]:W0517 06:32:09.178000 140058576181056 torch/_logging/_internal.py:1024] [0/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
post opt l1.weight DTensor(local_tensor=tensor([[1.8800, 0.1367],
        [0.6136, 0.4704],
        [   nan,    nan],
        ...,
        [   nan,    nan],
        [   nan,    nan],
        [   nan,    nan]], device='cuda:0'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=1),)) 0
post opt l2.weight DTensor(local_tensor=tensor([[-0.7944, -0.5337,  0.1366,  0.2406],
        [-0.1300,  0.6589, -1.5099, -0.1818],
        [-0.9264,  0.3334, -0.3189,  1.6834],
        ...,
        [ 1.0073, -0.8948,  0.1003,  1.4572],
        [-1.5023, -0.2784, -0.3138,  0.0789],
        [-1.1018, -0.8476,  0.2612,  0.4110]], device='cuda:0'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) 0
post opt l2.bias DTensor(local_tensor=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) 0
post opt l1.weight DTensor(local_tensor=tensor([[-1.7834, -0.2130],
        [ 0.6765,  0.5709],
        [    nan,     nan],
        ...,
        [    nan,     nan],
        [    nan,     nan],
        [    nan,     nan]], device='cuda:1'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=1),)) 1
post opt l2.weight DTensor(local_tensor=tensor([[ 1.8750,  2.2957, -1.8079,  0.4607],
        [ 2.0036, -0.0731,  1.4216,  1.1791],
        [ 0.7447, -0.7806, -0.0764, -0.9730],
        ...,
        [ 0.7182,  0.4091,  0.9995,  0.1174],
        [-0.2682,  0.2800, -0.5404, -0.8055],
        [-1.3858,  0.9973, -0.0299,  0.3903]], device='cuda:1'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) 1
post opt l2.bias DTensor(local_tensor=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:1'), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)) 1

Minified repro

No response

Versions

PyTorch version: 2.4.0a0+ed76079
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.19.17-coreweave-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 525.125.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   52 bits physical, 57 bits virtual
Byte Order:                      Little Endian
CPU(s):                          128
On-line CPU(s) list:             0-127
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Platinum 8462Y+
CPU family:                      6
Model:                           143
Thread(s) per core:              2
Core(s) per socket:              32
Socket(s):                       2
Stepping:                        8
CPU max MHz:                     4100.0000
CPU min MHz:                     800.0000
BogoMIPS:                        5600.00
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req hfi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       3 MiB (64 instances)
L1i cache:                       2 MiB (64 instances)
L2 cache:                        128 MiB (64 instances)
L3 cache:                        120 MiB (2 instances)
NUMA node(s):                    2
NUMA node0 CPU(s):               0-31,64-95
NUMA node1 CPU(s):               32-63,96-127
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.4.0a0+ed76079
[pip3] torchaudio==2.2.0a0+ea437b3
[pip3] torchvision==0.19.0a0+947ae1d
[pip3] triton==3.0.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar @bdhirsh @anijain2305 @chauhang @wanchaol @XilunWu @tianyu-l @d4l3k

@ad8e ad8e changed the title Weights become NaN with torch.compile optimizer, DTensor, lr=0.0 Weights become NaN with torch.compile optimizer, DTensor, lr=0.0, nn.Embedding May 17, 2024
@xmfan
Copy link
Member

xmfan commented May 17, 2024

@bdhirsh is this related to the torchtitan NaN loss you were talking about?

@xmfan xmfan added module: dtensor distributed tensor tag module: optimizer Related to torch.optim high priority labels May 17, 2024
@xmfan xmfan added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module triage review and removed triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 20, 2024
@xmfan
Copy link
Member

xmfan commented May 20, 2024

@ad8e Does the NaN repro with single gpu?

@janeyx99
Copy link
Contributor

janeyx99 commented May 20, 2024

P sure this is caused by this line https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py#L552.

To confirm, @ad8e you can likely repro this with just a

optim = torch.optim.AdamW(model.parameters(), lr=0.0, capturable=True)
...
optim.step()

The real solution is to allow foreach_div to support Scalar as the first argument, but I'm not sure how hard that is cc @crcrpar. It feels like we should be able to just add an overload. Regarding priority, I'm not sure this is high pri. How likely is this use case? Is there a real use case for having lr be 0?

@ad8e
Copy link
Contributor Author

ad8e commented May 20, 2024

DTensor doesn't work when I change TP mesh size from 2 to 1: I receive

[rank1]: Traceback (most recent call last):
[rank1]:   File "/clusterstorage/workspace/kevin/nandtensor.py", line 99, in <module>
[rank1]:     model_tp = parallelize_module(model, tp_mesh, parallelize_plan=layer_plan)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/tensor/parallel/api.py", line 82, in parallelize_module
[rank1]:     random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/random.py", line 345, in _manual_seed
[rank1]:     tensor_parallel_rank = tp_mesh.get_local_rank()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/device_mesh.py", line 502, in get_local_rank
[rank1]:     mesh_dim_group = not_none(self.get_group(mesh_dim))
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/device_mesh.py", line 411, in get_group
[rank1]:     _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2])
[rank1]: IndexError: list index out of range

which means the process group isn't being created when the dim size is 1. So I cannot test if the NaN would appear or not with single GPU.

If I remove the DTensor, like so:

# model_tp = parallelize_module(model, tp_mesh, parallelize_plan=layer_plan)
model_tp = model
...
# gas_loss = gas_loss.full_tensor() # commented out

Then no NaNs appear. So the NaN only appears with DTensor.

It's not high priority for me because DTensor TP is currently useless due to low performance, so I don't use it anywhere. If DTensor actually mattered (above 70B scale, or if it finally gets comm/comp overlap working), then 0 LR would affect linear decay/warmup, in which case LR=0.0 is common at the endpoints, but avoidable. Another use case would be re-baking the AdamW second moment, which is necessary for resuming from a saved checkpoint without optimizer states, which is useful for saving disk space. This can be done using a very low LR instead of 0.0.

If anyone else cared about DTensor, they would be able to spot the NaN issue and work around it in both cases, since it is not a silent failure.

@ad8e
Copy link
Contributor Author

ad8e commented May 20, 2024

I tried Jane's testcase, by taking the original DTensor TP=2 example, and making these modifications:

opt = AdamW(...
    capturable=True, # this is new
)
...
# opt.step = torch.compile(opt.step) # this is removed

The NaNs appear. So her diagnosis is correct.

@wanchaol
Copy link
Contributor

Is this actually related to DTensor or this is more about torch.compile + optimizer? Based on the analysis above, I think if we just use normal torch.Tensor and torch.compile, set the lr=0.0, we should still repro the issue?

@ad8e
Copy link
Contributor Author

ad8e commented May 20, 2024

The underlying bug is not in DTensor; it's in the optimizer. It's only that DTensor exposes this code path in the optimizer.

Normal torch.Tensor and torch.compile with lr=0.0 doesn't hit it; it's the capturable argument that Jane mentioned which is the key.

@xmfan xmfan removed the module: dtensor distributed tensor tag label May 20, 2024
@xmfan xmfan changed the title Weights become NaN with torch.compile optimizer, DTensor, lr=0.0, nn.Embedding Weights become NaN with torch.compile optimizer capturable=True, lr=0.0, nn.Embedding May 20, 2024
@xmfan xmfan added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed high priority triage review labels May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants