Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
revert
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Yu <[email protected]>
comaniac authored and sakunkun committed Dec 31, 2024

Unverified

This user has not yet uploaded their public signing key.
1 parent aaf2649 commit f570cd3
Showing 3 changed files with 100 additions and 122 deletions.
159 changes: 70 additions & 89 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -44,11 +44,10 @@ def test_prefill():
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks_and_num_evictable = manager.get_computed_blocks(req0)
computed_blocks = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req0, 55,
computed_blocks_and_num_evictable)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]

# Check full block metadata
@@ -69,15 +68,13 @@ def test_prefill():
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks_and_num_evictable = manager.get_computed_blocks(req1)
computed_blocks = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert [b.block_id
for b in computed_blocks_and_num_evictable[0]] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
computed_blocks_and_num_evictable)
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
for block in computed_blocks_and_num_evictable[0]:
for block in computed_blocks:
assert block.ref_cnt == 2

# At this point, we should have 3 free blocks left.
@@ -101,13 +98,11 @@ def test_prefill():
# Incomplete 1 block (6 tokens)
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks_and_num_evictable = manager.get_computed_blocks(req2)
computed_blocks = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert [b.block_id
for b in computed_blocks_and_num_evictable[0]] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens,
computed_blocks_and_num_evictable)
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8]

# Although we only have 5 free blocks, we have 8 blocks in
@@ -123,10 +118,9 @@ def test_prefill():

# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 9))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req3)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req3, 16 * 9,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert manager.free_block_queue.num_free_blocks == 0
@@ -151,10 +145,9 @@ def test_decode():
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks_and_num_evictable = manager.get_computed_blocks(req0)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req0, 55,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]

# Append slots without allocating a new block.
@@ -199,19 +192,17 @@ def test_evict():

last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req0)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated

# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req1)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req1, 3 * 16,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16

@@ -226,9 +217,9 @@ def test_evict():

# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks_and_num_evictable[0]] == [0, 1]
blocks = manager.allocate_slots(req2, 3, computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [0, 1]
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6

@@ -251,10 +242,9 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req, num_tokens,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
assert len(blocks) == 1

# Deallocate the block.
@@ -263,10 +253,9 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req, num_tokens - 1,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1

assert manager.block_pool[blocks[0].block_id].block_hash is None
@@ -290,19 +279,17 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req0)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req0, num_tokens,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 0

# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req1)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req1, num_tokens,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1

@@ -313,12 +300,12 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req2)
assert len(computed_blocks_and_num_evictable[0]) == 1
assert computed_blocks_and_num_evictable[0][0].block_id == 0
computed_blocks = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 0

blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks_and_num_evictable)
computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1

@@ -339,28 +326,26 @@ def test_basic_prefix_caching_disabled():

req1 = make_request("1", list(range(10))) # 2 blocks and some more

computed_blocks_and_num_evictable = manager.get_computed_blocks(req1)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req1, 10,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, 10, computed_blocks)
assert len(blocks) == 3

# Free the blocks.
manager.free(req1)

# No caching.
req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks_and_num_evictable = manager.get_computed_blocks(req2)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req2, 16,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req2)
assert not computed_blocks
blocks = manager.allocate_slots(req2, 16, computed_blocks)
assert len(blocks) == 4

# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req3)
assert not computed_blocks_and_num_evictable[0]
blocks = manager.allocate_slots(req3, 4, computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks
blocks = manager.allocate_slots(req3, 4, computed_blocks)
assert not blocks


@@ -381,11 +366,10 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)

req = make_request("0", list(range(block_size * 30)))
computed_blocks_and_num_evictable = manager.get_computed_blocks(req)
assert not computed_blocks_and_num_evictable[0]
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size,
computed_blocks_and_num_evictable)
blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks

@@ -480,17 +464,16 @@ def test_mm_prefix_caching():
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks_and_num_evictable = manager.get_computed_blocks(req0)
computed_blocks = manager.get_computed_blocks(req0)

# Completed block should have hashes with extra keys.
assert not computed_blocks_and_num_evictable[0]
assert not computed_blocks
assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), )
assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0))
assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), )

blocks = manager.allocate_slots(req0, 59,
computed_blocks_and_num_evictable)
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
req0.num_computed_tokens = 59

@@ -515,8 +498,8 @@ def test_mm_prefix_caching():
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks_and_num_evictable = manager.get_computed_blocks(req1)
assert len(computed_blocks_and_num_evictable[0]) == 3
computed_blocks = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3


def test_prefill_not_enough_free_blocks_with_computed_blocks():
@@ -539,16 +522,16 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids)
computed_blocks_and_num_evictable = manager.get_computed_blocks(req0)
assert not computed_blocks_and_num_evictable[0]
manager.allocate_slots(req0, 48, computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
manager.allocate_slots(req0, 48, computed_blocks)
block_part0 = manager.req_to_blocks[req0.request_id]

# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2)
computed_blocks_and_num_evictable = manager.get_computed_blocks(req1)
assert computed_blocks_and_num_evictable[0] == block_part0
manager.allocate_slots(req1, 48, computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req1)
assert computed_blocks == block_part0
manager.allocate_slots(req1, 48, computed_blocks)
block_part1 = manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| ... |
@@ -559,21 +542,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2)
computed_blocks_and_num_evictable = manager.get_computed_blocks(req2)
assert not computed_blocks_and_num_evictable[0]
manager.allocate_slots(req2, block_size * 2,
computed_blocks_and_num_evictable)
computed_blocks = manager.get_computed_blocks(req2)
assert not computed_blocks
manager.allocate_slots(req2, block_size * 2, computed_blocks)

# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
computed_blocks_and_num_evictable = manager.get_computed_blocks(req3)
assert computed_blocks_and_num_evictable[0] == block_part1
computed_blocks = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1
# Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48,
computed_blocks_and_num_evictable) is None
assert manager.allocate_slots(req3, 48, computed_blocks) is None
# Block 0-2 are used by Req 1.
assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free.
44 changes: 16 additions & 28 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional

from vllm.logger import init_logger
from vllm.utils import cdiv
@@ -69,30 +69,21 @@ def __init__(
# is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}

def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
When prompt length is divisible by the block size and all blocks
are computed, we force to recompute the last block. Note that we
have to re-compute an entire block because allocate_slots()
assumes num_computed_tokens is always a multiple of the block size.
This limitation can potentially be removed in the future to slightly
improve the performance.
Args:
request: The request to get the computed blocks.
Returns:
A list of blocks that are computed for the request, and the number
of free blocks (eviction candidates).
A list of blocks that are computed for the request.
"""
if not self.enable_caching:
# Prefix caching is disabled.
return [], 0
return []

computed_blocks, num_free_blocks = [], 0
computed_blocks = []

# The block hashes for the request may already be computed
# if the request was preempted and resumed.
@@ -107,17 +98,10 @@ def get_computed_blocks(
# not computed yet for sure.
if cached_block := self._get_cached_block(block_hash):
computed_blocks.append(cached_block)
num_free_blocks += cached_block.ref_cnt == 0
else:
break

# Remove the last computed block if all blocks are computed.
num_new_tokens = (request.num_tokens -
len(computed_blocks) * self.block_size)
if num_new_tokens == 0:
num_free_blocks -= computed_blocks.pop().ref_cnt == 0

return computed_blocks, num_free_blocks
return computed_blocks

def append_slots(
self,
@@ -199,16 +183,15 @@ def allocate_slots(
self,
request: Request,
num_tokens: int,
computed_blocks_and_num_evictable: Tuple[List[KVCacheBlock], int],
computed_blocks: List[KVCacheBlock],
) -> Optional[List[KVCacheBlock]]:
"""Allocate slots for a new request.
Args:
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks_and_num_evictable: A tuple of computed blocks
and the number of them that are currently in the free queue.
computed_blocks: A list of computed blocks.
Returns:
A list of new allocated blocks.
@@ -217,10 +200,15 @@ def allocate_slots(
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")

computed_blocks, num_evictable = computed_blocks_and_num_evictable
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = len(
[blk for blk in computed_blocks if blk.ref_cnt == 0])

num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks >
self.free_block_queue.num_free_blocks - num_evictable):
if (num_required_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
# Cannot allocate new blocks.
return None

19 changes: 14 additions & 5 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -185,9 +185,8 @@ def schedule(self) -> "SchedulerOutput":

request = self.waiting[0]
# Get already-cached tokens.
computed_blocks_and_num_evictable = (
self.kv_cache_manager.get_computed_blocks(request))
computed_blocks, _ = computed_blocks_and_num_evictable
computed_blocks = self.kv_cache_manager.get_computed_blocks(
request)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
@@ -197,8 +196,18 @@ def schedule(self) -> "SchedulerOutput":
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0:
# The happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last block. Note that we have to re-compute an entire
# block because allocate_slots() assumes num_computed_tokens
# is always a multiple of the block size. This limitation
# can potentially be removed in the future to slightly
# improve the performance.
num_computed_tokens -= self.block_size
num_new_tokens = self.block_size
computed_blocks.pop()
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0

# Schedule encoder inputs.
(encoder_inputs_to_schedule, num_new_tokens,
@@ -210,7 +219,7 @@ def schedule(self) -> "SchedulerOutput":
break

new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks_and_num_evictable)
request, num_new_tokens, computed_blocks)
if new_blocks is None:
# The request cannot be scheduled.
break

0 comments on commit f570cd3

Please sign in to comment.