forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[V1] Prefix caching (take 2) (vllm-project#9972)
Signed-off-by: Cody Yu <[email protected]> Signed-off-by: Maxime Fournioux <[email protected]>
- Loading branch information
1 parent
5027d23
commit 9dbf737
Showing
6 changed files
with
771 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
"""Compare the with and without prefix caching.""" | ||
from vllm.inputs import DecoderOnlyInputs | ||
from vllm.sampling_params import SamplingParams | ||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request | ||
from vllm.v1.core.kv_cache_utils import hash_block_tokens | ||
|
||
|
||
def make_request(request_id, prompt_token_ids): | ||
return Request( | ||
request_id=request_id, | ||
inputs=DecoderOnlyInputs(prompt_token_ids=prompt_token_ids), | ||
sampling_params=SamplingParams(max_tokens=17), | ||
eos_token_id=100, | ||
arrival_time=0, | ||
lora_request=None, | ||
) | ||
|
||
|
||
def test_prefill(): | ||
manager = KVCacheManager( | ||
block_size=16, | ||
num_gpu_blocks=10, | ||
sliding_window=False, | ||
enable_caching=True, | ||
num_preallocate_tokens=16, | ||
) | ||
|
||
# Complete 3 blocks (48 tokens) | ||
common_token_ids = [i for i in range(3) for _ in range(16)] | ||
|
||
# Fully cache miss | ||
# Incomplete 1 block (7 tokens) | ||
unique_token_ids = [3] * 7 | ||
req0 = make_request("0", common_token_ids + unique_token_ids) | ||
computed_blocks = manager.get_computed_blocks(req0) | ||
assert not computed_blocks | ||
blocks = manager.allocate_slots(req0, 55, computed_blocks) | ||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] | ||
|
||
# Check full block metadata | ||
parent_block_hash = None | ||
for block_id in (0, 1, 2): | ||
block_hash = hash_block_tokens(parent_block_hash, | ||
manager.block_pool[block_id].token_ids) | ||
assert manager.block_pool[block_id].block_hash == block_hash | ||
assert manager.block_pool[block_id].ref_cnt == 1 | ||
assert manager.block_pool[block_id].num_hashed_tokens == 16 * ( | ||
block_id + 1) | ||
assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16) | ||
parent_block_hash = block_hash | ||
|
||
# Check partial/preallocated block metadata | ||
for block_id in (3, 4): | ||
assert manager.block_pool[block_id].block_hash is None | ||
assert manager.block_pool[block_id].ref_cnt == 1 | ||
assert manager.block_pool[block_id].num_hashed_tokens == 0 | ||
if block_id == 3: | ||
assert manager.block_pool[block_id].token_ids == [3] * 7 | ||
else: | ||
assert not manager.block_pool[block_id].token_ids | ||
|
||
# Cache hit in the common prefix when the original block is still in use. | ||
# Incomplete 1 block (5 tokens) | ||
unique_token_ids = [3] * 5 | ||
req1 = make_request("1", common_token_ids + unique_token_ids) | ||
computed_blocks = manager.get_computed_blocks(req1) | ||
assert [b.block_id for b in computed_blocks] == [0, 1, 2] | ||
num_new_tokens = 53 - 3 * 16 | ||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) | ||
assert [b.block_id for b in blocks] == [5, 6] | ||
for block in computed_blocks: | ||
assert block.ref_cnt == 2 | ||
|
||
# At this point, we should have 3 free blocks left. | ||
assert manager.free_block_queue.num_free_blocks == 3 | ||
|
||
manager.free(req0) | ||
manager.free(req1) | ||
|
||
# All blocks should be available. | ||
assert manager.free_block_queue.num_free_blocks == 10 | ||
# The order should be | ||
# [unallocated (7, 8)] | ||
# [unique_req0 (4, 3)] | ||
# [unique_req1 (6, 5)] | ||
# [common (2, 1, 0)] | ||
assert [ | ||
b.block_id for b in manager.free_block_queue.get_all_free_blocks() | ||
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0] | ||
|
||
# Cache hit in the common prefix when the original block is already free. | ||
# Incomplete 1 block (6 tokens) | ||
unique_token_ids = [3] * 6 | ||
req2 = make_request("2", common_token_ids + unique_token_ids) | ||
computed_block = manager.get_computed_blocks(req2) | ||
assert [b.block_id for b in computed_block] == [0, 1, 2] | ||
num_new_tokens = 53 - 3 * 16 | ||
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) | ||
assert [b.block_id for b in blocks] == [7, 8] | ||
|
||
# Although we only have 5 free blocks, we have 8 blocks in | ||
# the free block queue due to lazy removal. | ||
assert manager.free_block_queue.num_free_blocks == 5 | ||
assert all([ | ||
b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks() | ||
]) | ||
assert len([b | ||
for b in manager.free_block_queue.get_all_free_blocks()]) == 5 | ||
|
||
manager.free(req2) | ||
|
||
# Cache miss and eviction. | ||
req3 = make_request("3", [99] * (16 * 9)) | ||
computed_blocks = manager.get_computed_blocks(req3) | ||
assert not computed_blocks | ||
blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks) | ||
# This block ID order also checks the eviction order. | ||
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] | ||
assert manager.free_block_queue.num_free_blocks == 0 | ||
assert manager.free_block_queue.free_list_head is None | ||
assert manager.free_block_queue.free_list_tail is None | ||
|
||
|
||
def test_decode(): | ||
manager = KVCacheManager( | ||
block_size=16, | ||
num_gpu_blocks=10, | ||
sliding_window=False, | ||
enable_caching=True, | ||
num_preallocate_tokens=16, | ||
) | ||
|
||
# Complete 3 blocks (48 tokens) | ||
common_token_ids = [i for i in range(3) for _ in range(16)] | ||
|
||
# Fully cache miss | ||
# Incomplete 1 block (7 tokens) | ||
unique_token_ids = [3] * 7 | ||
req0 = make_request("0", common_token_ids + unique_token_ids) | ||
computed_blocks = manager.get_computed_blocks(req0) | ||
assert not computed_blocks | ||
blocks = manager.allocate_slots(req0, 55, computed_blocks) | ||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] | ||
|
||
# Append slots without allocating a new block. | ||
req0.num_computed_tokens = 55 | ||
for _ in range(4): | ||
req0.append_output_token_ids(8) | ||
new_blocks = manager.append_slots(req0, 4) | ||
assert new_blocks is not None and len(new_blocks) == 0 | ||
assert len(manager.block_pool[3].token_ids) == 11 | ||
|
||
# Append slots without allocating a new block, but start using the | ||
# preallocated block. | ||
req0.num_computed_tokens = 59 | ||
# 6 tokens to fill the previous block, and 10 tokens to fill | ||
# the preallocated block. | ||
for _ in range(5 + 10): | ||
req0.append_output_token_ids(7) | ||
new_blocks = manager.append_slots(req0, 15) | ||
assert new_blocks is not None and len(new_blocks) == 0 | ||
assert len(manager.block_pool[3].token_ids) == 16 | ||
assert len(manager.block_pool[4].token_ids) == 10 | ||
|
||
# Append slots with allocating a new block. | ||
req0.num_computed_tokens = 74 | ||
# 6 tokens to fill the previous block, and 10 tokens to fill | ||
# the preallocated block. | ||
for _ in range(6 + 11): | ||
req0.append_output_token_ids(12) | ||
new_blocks = manager.append_slots(req0, 17) | ||
# Plus one preallocated block. | ||
assert new_blocks is not None and len(new_blocks) == 2 | ||
assert len(manager.block_pool[4].token_ids) == 16 | ||
assert len(manager.block_pool[5].token_ids) == 11 | ||
assert len(manager.block_pool[6].token_ids) == 0 | ||
|
||
|
||
def test_evict(): | ||
manager = KVCacheManager( | ||
block_size=16, | ||
num_gpu_blocks=10, | ||
sliding_window=False, | ||
enable_caching=True, | ||
num_preallocate_tokens=16, | ||
) | ||
|
||
last_token_id = 5 * 16 + 7 | ||
req0 = make_request("0", list(range(last_token_id))) | ||
computed_blocks = manager.get_computed_blocks(req0) | ||
assert not computed_blocks | ||
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) | ||
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated | ||
|
||
# 3 blocks. | ||
req1 = make_request("1", list(range(last_token_id, | ||
last_token_id + 3 * 16))) | ||
computed_blocks = manager.get_computed_blocks(req1) | ||
assert not computed_blocks | ||
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) | ||
assert len(blocks) == 3 # 3 full blocks | ||
last_token_id += 3 * 16 | ||
|
||
assert manager.free_block_queue.num_free_blocks == 0 | ||
|
||
manager.free(req0) | ||
manager.free(req1) | ||
assert manager.free_block_queue.num_free_blocks == 10 | ||
assert [ | ||
b.block_id for b in manager.free_block_queue.get_all_free_blocks() | ||
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7] | ||
|
||
# Touch the first 2 blocks. | ||
req2 = make_request("2", list(range(2 * 16 + 3))) | ||
computed_blocks = manager.get_computed_blocks(req2) | ||
assert [b.block_id for b in computed_blocks] == [0, 1] | ||
blocks = manager.allocate_slots(req2, 3, computed_blocks) | ||
assert [b.block_id for b in blocks] == [6, 5] | ||
assert manager.free_block_queue.num_free_blocks == 6 |
Oops, something went wrong.