New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Generate triton block pointers for discontiguous strided tensors #125077
Comments
I think not every non-contiguous access will cause inductor skips block_ptr. E.g., for 'a + b.t()', here is the code inductor generates which uses block_ptr for all 3 memory accesses:
|
Thanks, this is good context. So it seems like 2D block pointers are already possible, it's just that inductor might not take advantage of them in the case of padded rows coming from |
There is nothing special about as_strided. In that case inductor decided to generate a 1D kernel (since both dimensions had the same contiguity), but required a 2D load. Similarly, if you have a 2D kernel, but a 3D/4D load -- then block ptr won't be used. Option 1Change the tiling algorithm here: pytorch/torch/_inductor/codegen/triton.py Lines 3851 to 3854 in c5b1a4c
If you trigger a 2D tiled kernel, then block_ptr should get used. Option 2Generate a 2D load, then call tl.reshape. Something like:
This would require some multiple_of guards to ensure correctness. This would be a bit more flexible. |
Thanks @jansel for the suggestions. I can take a shot at this. Would option 2 break the requirement that tiling dims == block pointer dims? That seems preferable, but I might attempt option 1 first just to get things working. |
Yes, that is what I meant by "This would be a bit more flexible." |
I think I have a reasonable draft of option 2. It pattern matches on the div/modulo indexing expression to extract the strides and offset. I'm struggling with the Instead of shape guards, would it be possible to use
I think this could work if we check that the iteration ranges are all powers of 2. (Is this always true?) If |
I think there are some correctness issues with that, because the iteration order must match exactly between all loads/stores in the kernel. The guards I was talking about would need to be on the shape of the tensor being loaded. |
馃殌 The feature, motivation and pitch
I ran the following program to test what triton code is generated from a discontiguous tensor:
The generated kernel was:
It seems like Inductor generates a block pointer for the output, but reverts back to standard pointers for the input. Whereas if I don't call
torch.as_strided
on the input, I see block pointers for both.I am wondering if it's possible for inductor to generate something like this instead:
This would use the
strides
argument totl.make_block_ptr
to express that the input tensor is discontiguous. On GPUs, this could avoid the address calculation using division and modulo, which might yield some performance benefit. There is probably a much bigger win for accelerators like MTIA with simpler memory systems, where this code maps very naturally to DMA engines. Without this, simpler accelerators might have a tough time handling padding between the rows of a tensor.Is this feature feasible? The main change I see is that here
XBLOCK
would refer the columns of the input matrix, as opposed to the linear index. It would also be possible to block on rows.Alternatives
In principle, it's possible for the triton compiler to recognize this pattern under the hood. But it seems like that would require reading a whole number of rows, i.e.
XBLOCK
must be a multiple of the row length. Also, the analysis could get complex when division and modulo are involved. I'm wondering if makes more sense to handle this in Inductor.Instead of block pointers, it's also possible to simplify the address calculation for standard pointers, such as
which could more easily be converted to a block representation inside the triton compiler.
Additional context
No response
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
cc @shunting314 based on offline conversations. We were hoping for input from @jansel .
The text was updated successfully, but these errors were encountered: