Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Prefix caching for multimodal language models #11187

Merged
merged 10 commits into from
Dec 18, 2024

Conversation

comaniac
Copy link
Collaborator

@comaniac comaniac commented Dec 13, 2024

This PR enables prefix caching for VLMs. Specifically, we enhanced the KV block hash to support extra keys with the image hash and offset.

Block Hash Format

Taking a series of 3 blocks as an example: T0,T1,P00,P01 | P02,P03,P04,T2 | T3,P10,P11,P12, where Ti is i-th text token and Pxy is the y-th placeholder token of the x-th image, so this prompt has 2 images (P0 and P1). Assuming the image hash of P0 and P1 is aaa and bbb, respectively, and mm_positions=[(offset=2, length=5), (offset=9, length=3)], the hash of 3 blocks is as follows

# (Parent hash,
#  token ID w. placeholders,
#  image hash, start)
hash0 = hash(None, T0,T1,P00,P01, (aaa,0))
hash1 = hash(hash0, P02,P03,P04,T2, (aaa,2))
hash2 = hash(hash1, T3,P10,P11,P12, (bbb,0))

A more straightforward is to embed the image hash and offset directly:

hash0 = hash(None, T0,T1,(aaa,0),(aaa,1))

We don't adopt this approach because it needs to traverse all input tokens and replace placeholder tokens with the tuple.

Performance Optimization

To reduce the overhead of computing the extra keys of each block, this PR adds an optimization that caches the computed hash values in Request, so that we guarantee the block hash for a request only needs to be computed once.

Benchmark

We benchmarked the throughput using Llava-1.6-Mistral-7B with 500 prompts on L40S GPU. The image hit rate is set to 30%, meaning that we have 500*0.7=350 unique images and 500-350=150 redundant requests. We put the redundant requests together to achieve the best cache locality for better illustration the effectiveness of prefix caching. The benchmark script is https://gist.github.com/comaniac/ea26df17fdffa533cf53d53b8455bc31

VLLM_USE_V1=1 VLLM_ENABLE_V1_MULTIPROCESSING=1 python3 mmmu_bench.py --model llava-hf/llava-v1.6-mistral-7b-hf --num-prompts 500 --image-hit-rate 0.3 --no-enable-prefix-caching
> Throughput: 3.84 req/s

VLLM_USE_V1=1 VLLM_ENABLE_V1_MULTIPROCESSING=1 python3 mmmu_bench.py --model llava-hf/llava-v1.6-mistral-7b-hf --num-prompts 500  --image-hit-rate 0.3 --mm-cache-preprocessor --no-enable-prefix-caching
> Throughput: 3.85 req/s

VLLM_USE_V1=1 VLLM_ENABLE_V1_MULTIPROCESSING=1 python3 mmmu_bench.py --model llava-hf/llava-v1.6-mistral-7b-hf --num-prompts 500  --image-hit-rate 0.3 --mm-cache-preprocessor
> Throughput: 7.08 req/s

Note: Now prefix caching for VLMs is enabled by default, but it requires the image hashes from mm cache preprocessor, so the following command (enabled prefix caching w/o mm cache preprocessor) will result in error. @alexm-neuralmagic please let me know what's the best practice for this.

VLLM_USE_V1=1 VLLM_ENABLE_V1_MULTIPROCESSING=1 python3 mmmu_bench.py --model llava-hf/llava-v1.6-mistral-7b-hf --num-prompts 500  --image-hit-rate 0.3

cc @alexm-neuralmagic @ywang96 @rickyyx

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

vllm/v1/request.py Outdated Show resolved Hide resolved
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac Thanks for this great work! Overall the code looks clean and I have left some comments. PTAL!

vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
vllm/v1/request.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
Copy link

mergify bot commented Dec 15, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @comaniac.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 15, 2024
Copy link
Contributor

@rickyyx rickyyx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really great! I mainly chimed in for some nits.

The only main question I have is on the generate_block_hash_extra_keys routine, which I feel there's some potential to make it easier to reason with. But I might be overlooking some constraints that have driven its current impl.

vllm/inputs/data.py Outdated Show resolved Hide resolved
vllm/inputs/data.py Show resolved Hide resolved
vllm/v1/core/kv_cache_manager.py Show resolved Hide resolved
vllm/v1/core/kv_cache_manager.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_manager.py Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
@sleepwalker2017
Copy link

sleepwalker2017 commented Dec 16, 2024

export VLLM_USE_V1=1
Is this a must?
I export it, and vllm complains

ERROR 12-16 12:00:25 core.py:263] Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

If I run without this, it runs ok.

@alexm-redhat
Copy link
Collaborator

@comaniac I will modify the code so you don't get an error without mm cache preprocessor. Will do it on your PR and send you the patch.

Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
@comaniac
Copy link
Collaborator Author

All comments should have been addressed. PTAL @ywang96 @alexm-neuralmagic @rickyyx.

Highlights:

  • Alex's patch is applied so we don't have to enable mm preprocessor to make prefix caching work, although enabling mm preprocessor is still recommended for better performance.
  • The default behavior of enable_prefix_caching is changed to the following
    • v0: Default off and force off for MM models.
    • v1, text-only models: Default on.
    • v1, MM models: Default off.

@comaniac
Copy link
Collaborator Author

Note: CI failure is unrelated.

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 17, 2024
Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! @comaniac thanks for making prefix caching work for VLMs! Just some nits

vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Show resolved Hide resolved
vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
vllm/v1/request.py Outdated Show resolved Hide resolved
Signed-off-by: Cody Yu <[email protected]>
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I've shared some benchmark results on Slack.

The negative impact of APC is minimal at even 0% hit rate, so I think this PR is good to go!

Signed-off-by: Cody Yu <[email protected]>
@comaniac
Copy link
Collaborator Author

I found that it's tricky to configure different the default value of prefix caching for MM, because we don't know what model to serve when creating engine config from CLI. So now I enable prefix caching by default for all models in v1. We should mention in the blogpost/announcement that if users encounter any errors with MM in v1, disabling prefix caching is one of the things they could try to workaround.

cc @ywang96 @WoosukKwon

Signed-off-by: Cody Yu <[email protected]>
Copy link
Contributor

@rickyyx rickyyx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Cody Yu <[email protected]>
@simon-mo simon-mo merged commit bf8717e into vllm-project:main Dec 18, 2024
52 of 54 checks passed
@comaniac comaniac deleted the v1-vlm-cache branch December 18, 2024 01:00
SageMoore pushed a commit to neuralmagic/vllm that referenced this pull request Dec 19, 2024
BKitor pushed a commit to BKitor/vllm that referenced this pull request Dec 30, 2024
@ywang96
Copy link
Member

ywang96 commented Dec 31, 2024

Going to rename this PR to Prefix caching for multimodal language models since the underlying logic is not tied to image input format!

@ywang96 ywang96 changed the title [V1] Prefix caching for vision language models [V1] Prefix caching for multimodal language models Dec 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants