From a200257f81fbab3e48bea25fe78232d1f9d2c011 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 24 Oct 2024 20:30:07 +0000 Subject: [PATCH] fix Signed-off-by: Cody Yu --- vllm/v1/core/kv_cache_manager.py | 23 ++++++++++++++++++----- vllm/v1/core/scheduler.py | 11 ++++++++++- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 3d1d19ecc1dfd..b70cf9d80e9f0 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -27,6 +27,13 @@ class KVCacheBlock: # is closer to the end of a prompt and more likely to be evicted. num_hashed_tokens: int = 0 + def reset(self): + self.prev_block_id = None + self.ref_cnt = 0 + self.token_ids.clear() + self.block_hash = None + self.num_hashed_tokens = 0 + class KVCacheManager: @@ -180,8 +187,8 @@ def append_slots( # No new block is needed. We caching is enabled, # then token_id_idx must be equal to len(new_token_ids), # meaning that all tokens are added to allocated blocks. - assert not self.enable_caching or token_id_idx == len( - new_token_ids) + assert not self.enable_caching or token_id_idx == num_tokens, \ + f"{token_id_idx=} != {num_tokens=}" return [] num_new_blocks = num_required_blocks - len(req_block_ids) @@ -195,8 +202,12 @@ def append_slots( if self.enable_caching: new_token_ids = new_token_ids[token_id_idx:] prev_block_id = req_block_ids[-1] - new_block_ids = self._get_new_blocks(num_new_blocks, new_token_ids, - prev_block_id) + else: + new_token_ids = None + prev_block_id = None + new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids, + prev_block_id) + new_block_ids = [blk.block_id for blk in new_blocks] req_block_ids.extend(new_block_ids) return new_block_ids @@ -235,7 +246,8 @@ def allocate_slots( # request, so we must have all new token IDs in the prompt. num_computed_tokens = len(computed_block_ids) * self.block_size if self.enable_caching: - new_token_ids = request.prompt_token_ids[num_computed_tokens:] + new_token_ids = request.prompt_token_ids[ + num_computed_tokens:num_computed_tokens + num_tokens] if not new_token_ids: raise RuntimeError( "Failed to infer the token IDs for allocation. " @@ -337,6 +349,7 @@ def _get_new_blocks( else: del self.cached_block_hash_to_block[block_hash][ curr_block.block_id] + curr_block.reset() curr_block.ref_cnt = 1 ret.append(curr_block) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 41659ff62747d..2587b0fc8876e 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -34,7 +34,7 @@ def __init__( block_size=self.cache_config.block_size, num_gpu_blocks=num_gpu_blocks, sliding_window=self.cache_config.sliding_window, - enable_caching=True) + enable_caching=self.cache_config.enable_prefix_caching) self.block_size = self.cache_config.block_size # Scheduling constraints. @@ -137,6 +137,15 @@ def schedule(self) -> "SchedulerOutput": # `request.num_prompt_tokens` to consider the resumed requests, # which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens + if num_new_tokens == 0: + # FIXME: The happens when prompt length is divisible by + # the block size and all blocks are cached. We have to + # support query_len=0 in model runner to handle this case. + # Now we force to recompute the last block, which hurts + # performance and introduces duplications. + num_computed_tokens -= self.block_size + num_new_tokens = self.block_size + computed_block_ids.pop() num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 new_block_ids = self.kv_cache_manager.allocate_slots(