Skip to content

Commit

Permalink
[V1] Prefix caching (take 2) (#9972)
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Yu <[email protected]>
  • Loading branch information
comaniac authored Nov 8, 2024
1 parent 42b4f46 commit 201fc07
Show file tree
Hide file tree
Showing 6 changed files with 771 additions and 66 deletions.
9 changes: 1 addition & 8 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def main(args):
random.seed(args.seed)
if args.dataset_path is not None:
print(f"Start to sample {args.num_prompts} prompts"
"from {args.dataset_path}")
f"from {args.dataset_path}")
filtered_datasets = sample_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
Expand All @@ -142,13 +142,6 @@ def main(args):
repeat_count=args.repeat_count,
sort=args.sort)

print("------warm up------")
test_prefix(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)

print("------start generating------")
test_prefix(
llm=llm,
Expand Down
219 changes: 219 additions & 0 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""Compare the with and without prefix caching."""
from vllm.inputs import DecoderOnlyInputs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import hash_block_tokens


def make_request(request_id, prompt_token_ids):
return Request(
request_id=request_id,
inputs=DecoderOnlyInputs(prompt_token_ids=prompt_token_ids),
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
lora_request=None,
)


def test_prefill():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=16,
)

# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]

# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]

# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_hash = hash_block_tokens(parent_block_hash,
manager.block_pool[block_id].token_ids)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 16 * (
block_id + 1)
assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16)
parent_block_hash = block_hash

# Check partial/preallocated block metadata
for block_id in (3, 4):
assert manager.block_pool[block_id].block_hash is None
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 0
if block_id == 3:
assert manager.block_pool[block_id].token_ids == [3] * 7
else:
assert not manager.block_pool[block_id].token_ids

# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req1)
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
for block in computed_blocks:
assert block.ref_cnt == 2

# At this point, we should have 3 free blocks left.
assert manager.free_block_queue.num_free_blocks == 3

manager.free(req0)
manager.free(req1)

# All blocks should be available.
assert manager.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
assert [
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]

# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_block = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_block] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8]

# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
assert manager.free_block_queue.num_free_blocks == 5
assert all([
b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks()
])
assert len([b
for b in manager.free_block_queue.get_all_free_blocks()]) == 5

manager.free(req2)

# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 9))
computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks
blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert manager.free_block_queue.num_free_blocks == 0
assert manager.free_block_queue.free_list_head is None
assert manager.free_block_queue.free_list_tail is None


def test_decode():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=16,
)

# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]

# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]

# Append slots without allocating a new block.
req0.num_computed_tokens = 55
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 11

# Append slots without allocating a new block, but start using the
# preallocated block.
req0.num_computed_tokens = 59
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for _ in range(5 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.append_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 16
assert len(manager.block_pool[4].token_ids) == 10

# Append slots with allocating a new block.
req0.num_computed_tokens = 74
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for _ in range(6 + 11):
req0.append_output_token_ids(12)
new_blocks = manager.append_slots(req0, 17)
# Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2
assert len(manager.block_pool[4].token_ids) == 16
assert len(manager.block_pool[5].token_ids) == 11
assert len(manager.block_pool[6].token_ids) == 0


def test_evict():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=16,
)

last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated

# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)))
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16

assert manager.free_block_queue.num_free_blocks == 0

manager.free(req0)
manager.free(req1)
assert manager.free_block_queue.num_free_blocks == 10
assert [
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]

# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [0, 1]
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6
Loading

0 comments on commit 201fc07

Please sign in to comment.