Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sync torch_FA2 and FA2 flash_api] + [Expose seqused_k & alibi_slopes arguments] #126520

Closed
wants to merge 6 commits into from

Conversation

lvaleriu
Copy link
Contributor

@lvaleriu lvaleriu commented May 17, 2024

  1. Expose seqused_k & alibi_slopes arguments:
  • This can be used when your sequence length k is not the full extent of the tensor. This is useful for kv cache scenarios and was not previously supported in the FA2 TORCH integration. We need these arguments for external xformers lib call to the _flash_attention_forward API.
    Before:
  std::optional<Tensor> seqused_k = c10::nullopt;
  std::optional<Tensor> alibi_slopes = c10::nullopt;

After:

_flash_attention_forward(...
    std::optional<Tensor>& seqused_k,
    std::optional<Tensor>& alibi_slopes,
  1. There is a difference between the TORCH_FA2_flash_api:mha_fwd and FA2_flash_api:mha_fwd (same for mha_varlen_fwd) at the query transposition (GQA) step.

The CHECK_SHAPE is applied on the original query vs the reshaped query. This causes an error (because of the shape constraint) for such inputs:

q = torch.randn([7, 1, 4, 256], dtype=torch.bfloat16, device='cuda')
k = torch.randn([7, 51, 1, 256], dtype=torch.bfloat16, device='cuda')
v = torch.randn([7, 51, 1, 256], dtype=torch.bfloat16, device='cuda')

image

  • i've modified the code as little as possible, but if you prefer a more verbose change like the following, dont hesitate to tell me:
at::Tensor swapped_q = seqlenq_ngroups_swapped 
    ? q.reshape({batch_size, num_heads_k, num_heads / num_heads_k, head_size_og}).transpose(1, 2)
    : q;

if (seqlenq_ngroups_swapped) {
    seqlen_q = num_heads / num_heads_k;
    num_heads = num_heads_k;
}

CHECK_SHAPE(swapped_q, batch_size, seqlen_q, num_heads, head_size_og);

Copy link

pytorch-bot bot commented May 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126520

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 58 New Failures, 5 Unrelated Failures

As of commit e8c29f3 with merge base ec8b254 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@lvaleriu lvaleriu requested a review from drisspg May 17, 2024 07:46
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch

@drisspg drisspg added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels May 17, 2024
@drisspg
Copy link
Contributor

drisspg commented May 17, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@drisspg
Copy link
Contributor

drisspg commented May 17, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix_fa2_gqa onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix_fa2_gqa && git pull --rebase)

@drisspg
Copy link
Contributor

drisspg commented May 17, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@lvaleriu lvaleriu changed the title check_shape for the transposed query [Sync torch_FA2 and FA2 flash_api] + [Expose seqused_k & alibi_slopes arguments] May 21, 2024
@drisspg
Copy link
Contributor

drisspg commented May 24, 2024

These changes look good

@lvaleriu
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@lvaleriu
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Aidyn-A pushed a commit to tinglvv/pytorch that referenced this pull request May 30, 2024
… arguments] (pytorch#126520)

1. **Expose seqused_k & alibi_slopes arguments**:
- This can be used when your sequence length k is not the full extent of the tensor. This is useful for kv cache scenarios and was not previously supported in the FA2 TORCH integration. We need these arguments for external xformers lib call to the _flash_attention_forward API.
Before:
```
  std::optional<Tensor> seqused_k = c10::nullopt;
  std::optional<Tensor> alibi_slopes = c10::nullopt;
```
After:
```
_flash_attention_forward(...
    std::optional<Tensor>& seqused_k,
    std::optional<Tensor>& alibi_slopes,
```

2. There is a difference between the **TORCH_FA2_flash_api:mha_fwd** and **FA2_flash_api:mha_fwd** (same for **mha_varlen_fwd**) at the query transposition (GQA) step.

The **CHECK_SHAPE** is applied on the original query vs the reshaped query. This causes an error (because of the shape constraint) for such inputs:
```
q = torch.randn([7, 1, 4, 256], dtype=torch.bfloat16, device='cuda')
k = torch.randn([7, 51, 1, 256], dtype=torch.bfloat16, device='cuda')
v = torch.randn([7, 51, 1, 256], dtype=torch.bfloat16, device='cuda')
```

![image](https://github.com/pytorch/pytorch/assets/927999/77ea6bf6-b6e9-4f3f-96a9-8d952956ddd9)

- i've modified the code as little as possible, but if you prefer a more verbose change like the following, dont hesitate to tell me:
```
at::Tensor swapped_q = seqlenq_ngroups_swapped
    ? q.reshape({batch_size, num_heads_k, num_heads / num_heads_k, head_size_og}).transpose(1, 2)
    : q;

if (seqlenq_ngroups_swapped) {
    seqlen_q = num_heads / num_heads_k;
    num_heads = num_heads_k;
}

CHECK_SHAPE(swapped_q, batch_size, seqlen_q, num_heads, head_size_og);
```
Pull Request resolved: pytorch#126520
Approved by: https://github.com/drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants