Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: Automatic Prefix Caching in multi-turn conversations #4917

Closed
hmellor opened this issue May 20, 2024 · 15 comments
Closed

[Performance]: Automatic Prefix Caching in multi-turn conversations #4917

hmellor opened this issue May 20, 2024 · 15 comments
Labels
performance Performance-related issues

Comments

@hmellor
Copy link
Collaborator

hmellor commented May 20, 2024

I'm interested in the automatic prefix caching feature for multi-turn conversations but I can't seem to observe a performance improvement when prefix caching is enabled. This tweet from @vllm_project indicates that automatic prefix caching should benefit this use case.

I am using the following commands to start the vLLM server:

python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --port 7001 --gpu-memory-utilization 0.5 --disable-log-requests --enforce-eager

python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --port 7001 --gpu-memory-utilization 0.5 --disable-log-requests --enforce-eager --enable-prefix-caching

And the following script to simulate a multi turn conversation from a user:

import time
from openai import OpenAI

user_messages = [
    "Tell me your ten favourite films",
    "Who directed each of these films?",
    "Which director has the most experience?",
    "What other films has this director directed?",
    "Do these films have anything in common?",
    "Which of those films is the oldest?",
    "How old was the director when this was released?",
]

client = OpenAI(api_key="api_key", base_url="http://localhost:7001/v1")

messages = []

start = time.perf_counter()
for user_message in user_messages:
    messages.append(dict(role="user", content=user_message))
    output = client.chat.completions.create(
        messages=messages,
        model="meta-llama/Llama-2-7b-chat-hf",
        temperature=0.0,
    )
    print(output.usage)
    assistant_message = output.choices[0].message
    messages.append(dict(role=assistant_message.role, content=assistant_message.content))
stop = time.perf_counter()
print(f"{stop - start = }")

With automatic prefix caching disabled I see:

$ python test.py 
CompletionUsage(completion_tokens=598, prompt_tokens=16, total_tokens=614)
CompletionUsage(completion_tokens=238, prompt_tokens=629, total_tokens=867)
CompletionUsage(completion_tokens=321, prompt_tokens=882, total_tokens=1203)
CompletionUsage(completion_tokens=465, prompt_tokens=1219, total_tokens=1684)
CompletionUsage(completion_tokens=446, prompt_tokens=1700, total_tokens=2146)
CompletionUsage(completion_tokens=212, prompt_tokens=2162, total_tokens=2374)
CompletionUsage(completion_tokens=58, prompt_tokens=2393, total_tokens=2451)
stop - start = 35.292656753677875

And with automatic prefix caching enabled I see:

$ python test.py 
CompletionUsage(completion_tokens=598, prompt_tokens=16, total_tokens=614)
CompletionUsage(completion_tokens=238, prompt_tokens=629, total_tokens=867)
CompletionUsage(completion_tokens=321, prompt_tokens=882, total_tokens=1203)
CompletionUsage(completion_tokens=468, prompt_tokens=1219, total_tokens=1687)
CompletionUsage(completion_tokens=459, prompt_tokens=1703, total_tokens=2162)
CompletionUsage(completion_tokens=197, prompt_tokens=2178, total_tokens=2375)
CompletionUsage(completion_tokens=60, prompt_tokens=2394, total_tokens=2454)
stop - start = 35.605276009999216

Is this expected?

@hmellor hmellor added the performance Performance-related issues label May 20, 2024
@hmellor
Copy link
Collaborator Author

hmellor commented May 20, 2024

CC @robertgshaw2-neuralmagic (the tweet said the feature was added by Neural Magic, so you might have some insight into this feature)

@hmellor hmellor changed the title [Performance]: [Performance]: Automatic Prefix Caching in multi-turn conversations May 20, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator

Will take a look at this case

@comaniac
Copy link
Contributor

I'm also interested in this issue so I benchmarked today using the latest main branch, which already uses flash-attn kernel for prefix caching. But even I've verified cache hit in prefix cache, I also found no speedup by running the above script. I'll also investigate a bit.

@robertgshaw2-neuralmagic
Copy link
Collaborator

cc @SageMoore fyi

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented May 20, 2024

I am not sure what GPU this is, but on an A100, we can do ~15000 prefill tokens/sec at fp16. So even a 2000 token prefill should only take 0.13 seconds to process. Since APC skips prefill computation, there are only 0.5s worth of time that can be optimized in this case. As a result, I would not really expect to see a speedup in this case (+ in fact there is some overhead associated with managing another layer of indirection)

APC really is useful for cases with long shared prefills and short decodes, such as:

  • Shared system prompt
  • Few shot prompting

@comaniac
Copy link
Contributor

Thanks for the good hint. I instead let the script report the latency of every request instead of the total time, and here are the results on L4 GPU:

w/o APC

stop - start = 28.363408592998894
CompletionUsage(completion_tokens=453, prompt_tokens=16, total_tokens=469)
stop - start = 13.253794727999775
CompletionUsage(completion_tokens=211, prompt_tokens=485, total_tokens=696)
stop - start = 15.47434264399999
CompletionUsage(completion_tokens=245, prompt_tokens=712, total_tokens=957)
stop - start = 22.77607062900279
CompletionUsage(completion_tokens=357, prompt_tokens=974, total_tokens=1331)
stop - start = 25.096272947001125
CompletionUsage(completion_tokens=392, prompt_tokens=1348, total_tokens=1740)
stop - start = 2.3558405980002135
CompletionUsage(completion_tokens=30, prompt_tokens=1757, total_tokens=1787)
stop - start = 2.8636473680016934
CompletionUsage(completion_tokens=37, prompt_tokens=1806, total_tokens=1843)

w. APC

stop - start = 28.40403065999999
CompletionUsage(completion_tokens=453, prompt_tokens=16, total_tokens=469)
stop - start = 13.463971014996787
CompletionUsage(completion_tokens=211, prompt_tokens=485, total_tokens=696)
stop - start = 15.43624263699894
CompletionUsage(completion_tokens=245, prompt_tokens=712, total_tokens=957)
stop - start = 22.343338724000205
CompletionUsage(completion_tokens=355, prompt_tokens=974, total_tokens=1329)
stop - start = 25.549687523998728
CompletionUsage(completion_tokens=403, prompt_tokens=1346, total_tokens=1749)
stop - start = 1.933658195001044
CompletionUsage(completion_tokens=30, prompt_tokens=1766, total_tokens=1796)
stop - start = 2.3811154130016803
CompletionUsage(completion_tokens=37, prompt_tokens=1815, total_tokens=1852)

It seems align to what you analyzed.

@hmellor
Copy link
Collaborator Author

hmellor commented May 20, 2024

I am not sure what GPU this is

My results came from A100 40GB with --gpu-memory-utilization 0.5 and --enforce-eager (both of which would make my experiments slower).

there are only 0.5s worth of time that can be optimized in this case

Ok, so it's simply a case of my test not being suitable. If I was running a model with a more expensive prefill (i.e. bigger than 7B) and with longer prompts, I'd start to be able to observe the difference in a single conversation (albeit a subtle difference).

Presumably there is also a concurrency benefit too, because the slot that would have been scheduled to execute the cached prefill can be used to process the prefill (or decoding) of a different request?

@KuntaiDu
Copy link
Collaborator

The key thing for automatic prefix caching to have a sizable improvement is that the ratio between input token length and output token length should be VERY VERY large (ideally more than 100x difference).

This is a very strong workload requirement, and such type of workload only commonly occurs in specific applications (e.g. asking questions to a very long software manual).

@hmellor hmellor closed this as completed May 21, 2024
@hmellor hmellor reopened this May 23, 2024
@hmellor
Copy link
Collaborator Author

hmellor commented May 23, 2024

I ran a better test and have an interesting graph:

image

Regardless of first prompt size, there seems to be a large fixed cost on turn 1 (i.e. the second turn), but not the subsequent turns.

@SageMoore:

  • Have you seen this while developing the feature?
  • Do you know where in the PrefixCachingBlockAllocator this comes from?

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented May 23, 2024

@hmellor - this is caused by Triton jitting. The first time the server runs the context_fwd_attention, Triton jits which slows us down. Have been meaning to finish off a PR that runs the JITing durin profiling, but has become lower priority since if you use latest main with flash attention this issue is resolved b/c it uses the flash attn kernels rather than triton for context_fwd_attn

@robertgshaw2-neuralmagic
Copy link
Collaborator

note: this will happen once per instantiation of the server

@hmellor
Copy link
Collaborator Author

hmellor commented May 23, 2024

if you use latest main with flash attention this issue is resolved

Is that the flash attention from the "pip install vllm-flash-attn for better performance." info log I've seen?

@comaniac
Copy link
Contributor

Yes. You can just pip install vllm-flash-attn and make sure seeing the log Using FlashAttention-2 when launching the server.

@robertgshaw2-neuralmagic
Copy link
Collaborator

I think its now installed automatically

https://github.com/vllm-project/vllm/blob/main/setup.py#L356

@hmellor
Copy link
Collaborator Author

hmellor commented May 23, 2024

Ok, thanks for clearing that up for me!

@hmellor hmellor closed this as completed May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues
Projects
None yet
Development

No branches or pull requests

4 participants