-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Triton kernel doing more work uses less registers #126463
Comments
…e reduction kernels" Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers. Will run perf test.. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…els" Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers. Will run perf test.. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Karthikeyan Manivannan can take a look at what's going on with Triton. |
I am seeing the same number of registers in both cases: [~/triton (main)]$ TRITON_CACHE_DIR=$HOME/triton-cache/dump1 python ~/work/issues/126463/softmax_bwd_k1.py [~/triton (main)]$ TRITON_CACHE_DIR=$HOME/triton-cache/dump2 python ~/work/issues/126463/softmax_bwd_k2.py [~/triton (main)]$ ptxas --gpu-name=sm_90a -v ~/triton-cache/dump1/14efd82a05f3d6020f6e683159d9f173ce520e99f2797052912a4e8a7c182d60/triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx [~/triton (main)]$ ptxas --gpu-name=sm_90a -v ~/triton-cache/dump2/365f353e83c5ddc165a630bcdc9f4ca005601ec6d3d8147823186d9854336675/triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx This output is from 4faa131964a93d5022e284ceb99957f714e2d4e5 and I see the same behavior on 8e48e4fa454e652438f41ff00cbb9ed38485e0f8 |
Can you check the current pined triton in pytorch 45fff310c891f5a92d55445adf8cc9d29df5841e ? |
I am seeing the same the same number of registers across both kernels on the pinned hash (45fff310) too. [~/triton (45fff310)]$ ptxas --gpu-name=sm_90a -v ~//triton-cache/dump1/9b6fa6e98c5d58e5c3ff5d322a33db7a09b12531893a0ecf516dc9ddafc58cff/triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx [~/triton (45fff310)]$ ptxas --gpu-name=sm_90a -v ~//triton-cache/dump2/224ba495c345498737865c572e18a1c439b5b194a7d53044dc2bf7d82aaac31d/triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx |
hmm, very interesting. One thing I can think of is, the tests mentioned in the issue summary are ran on A100, while you seems to see different behavior on H100. Can you try on an A100 to see if you repro? |
Repros on A100. |
The issue seems arise from how ptxas is using registers on A100. For k2, Triton is producing the same ptx file for both H100 and A100. For whatever reason, ptxas uses fewer registers on A100. |
Not sure how far we can go. But what versions of ptxas are you using on A100 and H100? I'm wondering if it's due to different ptxas version or due to different GPUs. |
…e reduction kernels" Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers. Will run perf test.. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…els" Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers. Will run perf test.. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
I ran with the same ptxas version(V12.1.105) on both H100 and A100. Setting --gpu-name=sm_80 reduces register usage as compared to setting --gpu-name=sm_90 for the same k2 bs=2048 ptx file. |
Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers. Will run perf test.. Pull Request resolved: #126477 Approved by: https://github.com/jansel
We are trying to recompute softmax in backward pass so we can save a store and load of the large [BxT, V] tensor. More context in: #126348
One weird thing we encounter is, the kernel in backward pass that computes gradient of softmax input now need do a bit more things (recompute softmax) but uses less registers than before. The outcome is inductor miss a much better triton config which can be picked if we scale RBLOCK down.
Old kernel without recomputation: https://gist.github.com/shunting314/f46a7598866e2ed69ce3b677a7694300 , script output:
New kernel with recomputation: https://gist.github.com/shunting314/7653edb6b2d036027b2ee7fd84a32353 , script output:
cc @Chillee , @eellison , @jansel as FYI
cc @htyu can we check why triton has such counter-intuitive behavior regarding register usage?
The text was updated successfully, but these errors were encountered: