Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton kernel doing more work uses less registers #126463

Open
shunting314 opened this issue May 16, 2024 · 9 comments
Open

Triton kernel doing more work uses less registers #126463

shunting314 opened this issue May 16, 2024 · 9 comments
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@shunting314
Copy link
Contributor

shunting314 commented May 16, 2024

We are trying to recompute softmax in backward pass so we can save a store and load of the large [BxT, V] tensor. More context in: #126348

One weird thing we encounter is, the kernel in backward pass that computes gradient of softmax input now need do a bit more things (recompute softmax) but uses less registers than before. The outcome is inductor miss a much better triton config which can be picked if we scale RBLOCK down.

Old kernel without recomputation: https://gist.github.com/shunting314/f46a7598866e2ed69ce3b677a7694300 , script output:

rblock=2048 kernel.n_regs=40 ms=16.207
rblock=1024 kernel.n_regs=32 ms=12.833

New kernel with recomputation: https://gist.github.com/shunting314/7653edb6b2d036027b2ee7fd84a32353 , script output:

rblock=2048 kernel.n_regs=32 ms=17.223
rblock=1024 kernel.n_regs=32 ms=12.831

cc @Chillee , @eellison , @jansel as FYI
cc @htyu can we check why triton has such counter-intuitive behavior regarding register usage?

@shunting314 shunting314 added oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 16, 2024
shunting314 added a commit that referenced this issue May 17, 2024
…e reduction kernels"


Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers.

Will run perf test..


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue May 17, 2024
…els"


Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers.

Will run perf test..


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@htyu
Copy link
Contributor

htyu commented May 17, 2024

Karthikeyan Manivannan can take a look at what's going on with Triton.

@karthik-man
Copy link

I am seeing the same number of registers in both cases:

[~/triton (main)]$ TRITON_CACHE_DIR=$HOME/triton-cache/dump1 python ~/work/issues/126463/softmax_bwd_k1.py
14efd82a05f3d6020f6e683159d9f173ce520e99f2797052912a4e8a7c182d60 rblock=2048 kernel.n_regs=40 ms=10.224
a979388b7845426248da805049b53870c5afb5499771d04b7b3b953214069c7d rblock=1024 kernel.n_regs=32 ms=9.156

[~/triton (main)]$ TRITON_CACHE_DIR=$HOME/triton-cache/dump2 python ~/work/issues/126463/softmax_bwd_k2.py
365f353e83c5ddc165a630bcdc9f4ca005601ec6d3d8147823186d9854336675 rblock=2048 kernel.n_regs=40 ms=11.312
718152d0e5ad5222fa796d108f80924ea227699ca5b2edc6b8f6e36eca5bdc80 rblock=1024 kernel.n_regs=32 ms=9.141

[~/triton (main)]$ ptxas --gpu-name=sm_90a -v ~/triton-cache/dump1/14efd82a05f3d6020f6e683159d9f173ce520e99f2797052912a4e8a7c182d60/triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx
ptxas info : 0 bytes gmem
ptxas info : Compiling entry function 'triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2' for 'sm_90a'
ptxas info : Function properties for triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 40 registers

[~/triton (main)]$ ptxas --gpu-name=sm_90a -v ~/triton-cache/dump2/365f353e83c5ddc165a630bcdc9f4ca005601ec6d3d8147823186d9854336675/triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx
ptxas info : 0 bytes gmem
ptxas info : Compiling entry function 'triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2' for 'sm_90a'
ptxas info : Function properties for triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 40 registers

This output is from 4faa131964a93d5022e284ceb99957f714e2d4e5 and I see the same behavior on 8e48e4fa454e652438f41ff00cbb9ed38485e0f8

@shunting314
Copy link
Contributor Author

Can you check the current pined triton in pytorch 45fff310c891f5a92d55445adf8cc9d29df5841e ?

@karthik-man
Copy link

I am seeing the same the same number of registers across both kernels on the pinned hash (45fff310) too.

[~/triton (45fff310)]$ ptxas --gpu-name=sm_90a -v ~//triton-cache/dump1/9b6fa6e98c5d58e5c3ff5d322a33db7a09b12531893a0ecf516dc9ddafc58cff/triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx
ptxas info : 0 bytes gmem
ptxas info : Compiling entry function 'triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2' for 'sm_90a'
ptxas info : Function properties for triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 40 registers

[~/triton (45fff310)]$ ptxas --gpu-name=sm_90a -v ~//triton-cache/dump2/224ba495c345498737865c572e18a1c439b5b194a7d53044dc2bf7d82aaac31d/triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx
ptxas info : 0 bytes gmem
ptxas info : Compiling entry function 'triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2' for 'sm_90a'
ptxas info : Function properties for triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 40 registers

@shunting314
Copy link
Contributor Author

hmm, very interesting. One thing I can think of is, the tests mentioned in the issue summary are ran on A100, while you seems to see different behavior on H100.

Can you try on an A100 to see if you repro?

@karthik-man
Copy link

Repros on A100.

@karthik-man
Copy link

The issue seems arise from how ptxas is using registers on A100. For k2, Triton is producing the same ptx file for both H100 and A100. For whatever reason, ptxas uses fewer registers on A100.

@shunting314
Copy link
Contributor Author

For whatever reason, ptxas uses fewer registers on A100.

Not sure how far we can go. But what versions of ptxas are you using on A100 and H100? I'm wondering if it's due to different ptxas version or due to different GPUs.

shunting314 added a commit that referenced this issue May 21, 2024
…e reduction kernels"


Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers.

Will run perf test..


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue May 21, 2024
…els"


Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers.

Will run perf test..


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@karthik-man
Copy link

karthik-man commented May 22, 2024

I ran with the same ptxas version(V12.1.105) on both H100 and A100. Setting --gpu-name=sm_80 reduces register usage as compared to setting --gpu-name=sm_90 for the same k2 bs=2048 ptx file.

pytorchmergebot pushed a commit that referenced this issue May 22, 2024
Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers.

Will run perf test..

Pull Request resolved: #126477
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants