Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Nov 5, 2024
1 parent a0e9a8e commit e6bd231
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,13 @@ def __init__(
self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)

# A Block pool of all kv-cache blocks.
self.block_pool: List[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool)

# {block_hash: {block ID: block}}. A cached block is
Expand Down Expand Up @@ -249,7 +253,7 @@ def append_slots(
# slots, but we cannot allocate new blocks due to the limit.
return None

# Assign token IDs to already allocated blocks.
# When caching is enabled, assign token IDs to already allocated blocks.
new_token_ids = None
parent_block_id = None
if self.enable_caching:
Expand Down Expand Up @@ -343,11 +347,21 @@ def allocate_slots(
num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks)
# Get the token IDs for the blocks being allocated for hashing.
# Note that we expect this function to be called only once per
# request, so we must have all new token IDs in the prompt.

num_computed_tokens = len(computed_block_ids) * self.block_size

# When caching is enabled, get the new token IDs and the parent block
# ID to generate cache keys.
new_token_ids = None
parent_block_id = None
if self.enable_caching:
# Touch the computed blocks to make sure they won't be evicted.
self._touch(computed_block_ids)

# Get the token IDs for the blocks being allocated for hashing.
# Note that we expect allocate_slots to be called only once per
# new request, so num_computed_tokens + num_tokens must be less
# than or equal to the total number of tokens in the prompt.
new_token_ids = request.prompt_token_ids[
num_computed_tokens:num_computed_tokens + num_tokens]
if not new_token_ids:
Expand All @@ -356,15 +370,10 @@ def allocate_slots(
f"#prompt_tokens={len(request.prompt_token_ids)} < "
f"#computed_tokens={num_computed_tokens}")

# Touch the computed blocks to make sure they won't be evicted.
self._touch(computed_block_ids)

# Get the parent block ID to construct the block chain.
parent_block_id = computed_block_ids[
-1] if computed_block_ids else None
else:
new_token_ids = None
parent_block_id = None

new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
parent_block_id)
new_block_ids = [blk.block_id for blk in new_blocks]
Expand Down

0 comments on commit e6bd231

Please sign in to comment.