Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash-Attention #85

Open
Xreki opened this issue Dec 20, 2022 · 8 comments
Open

Flash-Attention #85

Xreki opened this issue Dec 20, 2022 · 8 comments
Labels
enhancement New feature or request

Comments

@Xreki
Copy link

Xreki commented Dec 20, 2022

No description provided.

@robotcator
Copy link

robotcator commented Dec 28, 2022

Hi, if you want to use the optimized flash attention code, you can check out the code here. And this document may be helpful. Hope this can help you.

@lhatsk
Copy link

lhatsk commented Jan 1, 2023

I run into NaNs if I enable flash attention.

unicore.nan_detector | NaN detected in output of model.evoformer.blocks.47.tri_att_end.mha.linear_v, shape: torch.Size([1, 292, 292, 128]), backward

WARNING | unicore.nan_detector | NaN detected in output of model.evoformer.blocks.21.msa_att_row.mha.linear_v, shape: torch.Size([1, 256, 184, 256]), backward

I get also lots of new warnings:
UserWarning: Using non-full backward hooks on a Module that does not return a single Tensor or a tuple of Tensors is deprecated and will
be removed in future versions. This hook will be missing some of the grad_output. Please use register_full_backward_hook to get the documented behavior.

Is it working for you @Xreki?

A100 with bfloat16 enabled

@robotcator
Copy link

Can you provide some details for the installation of flash attention? It seems that the backward did not work correctly.

@Xreki
Copy link
Author

Xreki commented Jan 3, 2023

@lhatsk It seems OK for me. I use the docker image dptechnology/unicore:latest-pytorch1.12.1-cuda11.6-flashattn, test the monomer model with demo data on 1-A100 GPU, using bfloat16 and no NaNs.

@lhatsk
Copy link

lhatsk commented Jan 3, 2023

I installed flash attention from source according to the README. torch 1.12.1 + CUDA 11.2
I tested it with multimer on 4 GPUs distributed over two nodes (finetuning). It doesn't happen right away. Interestingly, I also get NaNs with OpenFold when I enable flash attention (different data, different cluster, different software setup, monomer) but it happens in the pTM computation there.

@robotcator
Copy link

Can you write a single test for the flash_attn interface with the shape of the input like [1, 292, 292, 128], so that we can test the function whether works properly?

@lhatsk
Copy link

lhatsk commented Jan 5, 2023

Just running _flash_attn(q,k,v) works without NaNs. I tested it now also with the pre-compiled package and Uni-Fold monomer, also NaNs. Seems to happen after two or three samples.

@guolinke
Copy link
Member

guolinke commented Feb 3, 2023

you now can use this branch: https://github.com/dptech-corp/Uni-Fold/tree/flash-attn , to try the flash-attention.

@ZiyaoLi ZiyaoLi changed the title Is the newest optimized version with flash-attention updated to the repo now? And how can I test it? Flash-Attention Feb 22, 2023
@ZiyaoLi ZiyaoLi closed this as completed Apr 11, 2023
@ZiyaoLi ZiyaoLi reopened this Apr 11, 2023
@ZiyaoLi ZiyaoLi added the enhancement New feature or request label Apr 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants