Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused AdamW not supported with FSDP2 #126670

Closed
ad8e opened this issue May 20, 2024 · 6 comments
Closed

Fused AdamW not supported with FSDP2 #126670

ad8e opened this issue May 20, 2024 · 6 comments
Assignees
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ad8e
Copy link
Contributor

ad8e commented May 20, 2024

馃悰 Describe the bug

Error log:

[rank4]: Traceback (most recent call last):
[rank4]:   File "/clusterstorage/workspace/kevin/basedformer/train_distributed.py", line 1290, in <module>
[rank4]:     main(train_config, model_config_yaml)
[rank4]:   File "/clusterstorage/workspace/kevin/basedformer/train_distributed.py", line 1231, in main
[rank4]:     train(args, model, train_loader, opt, derived_args)
[rank4]:   File "/clusterstorage/workspace/kevin/basedformer/train_distributed.py", line 558, in train
[rank4]:     opt.step()  # torch.compile also doesn't like the grad scaler, even as a positional argument
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 420, in _fn
[rank4]:     return fn(*args, **kwargs)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 37, in inner
[rank4]:     return fn(*args, **kwargs)
[rank4]:   File "/clusterstorage/workspace/kevin/basedformer/basedformer/optimizer.py", line 279, in step
[rank4]:     grad_scaler.step(self.optimizer)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/grad_scaler.py", line 379, in step
[rank4]:     return optimizer.step(*args, **kwargs)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py", line 108, in wrapper
[rank4]:     return func.__get__(opt, opt.__class__)(*args, **kwargs)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 493, in wrapper
[rank4]:     out = func(*args, **kwargs)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 87, in _use_grad
[rank4]:     ret = func(self, *args, **kwargs)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/optim/adamw.py", line 225, in step
[rank4]:     adamw(
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 159, in maybe_fallback
[rank4]:     return func(*args, **kwargs)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/optim/adamw.py", line 753, in adamw
[rank4]:     func(
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 420, in _fn
[rank4]:     return fn(*args, **kwargs)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/optim/adamw.py", line 665, in _fused_adamw
[rank4]:     torch._fused_adamw_(
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 24, in inner
[rank4]:     return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 420, in _fn
[rank4]:     return fn(*args, **kwargs)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/api.py", line 296, in __torch_dispatch__
[rank4]:     return DTensor._op_dispatcher.dispatch(
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/dispatch.py", line 115, in dispatch
[rank4]:     op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank4]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/dispatch.py", line 359, in unwrap_to_op_info
[rank4]:     raise RuntimeError(
[rank4]: RuntimeError: aten._fused_adamw_.tensor_lr: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

Happens when I set fused=True in AdamW.

Torchtitan is probably the best repro here.

Versions

PyTorch version: 2.4.0a0+ed76079
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.19.17-coreweave-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 525.125.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   52 bits physical, 57 bits virtual
Byte Order:                      Little Endian
CPU(s):                          128
On-line CPU(s) list:             0-127
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Platinum 8462Y+
CPU family:                      6
Model:                           143
Thread(s) per core:              2
Core(s) per socket:              32
Socket(s):                       2
Stepping:                        8
CPU max MHz:                     4100.0000
CPU min MHz:                     800.0000
BogoMIPS:                        5600.00
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req hfi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       3 MiB (64 instances)
L1i cache:                       2 MiB (64 instances)
L2 cache:                        128 MiB (64 instances)
L3 cache:                        120 MiB (2 instances)
NUMA node(s):                    2
NUMA node0 CPU(s):               0-31,64-95
NUMA node1 CPU(s):               32-63,96-127
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.4.0a0+ed76079
[pip3] torchaudio==2.2.0a0+ea437b3
[pip3] torchvision==0.19.0a0+947ae1d
[pip3] triton==3.0.0
[conda] Could not collect

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@awgu
Copy link
Contributor

awgu commented May 20, 2024

cc: @msaroufim @wanchaol

PyTorch version: 2.4.0a0+ed76079

Does this mean that the fused AdamW support had not landed yet?

@ad8e
Copy link
Contributor Author

ad8e commented May 20, 2024

This is the commit I use, from one week ago: ed76079

This is the fused Adam commit, from two weeks ago: 3407899

I didn't find further fused Adam commits, so the fused Adam support seems to have fully landed in the image I'm using.

@wz337
Copy link
Contributor

wz337 commented May 20, 2024

@ad8e Thanks for reporting the issue. Looking into it now.

@wanchaol
Copy link
Contributor

@ad8e I believe we missed two operator overloads for the case where lr is a tensor, if lr be an float is an option for you, that should work in main

@ad8e
Copy link
Contributor Author

ad8e commented May 20, 2024

LR being a float is fine. There are some interactions between float vs tensor LRs and torch.compile, but there's no priority on this issue from my side, since I'm just exploring FSDP2 to report bugs rather than doing big training runs with it.

@wanchaol
Copy link
Contributor

@ad8e sounds good we are fixing this in any case :)

@wanchaol wanchaol added oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants