Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Prefix caching #9668

Closed
wants to merge 12 commits into from
Closed

Conversation

comaniac
Copy link
Collaborator

@comaniac comaniac commented Oct 24, 2024

This PR adds prefix caching to V1.

Data Structure

  • Block pool: A pool of kv-cache blocks corresponding to block IDs that will be used in the entire engine lifecycle.
  • Free block queue: A queue of free blocks to be allocated. The blocks in this queue may be able to be reused (cache hit) by other requests.
  • Cached block map: Mapping from block hash to a list of blocks. The reason to have a list of blocks is we don't do de-duplication (see "Duplication" below for details). When cache hit, we always allocate the first block in the list to aggregate the references.

Algorithms

Allocate Slots

When a request is scheduled for the first time, allocate_slots() is used to allocate blocks based on the current scheduled prompt tokens. If the prompt is chunked due to chunked prefill, we will only allocate blocks for the scheduled tokens. In addition to the scheduled tokens, we also pre-allocate empty blocks to reduce allocation overheads.

With prefix caching, when we attempt to allocate a full block, we will compute its block hash and query the cached block map. There are 3 possible outcomes:

  1. Cache miss: Allocate a new block from free block queue: The new allocated block may be evicted from the cache.
  2. Cache hit and the block is in free block queue: Reuse the block and mark it to be removed from the queue.
  3. Cache hit and the block is not in free block queue (being used by other requests as well): Reuse the block.

Note 1: When cache hit a block in the free block queue, we put the block in a "lazy remove set" instead of immediately removing the block from the queue. This is because removing an element from a queue takes O(N). Instead, when we are allocating a new block and the front block in the queue is marked as lazy remove, we pop the block and move to the next one.

Note 2: When cache miss and we allocate a new block, the token IDs will be added to the allocated block to construct its hash. The block will also be added to the cache if it is full.

Append Slots

When a request is scheduled again, append_slots() is used to maybe allocate more blocks. This can be the case of continuous chunked prefill or decode. Here are the steps in the append slots:

  1. Check the allocated slots (empty slots in a partial block and preallocated blocks), and add token IDs to these slots.
  2. If the allocated blocks are full, add them to the cache.
  3. If the allocated slots are insufficient, allocate new blocks.

Free

When a request is done, all its blocks will decrease the reference count by 1. If a block now has 0 reference, it will be freed (push to the free block queue). Note that since we allocate new blocks by popping the free block queue, the block order in the free block queue is also the eviction order. Since we now use LRU eviction policy, the eviction order is

  1. The least accessed block.
  2. When a sequence of blocks has the same access time, the one with the longest hashed tokens will be evicted first, because this is the last block in a sequence and is less likely to be shared with other requests.

We maintain the above order by pushing free blocks to the queue in the reversed order, so that:

  1. The order of free requests implies the access time. An early free block will appear at the front of the queue.
  2. When pushing a sequence of blocks to the queue, the last block with more hashed tokens goes first.

Get Computed Blocks

Before calling allocate_slots(), the scheduler calls get_computed_block_ids() to know how many blocks hits the cache. This function simply computes the hash of full blocks and queries the cache for existing block IDs. This function won't allocate any block or change the block metadata.

Duplication

Since V1 has incremental prepare inputs, the block table is append-only. This results in potential duplications as shown below. Suppose we have 2 identical requests (same prompt with greedy sampling) arriving at different time:

TIme 1

req1: [0, 1, 2, 3 (partial, 14/16)]

Time 2

req1: [0, 1, 2, 3 (partial, 15/16)]
req2: [0, 1, 2, 4 (partial, 14/16)] # Partial block cannot be shared so we allocate a new block for req2

TIme 3

req1: [0, 1, 2, 3 (full)] # Block 3 is now sharable
req2: [0, 1, 2, 4 (partial, 15/16)]

TIme 4

req1: [0, 1, 2, 3 (full)]
req2: [0, 1, 2, 4 (full)]

At time 4, block becomes full and has the same hash and content as block 3. In vLLM V0 block manager, we will free block 4 and assign block 3 to req2 in the next step. However, we cannot do this in V1 because block table is append only. As a result, at this moment the cache will look like:

block_0_hash: [block0]
block_1_hash: [block1]
block_2_hash: [block2]
block_3_hash: [block3, block4]
  • When another request hits block 3 hash, we always allocate block 3.
  • Block 4 will be free once req2 is done.

We consider that this is fine with practical use cases, because:

  1. Only partial blocks will potentially have duplications. This happens at the last block of a prompt, or the first N blocks of decode.
  2. Only the same prompt with greedy sampling will encounter this issue, which is not a practical use case.

cc @WoosukKwon @zhuohan123

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@WoosukKwon
Copy link
Collaborator

This is amazing! Do you happen to have any performance benchmarks?

@comaniac
Copy link
Collaborator Author

This is amazing! Do you happen to have any performance benchmarks?

Plan to run some tests today and will report back once I got something.

@comaniac
Copy link
Collaborator Author

comaniac commented Oct 24, 2024

Benchmark

  • Model: neuralmagic/Meta-Llama-3-8B-Instruct-FP8
  • GPU: L40S

Benchmark Prefix Caching

Note that I disabled the warmup phase in this script because after warming up we are benchmarking exactly the same requests, which is not practical.

Command

VLLM_USE_V1=1 python3 benchmarks/benchmark_prefix_caching.py \
--model neuralmagic/Meta-Llama-3-8B-Instruct-FP8 \
--num-prompts 200 --repeat-count 2 \
--input-length-range 256:512 \
--dataset-path ../ShareGPT_V3_unfiltered_cleaned_split.json \
--seed 0 \
[--enable-prefix-caching]
Version Input (tok/s) Output (tok/s) Cost Time (s)
v1. main branch (no cache) 17777.03 481.86 8.56
v1. this PR w/o cache 17829.40 483.28 8.54
v1. this PR w. cache (49% hit rate) 32149.17 871.43 4.86

Benchmark Serving

Server command

VLLM_USE_V1=1 vllm serve neuralmagic/Meta-Llama-3-8B-Instruct-FP8 \
--disable-log-requests [--enable-prefix-caching]

Client command

python3 benchmarks/benchmark_serving.py --backend vllm \
--model neuralmagic/Meta-Llama-3-8B-Instruct-FP8 \
--dataset-name random --random-input-len 550 --random-output-len 150 \
--random-prefix-len 330 --seed 0 --request-rate 8 --num-prompts 500
Version MeanTTFT MeanTPOT
v1. main branch (no cache) 193.03 42.45
v1. this PR w/o cache 199.42 43.07
v1. this PR w. cache (37% hit rate) 125.42 32.70

Full Results

v1. main, w/o cache
============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  68.31
Total input tokens:                      440000
Total generated tokens:                  74812
Request throughput (req/s):              7.32
Output token throughput (tok/s):         1095.11
Total Token throughput (tok/s):          7535.89
---------------Time to First Token----------------
Mean TTFT (ms):                          193.03
Median TTFT (ms):                        174.06
P99 TTFT (ms):                           445.08
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          42.45
Median TPOT (ms):                        43.63
P99 TPOT (ms):                           54.82
---------------Inter-token Latency----------------
Mean ITL (ms):                           42.97
Median ITL (ms):                         30.04
P99 ITL (ms):                            127.54
==================================================

v1. PR w/o cache
============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  68.35
Total input tokens:                      440000
Total generated tokens:                  74812
Request throughput (req/s):              7.32
Output token throughput (tok/s):         1094.52
Total Token throughput (tok/s):          7531.86
---------------Time to First Token----------------
Mean TTFT (ms):                          199.42
Median TTFT (ms):                        177.75
P99 TTFT (ms):                           461.73
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          43.07
Median TPOT (ms):                        43.40
P99 TPOT (ms):                           55.52
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.60
Median ITL (ms):                         30.32
P99 ITL (ms):                            127.74
==================================================

v1. PR w. cache
============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  68.15
Total input tokens:                      440000
Total generated tokens:                  74862
Request throughput (req/s):              7.34
Output token throughput (tok/s):         1098.52
Total Token throughput (tok/s):          7555.07
---------------Time to First Token----------------
Mean TTFT (ms):                          125.42
Median TTFT (ms):                        117.17
P99 TTFT (ms):                           266.42
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          32.70
Median TPOT (ms):                        32.73
P99 TPOT (ms):                           37.35
---------------Inter-token Latency----------------
Mean ITL (ms):                           33.09
Median ITL (ms):                         27.30
P99 ITL (ms):                            81.41
==================================================

@comaniac comaniac force-pushed the v1_prefix_caching branch 3 times, most recently from 5de6d2e to 2b45c13 Compare October 25, 2024 01:09
@comaniac
Copy link
Collaborator Author

After some thoughts, I feel the current approach of lazy removing reused blocks may be inefficient when cache hit rate is high. It might still be better to implement a simple LRU cache with doubly linked list. I'll try it and benchmark later.

@WoosukKwon WoosukKwon self-requested a review October 25, 2024 17:06
@WoosukKwon
Copy link
Collaborator

QQ: How much perf do we lose if we enable prefix caching but got 0% cache hit?

@comaniac
Copy link
Collaborator Author

QQ: How much perf do we lose if we enable prefix caching but got 0% cache hit?

I did a serving benchmark before and didn't observe obvious performance regression with 0% cache hit. I could run another experiment later.

@njhill
Copy link
Member

njhill commented Oct 25, 2024

This is an awesome writeup, thanks @comaniac, makes a lot of sense to me. One thing we could think about is separating cache maintenance operations as something that can be done in parallel with the forward pass. e.g. instead of DLL (though I guess DLL should also be efficient so perhaps there's negligible value in that).

@comaniac
Copy link
Collaborator Author

This is an awesome writeup, thanks @comaniac, makes a lot of sense to me. One thing we could think about is separating cache maintenance operations as something that can be done in parallel with the forward pass. e.g. instead of DLL (though I guess DLL should also be efficient so perhaps there's negligible value in that).

Yeah DLL is definitely efficient in terms of time complexity, but every time adding a node requires to create a Python object. I'm afraid that this may introduce non-negligible latency overhead. Your idea of updating cache async is an interesting idea and I'll think about that!

Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
@comaniac
Copy link
Collaborator Author

comaniac commented Oct 28, 2024

Updated: Thanks @njhill for the tips. I'm now using a separate thread to perform costly O(n) operations asyncly at the end of each schedule step so that it can be done in parallel with the forward pass. Here are the benchmark results:

Server command

VLLM_USE_V1=1 vllm serve neuralmagic/Meta-Llama-3-8B-Instruct-FP8 \
--disable-log-requests [--enable-prefix-caching]

Client command (37% hit rate)

python3 benchmarks/benchmark_serving.py --backend vllm \
--model neuralmagic/Meta-Llama-3-8B-Instruct-FP8 \
--dataset-name random --random-input-len 550 --random-output-len 150 \
--random-prefix-len 330 --seed 0 --request-rate 8 --num-prompts 500
Version MeanTTFT MeanTPOT
v1. main branch (no cache) 195.66 46.29
v1. this PR w/o cache 195.53 44.35
v1. this PR w. cache (0% hit rate) 199.69 46.58
v1. this PR w. cache (37% hit rate) 133.98 35.74

Full Results

v1. main, w/o cache
============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  68.44
Total input tokens:                      440000
Total generated tokens:                  73403
Request throughput (req/s):              7.31
Output token throughput (tok/s):         1072.48
Total Token throughput (tok/s):          7501.24
---------------Time to First Token----------------
Mean TTFT (ms):                          195.66
Median TTFT (ms):                        176.93
P99 TTFT (ms):                           479.71
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          46.29
Median TPOT (ms):                        43.17
P99 TPOT (ms):                           146.91
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.81
Median ITL (ms):                         31.57
P99 ITL (ms):                            126.50
==================================================

v1. PR w/o cache
============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  68.71
Total input tokens:                      440000
Total generated tokens:                  74812
Request throughput (req/s):              7.28
Output token throughput (tok/s):         1088.89
Total Token throughput (tok/s):          7493.07
---------------Time to First Token----------------
Mean TTFT (ms):                          195.53
Median TTFT (ms):                        178.54
P99 TTFT (ms):                           437.57
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.35
Median TPOT (ms):                        45.22
P99 TPOT (ms):                           55.34
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.90
Median ITL (ms):                         31.79
P99 ITL (ms):                            126.00
==================================================

v1. PR w. cache (0% hit rate)
============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  68.40
Total input tokens:                      440000
Total generated tokens:                  73403
Request throughput (req/s):              7.31
Output token throughput (tok/s):         1073.12
Total Token throughput (tok/s):          7505.73
---------------Time to First Token----------------
Mean TTFT (ms):                          199.69
Median TTFT (ms):                        180.24
P99 TTFT (ms):                           492.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          46.58
Median TPOT (ms):                        43.50
P99 TPOT (ms):                           145.34
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.11
Median ITL (ms):                         31.37
P99 ITL (ms):                            129.37
==================================================

v1. PR w. cache (37.5% hit rate)
============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  68.59
Total input tokens:                      440000
Total generated tokens:                  74861
Request throughput (req/s):              7.29
Output token throughput (tok/s):         1091.39
Total Token throughput (tok/s):          7506.07
---------------Time to First Token----------------
Mean TTFT (ms):                          133.98
Median TTFT (ms):                        123.28
P99 TTFT (ms):                           298.46
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          35.74
Median TPOT (ms):                        35.47
P99 TPOT (ms):                           41.00
---------------Inter-token Latency----------------
Mean ITL (ms):                           36.20
Median ITL (ms):                         30.40
P99 ITL (ms):                            83.16
==================================================

Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
@comaniac comaniac marked this pull request as ready for review October 28, 2024 22:41
@njhill
Copy link
Member

njhill commented Oct 29, 2024

Thanks @comaniac. I think the plan is to have SPMD style workers, where even with single-GPU the worker will run in a separate process to the scheduler. We can then move the async maintenance to be done between sending/receiving the input/output for each step. Alternatively this could be achieved with a call-back similar to what's done in v0 with the async output processing. But I think @WoosukKwon wanted to avoid that if possible.

Otherwise, having a separate thread might interfere with the subsequent critical loop processing before it reaches the GPU forward pass.

@robertgshaw2-neuralmagic
Copy link
Collaborator

This is amazing! Do you happen to have any performance benchmarks?

cc @tlrmchlsmth re: workers

@comaniac
Copy link
Collaborator Author

Thanks @comaniac. I think the plan is to have SPMD style workers, where even with single-GPU the worker will run in a separate process to the scheduler. We can then move the async maintenance to be done between sending/receiving the input/output for each step. Alternatively this could be achieved with a call-back similar to what's done in v0 with the async output processing. But I think @WoosukKwon wanted to avoid that if possible.

Otherwise, having a separate thread might interfere with the subsequent critical loop processing before it reaches the GPU forward pass.

Yeah if we have an async scheduler then it makes sense, but at this moment we don't have that (and I believe we will have a sync scheduler anyways). What you mentioned can definitely be the case (we asyncly process operations before entering the forward pass on GPU). I don't have a better idea now tho. Which approach do you think is more preferable then? The previous approach with some more overheads when hit rate is high, or the DLL approach that introduces more code and data structure?

@robertgshaw2-neuralmagic
Copy link
Collaborator

Wow, these results look good even at a low cache hit rate

vllm/v1/core/kv_cache_manager.py Show resolved Hide resolved
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self._get_cached_block(block_hash):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found one corner case that I had to handle carefully with V0:

  • If a sequence has 3 blocks, [b0, b1, b2], and b0, b1 are both cached, but already evicted.
  • There are only 2 blocks that could be allocated (i.e. the already freed b1, b2)

When determining if the sequence can be allocated, IIUC, the current impl would:

  1. See there are 2 blocks cached (b1, b2)
  2. Calculate there's just b3 are new tokens
  3. Make it allocatable?

But in fact, this would run out of blocks because as one allocates the b0, b1, there are no more blocks for b3.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a sequence has 3 blocks, [b0, b1, b2], and b0, b1 are both cached, but already evicted.

Note that the term "evict" means the block is no longer in the free queue nor in the cached block. It has been re-allocated to store new tokens. I guess what you meant is b0 and b1 are in the free queue but not yet be evicted? In this case yes we could allocate b0 and b1 to another request. If b0 and b1 are evicted, then a new request won't hit the cache.

But in fact, this would run out of blocks because as one allocates the b0, b1, there are no more blocks for b3.

This should not happen, because when we reuse b0 and b1 for request A, both of them will be removed from the free queue (and num_free_blocks - 2).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, I mean in the free queue. I think I might be missing something in the V1 impl still. I guess the scenario I wanted to clarify is if b0 and b1 are in the "free queue".

Signed-off-by: Cody Yu <[email protected]>
@comaniac
Copy link
Collaborator Author

comaniac commented Nov 4, 2024

Re-take at #9972

@comaniac comaniac closed this Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants