Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hf_olmo: support flash attn 2 #471

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wade3han
Copy link

#460, tested with a simple snippet as below:

import transformers, torch

model = transformers.AutoModelForCausalLM.from_pretrained("allenai/OLMo-7B-Instruct", use_flash_attention_2="flash_attention_2", trust_remote_code=True).cuda()
tokenizer = transformers.AutoTokenizer.from_pretrained("allenai/OLMo-7B-Instruct", trust_remote_code=True)

print(tokenizer.decode(model.generate(torch.tensor(tokenizer.encode("Hello World! My name is")).unsqueeze(0).cuda())[0]))
# Hello World! My name is Emily and I am a second year student at the University of California,

@epwalsh
Copy link
Member

epwalsh commented Mar 1, 2024

Someone familiar with transformers internals should review this (maybe @AkshitaB). I'm not sure what transformers does with this, but I'd be very cautious if they're monkey-patching our attention mechanism since flash-attn expends a different input shape (the head and sequence dimensions are flipped compared to what we normally do).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants