Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Enhance attention selector #4751

Merged
merged 19 commits into from May 13, 2024
Merged

[Misc] Enhance attention selector #4751

merged 19 commits into from May 13, 2024

Conversation

WoosukKwon
Copy link
Collaborator

This PR is to provide more information (such as block size and kv cache dtype) to attention backend selector so that it can be used to find the appropriate attention backend. Also, the PR moves kv_cache_dtype from AttentionMetadata to Attention.

This PR is a prerequisite for #3648

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! One question regarding why we change the interface of get_attn_backend!

@@ -29,10 +30,22 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: when is this None? (should we just not allow None here? Since cache config is supposed to be created by default?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. In most situations, cache_config isn't None. However, I wanted to provide the flexibility to initialize the model without cache_config, which can be particularly useful in niche scenarios such as testing the model loader. For instance, some tests in test_tensorizer only use the HF config to initialize the model, without setting up a CacheConfig or ModelConfig. Additionally, allowing cache_config to be optional helps maintain consistency with the HF model interface, where a model can be instantiated solely with the HF config.
I think this adjustment makes the setup more versatile and aligns better with existing practices.

from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
else:
raise ValueError("Invalid attention backend.")


def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
def _which_attn_to_use(
num_heads: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change necessary? (seems like most of args are not used?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! It's actually for the PR #3648 and future PRs where we need to consider block sizes and KV cache dtypes in selecting the backend.

vllm/attention/layer.py Show resolved Hide resolved
@WoosukKwon WoosukKwon merged commit 0fca3cd into main May 13, 2024
47 of 48 checks passed
@WoosukKwon WoosukKwon deleted the attn-selector branch May 13, 2024 17:47
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 19, 2024
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
tybalex pushed a commit to tybalex/vllm-function-call that referenced this pull request May 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants