New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NT] Implementing Multi-Head Attention with NestedTensors #125214
Comments
@clessig Did you try this with the |
@cpuhrsch : thanks, torch.jagged helped. Next error:
Traceback (most recent call last): The documentation for reshape isn't very detailed but this is what I would have expected to work. |
@clessig it looks like you want to reshape a 3D nested tensor -> a 4D shape:
Something like this will work: import torch
shapes = [(19, 1024), (18, 1024), (22, 1024), (19, 1024), (13, 1024),
(18, 1024), (21, 1024), (17, 1024), (22, 1024), (19, 1024)]
a = torch.nested.as_nested_tensor(
[torch.randn(*shape, device="cuda") for shape in shapes],
layout=torch.jagged
)
print(a.shape, a.dim()) # torch.Size([10, j1, 1024]) 3
# do projection
lin = torch.nn.Linear(1024, 1024, bias=False, device="cuda")
q = lin(a)
print(q.shape, q.dim()) # torch.Size([10, j1, 1024]) 3
# split heads
p = q.unflatten(-1, [8, 128])
# alternative reshape() calls:
# p = q.reshape(-1, -1, 8, 128)
# p = q.reshape(10, -1, 8, 128)
print(p.shape, p.dim()) # torch.Size([10, j1, 8, 128]) 4 I suggest using |
Hi @jbschlosser, Many thanks for the example. This works now. I also extended it to use flash attention. I also found now your example here: https://github.com/pytorch/tutorials/blob/main/prototype_source/nestedtensor.py. Unfortunately, the torch.compile is broken with the inductor backend:
When I try to compile my code I get a different error:
Is there already an issue for the problems with torch.inductor and nested_tensors? Any ETA for a fix / work-around? Thanks! |
Hey @clessig, sorry for the trouble - the first problem you mentioned above stemmed from some bad interaction between PT2 tracing + subclass The second problem is likely related to #118446, and we are actively working on addressing this. In the meantime, a workaround is to move nested tensor construction outside of the compiled region and send them as inputs to the compiled region instead. This should work fine. Let me know if you run into other issues - thanks! |
Hi @jbschlosser , yes the first issue is fixed with the latest nightly (I even got it before posting but then only ran my code and not your example again :(). Performance is, however, not what one would expect: === with torch.compile === That's on a A100 with the latest CUDA. Did you see a speed-up? Thanks! |
Yes, I saw a speedup of somewhere in the 3-5x range locally. I'll investigate this on my A100 machine; there may have been some graph break related regression. Does passing |
fullgraph=True breaks the compile (as you suspected). I had modified your example to float16 to try out flash attention. Switching back to float32 I see a speedup of 2X: nested tensor multi-head attention takes 0.005227703019045293 seconds For reference float16: Should flash attention work or will it fall back to the regular implementation? |
Yes, flash should work here as long as all inputs match what it supports. I'd expect flash to be selected if it was compiled for you and if the inputs / MHA module are all converted to from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
... which will error out if flash isn't selectable as the backend. Running locally with |
Ok, this was/is what I am doing. I was wondering since for me the performance is identical with/without flash attention: with flash: without flash: Great to see that there's so much work on this! (very useful for what I am doing) |
If flash is available, it's the first priority to be selected. So I'd expect the same results for the same inputs with or without the use of
It's a work in progress but it's coming along! If you can provide any more details on the types of support you'll need (any op coverage gaps you run into, etc.), we can use that to help prioritize our efforts :) |
馃殌 The feature, motivation and pitch
Nested tensors are supported by PyTorch's flash attention implementation (cf. https://gist.github.com/victoroliv2/3668f07e11a0757febb6e55a8d78592a) and this has a markable (approx 25%) speedup compared to alternative options. But extending this example to a full multi-head attention implementation does not work at the moment since flash attention expects 3D tensors in the nested_tensor while nn.Linear requires 2D tensors.
RuntimeError: Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 4. Dense tensor dim: 2
This restriction on nn.Linear also seems odd to me. One could in principle construct the nested_tensor only after the projection but since this involves a copy operation it is rather inefficient and will likely negate any benefit from the flash attention with nested tensors.
Alternatives
No response
Additional context
Here's a minimal example:
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @erichan1 @mikaylagawarecki
The text was updated successfully, but these errors were encountered: