-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Sync torch_FA2 and FA2 flash_api] + [Expose seqused_k & alibi_slopes arguments] #126520
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
ab8d295
to
8e82ea8
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
These changes look good |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… _flash_attention_forward
…unction & copy ref
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… arguments] (pytorch#126520) 1. **Expose seqused_k & alibi_slopes arguments**: - This can be used when your sequence length k is not the full extent of the tensor. This is useful for kv cache scenarios and was not previously supported in the FA2 TORCH integration. We need these arguments for external xformers lib call to the _flash_attention_forward API. Before: ``` std::optional<Tensor> seqused_k = c10::nullopt; std::optional<Tensor> alibi_slopes = c10::nullopt; ``` After: ``` _flash_attention_forward(... std::optional<Tensor>& seqused_k, std::optional<Tensor>& alibi_slopes, ``` 2. There is a difference between the **TORCH_FA2_flash_api:mha_fwd** and **FA2_flash_api:mha_fwd** (same for **mha_varlen_fwd**) at the query transposition (GQA) step. The **CHECK_SHAPE** is applied on the original query vs the reshaped query. This causes an error (because of the shape constraint) for such inputs: ``` q = torch.randn([7, 1, 4, 256], dtype=torch.bfloat16, device='cuda') k = torch.randn([7, 51, 1, 256], dtype=torch.bfloat16, device='cuda') v = torch.randn([7, 51, 1, 256], dtype=torch.bfloat16, device='cuda') ``` ![image](https://github.com/pytorch/pytorch/assets/927999/77ea6bf6-b6e9-4f3f-96a9-8d952956ddd9) - i've modified the code as little as possible, but if you prefer a more verbose change like the following, dont hesitate to tell me: ``` at::Tensor swapped_q = seqlenq_ngroups_swapped ? q.reshape({batch_size, num_heads_k, num_heads / num_heads_k, head_size_og}).transpose(1, 2) : q; if (seqlenq_ngroups_swapped) { seqlen_q = num_heads / num_heads_k; num_heads = num_heads_k; } CHECK_SHAPE(swapped_q, batch_size, seqlen_q, num_heads, head_size_og); ``` Pull Request resolved: pytorch#126520 Approved by: https://github.com/drisspg
Before:
After:
The CHECK_SHAPE is applied on the original query vs the reshaped query. This causes an error (because of the shape constraint) for such inputs: