From 499bd7e4a606e53a292e58ebab0e422840a26c09 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 31 Oct 2024 17:41:13 -0700 Subject: [PATCH] fix allocation est Signed-off-by: Cody Yu --- vllm/v1/core/kv_cache_manager.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index d17204d21f9d7..80c3a7d3634cc 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -241,15 +241,25 @@ def allocate_slots( raise ValueError( f"num_tokens must be greater than 0, got {num_tokens}") + # If a computed block is an eviction candidate (in the free queue), + # it cannot be counted as a free block when estimating whether we + # can allocate new blocks for this request. + num_evictable_computed_blocks = len([ + bid for bid in computed_block_ids + if self.block_pool[bid].ref_cnt == 0 + ]) + num_required_blocks = cdiv(num_tokens, self.block_size) - if num_required_blocks > self.num_free_blocks: + if (num_required_blocks > + self.num_free_blocks - num_evictable_computed_blocks): # Cannot allocate new blocks. return None # Determine the number of new blocks to allocate considering # preallocated blocks. - num_new_blocks = min(num_required_blocks + self.num_preallocate_blocks, - self.num_free_blocks) + num_new_blocks = min( + num_required_blocks + self.num_preallocate_blocks, + self.num_free_blocks - num_evictable_computed_blocks) # Get the token IDs for the blocks being allocated for hashing. # Note that we expect this function to be called only once per # request, so we must have all new token IDs in the prompt.