Skip to content

Commit c6b0a7d

Browse files
authored
[V1] Simplify prefix caching logic by removing num_evictable_computed_blocks (vllm-project#11310)
1 parent a30482f commit c6b0a7d

File tree

1 file changed

+2
-11
lines changed

1 file changed

+2
-11
lines changed

vllm/v1/core/kv_cache_manager.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -201,32 +201,23 @@ def allocate_slots(
201201
f"num_tokens must be greater than 0, got {num_tokens}")
202202

203203
# Touch the computed blocks to make sure they won't be evicted.
204-
num_evictable_computed_blocks = 0
205204
if self.enable_caching:
206205
self._touch(computed_blocks)
207-
208-
# If a computed block of a request is an eviction candidate (in the
209-
# free queue and ref_cnt == 0), it cannot be counted as a free block
210-
# when allocating this request.
211-
num_evictable_computed_blocks = len(
212-
[blk for blk in computed_blocks if blk.ref_cnt == 0])
213206
else:
214207
assert not computed_blocks, (
215208
"Computed blocks should be empty when "
216209
"prefix caching is disabled")
217210

218211
num_required_blocks = cdiv(num_tokens, self.block_size)
219-
if (num_required_blocks > self.free_block_queue.num_free_blocks -
220-
num_evictable_computed_blocks):
212+
if (num_required_blocks > self.free_block_queue.num_free_blocks):
221213
# Cannot allocate new blocks.
222214
return None
223215

224216
# Determine the number of new blocks to allocate considering
225217
# preallocated blocks.
226218
num_new_blocks = min(
227219
num_required_blocks + self.num_preallocate_blocks,
228-
self.free_block_queue.num_free_blocks -
229-
num_evictable_computed_blocks,
220+
self.free_block_queue.num_free_blocks,
230221
# Should not exceed the maximum number of blocks per request.
231222
# This is especially because the block table has the shape
232223
# [..., max_num_blocks_per_req].

0 commit comments

Comments
 (0)