Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies #4696

Merged
merged 12 commits into from May 14, 2024

Conversation

KuntaiDu
Copy link
Contributor

@KuntaiDu KuntaiDu commented May 8, 2024

The hashing function implemented in automatic prefix caching is running slowly. To benchmark its performance, I implemented a new benchmarking script benchmarks/benchmark_hashing.py. It shows that the hashing takes 47.98% of the total inference time when all prefix blocks are cached.

After profiling, over 90% of the time is spent on preparing the data for hashing. The reason is that there are multiple deep copies involved while preparing the data.
As shown in the figure,
image
Line 254, 255, 256 are taking about 30% of the total runtime of function hash_of_block due to deep copy, and the hashing function itself only takes 10% of the hash_of_block's runtime.

To fix this issue, I implemented a new auxiliary function in vllm/sequence.py called get_prefix_token_ids. This function does a similar job as get_token_ids()[0:num_tokens], except for the return value of this function is directly hashable (thus no need to use the deep copies to prepare data for hashing).

After this fix, the hashing function is accelerated by 5x, now hashing only takes 15% of the total inference time.

Side note: this pr assumes that the prompt_token_ids inside SequenceData, once created, will not change (as I will convert prompt_token_ids to tuple and reuse it for all future get_prefix_token_ids request). For now, I will assume that this is true. Please leave comments if you feel this may not hold in certain cases.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@@ -119,6 +119,7 @@ def __init__(
output_token_ids = []

self.prompt_token_ids = prompt_token_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to store it as both list and a tuple?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically we do not have to, as long as there is no code that directly change the value of prompt_token_ids attribute inside the class SequenceGroup.
I did a code search that confirms that there is no such code inside vLLM repo now. The search keyword is ".prompt_token_ids"
image
And there is no such code that changes prompt_token_ids besides the initialization function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make it a tuple then, IMO. This also signifies prompt tokens are immutable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, most of the existing codes are assuming the attribute prompt_token_ids to be of type List[int], and only the hashing function requires it to be Tuple[int] so that it is immutable and thus hashable. From this perspective, I guess it is worthwhile to have an extra copy of prompt_token_ids so that both hashing is fast, and developers can still treat prompt_token_ids as a list.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we at least make the attribute with list private and only allow access (and not setting) through a property? Not perfect since it can still be modified in place but better than nothing.

Another idea is to subclass list to create FrozenList with mutable methods set to raise an exception but that's a lot more complex.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since immutable prompt token IDs is the most important assumption, I'd also suggest to change it to tuple directly. It seems not an issue to me to change all List[int] to Tuple[int, ...]. At least type annotation should not be a blocker.

Copy link
Contributor Author

@KuntaiDu KuntaiDu May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the data type of prompt_token_ids to Tuple[int, ...], and fixed typing conflicts caused by this change.

@KuntaiDu KuntaiDu requested review from comaniac and Yard1 May 9, 2024 20:56
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the timely fix!

…t aligns with the expectation of test cases, plus moving benchmark_hashing.py to a new folder
vllm/outputs.py Outdated
@@ -84,7 +84,7 @@ def __init__(
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = tuple(prompt_token_ids)
self.prompt_token_ids = list(prompt_token_ids)
Copy link
Collaborator

@Yard1 Yard1 May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to do this? maybe it will be better to change tests instead.

Copy link
Contributor Author

@KuntaiDu KuntaiDu May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case corresponding to this change is:
In tests/conftest.py:

        for req_output in req_outputs:
            ...
            prompt_ids = req_output.prompt_token_ids
            ...
            for sample in req_output.outputs:
                ...
                req_sample_output_ids.append(prompt_ids + output_ids)
                ...
            ...

And the code prompt_ids + output_ids requires prompt_ids to be list.

I guess processing vllm's output using code like prompt_ids + output_ids may be common in current vllm-based apps. So maybe keeping the prompt_token_ids attribute in RequestOutputs as list would be better for compatibility's sake.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think it would be fine to keep it as tuple and cast it to list in the test code tbh

@cadedaniel
Copy link
Collaborator

btw, we will move to the BlockManagerV2 soon, which has better design for hashing. see

def content_hash(self) -> Optional[int]:

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should avoid tuple -> list conversion. I measured locally, and it takes more than 10 us, which is pretty expensive.

vllm/outputs.py Outdated
@@ -75,7 +75,7 @@ def __init__(
self,
request_id: str,
prompt: str,
prompt_token_ids: List[int],
prompt_token_ids: Union[List[int], Tuple[int, ...]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove tuple input?

vllm/sequence.py Outdated
@@ -139,7 +139,18 @@ def get_output_len(self) -> int:
return len(self.output_token_ids)

def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
return list(self.prompt_token_ids) + self.output_token_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove tuple input and remove conversion!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. For performance we need to avoid tuple -> list conversion. How about storing the list version as prompt_token_ids for accessing, and stores the tuple version in _prompt_token_ids_tuple for hashing speedup purposes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think that sounds good to me. (it should be fine since prompt tokens are not going to be changed)

Copy link
Collaborator

@rkooo567 rkooo567 May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this part? (ping me when it is done!)

) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
"""Get prefix tokens, and make the return value hashable"""
prompt_length = len(self.prompt_token_ids)
if num_tokens > prompt_length:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when does this happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens when calculating hashes for both the user input (i.e. the prompt tokens) and the LLM-generated output (i.e. output tokens ).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it happen under normal circumstance or is it only for recomputation case?

Copy link
Contributor Author

@KuntaiDu KuntaiDu May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will happen for normal cases (inside function _allocate_last_physical_block in vllm/sequence.py).

vllm/sequence.py Outdated
output_token_ids: Optional[List[int]] = None,
) -> None:
if output_token_ids is None:
output_token_ids = []

self.prompt_token_ids = prompt_token_ids
self.prompt_token_ids: Tuple[int, ...] = tuple(prompt_token_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment this should not be changed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also consider making it private attr

self._prompt_token_id

?

@KuntaiDu KuntaiDu requested review from rkooo567 and Yard1 May 10, 2024 18:01
Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should revert get_token_ids! I will approve it after that

vllm/sequence.py Outdated
@@ -139,7 +139,18 @@ def get_output_len(self) -> int:
return len(self.output_token_ids)

def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
return list(self.prompt_token_ids) + self.output_token_ids
Copy link
Collaborator

@rkooo567 rkooo567 May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this part? (ping me when it is done!)

) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
"""Get prefix tokens, and make the return value hashable"""
prompt_length = len(self.prompt_token_ids)
if num_tokens > prompt_length:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it happen under normal circumstance or is it only for recomputation case?

@KuntaiDu
Copy link
Contributor Author

I think you should revert get_token_ids! I will approve it after that

Done

@KuntaiDu KuntaiDu requested a review from rkooo567 May 14, 2024 05:06
@rkooo567
Copy link
Collaborator

before this PR: Hashing took 0.16 seconds,14.13% of the total runtime.
after this PR: 4.75%

@rkooo567 rkooo567 merged commit ccb63a8 into vllm-project:main May 14, 2024
59 checks passed
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 19, 2024
@KuntaiDu KuntaiDu deleted the kuntai-hashing branch May 20, 2024 17:29
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
tybalex pushed a commit to tybalex/vllm-function-call that referenced this pull request May 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants