From c31d4a57a6b639900a7c70b6e844db0116c2f9f6 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 14 Dec 2024 00:51:25 +0900 Subject: [PATCH] [Core] support LoRA and prompt adapter in content-based hashing for Block Manager v2 prefix caching (#8240) --- tests/core/block/test_prefix_caching_block.py | 65 ++++++++++++++++++- tests/core/utils.py | 10 +++ vllm/core/block/block_table.py | 46 +++++++++---- vllm/core/block/common.py | 19 ++++-- vllm/core/block/cpu_gpu_block_allocator.py | 43 ++++++++---- vllm/core/block/interfaces.py | 32 ++++++--- vllm/core/block/naive_block.py | 10 ++- vllm/core/block/prefix_caching_block.py | 55 ++++++++++++---- vllm/core/block_manager.py | 8 ++- vllm/sequence.py | 13 ++++ 10 files changed, 246 insertions(+), 55 deletions(-) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index bbeb4b3a58f2a..29ac3a3c86cb4 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -5,7 +5,7 @@ import pytest -from tests.core.utils import create_dummy_sequence +from tests.core.utils import create_dummy_lora_sequence, create_dummy_sequence from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, @@ -801,6 +801,7 @@ def create_immutable_chain( block_size: int, token_ids: List[int], allocator: PrefixCachingBlockAllocator, + extra_hash: Optional[int] = None, ) -> List[PrefixCachingBlock]: """Helper method which creates a chain of blocks. """ @@ -816,7 +817,9 @@ def create_immutable_chain( block_size:(block_number + 1) * block_size] prev_block = allocator.allocate_immutable_block( - prev_block=prev_block, token_ids=block_token_ids) + prev_block=prev_block, + token_ids=block_token_ids, + extra_hash=extra_hash) blocks.append(prev_block) return blocks @@ -931,3 +934,61 @@ def test_correct_block_hash(): allocator.mark_blocks_as_computed([]) assert tracker.get_num_cached_tokens(seq) == len(tokens) + + @staticmethod + def test_correct_extra_hash(): + """ + Test that the block hash is correctly computed based on the extra hash, + ensuring it matches the allocator's block hash, specifically for the + LoRA case, and that the correct number of cached tokens is retrieved. + """ + block_size = 4 + allocator = CpuGpuBlockAllocator.create( + allocator_type="prefix_caching", + num_gpu_blocks=16, + num_cpu_blocks=16, + block_size=block_size, + ) + gpu_allocator = allocator._allocators[Device.GPU] + + tracker = ComputedBlocksTracker( + allocator=allocator, + block_size=block_size, + enable_caching=True, + ) + + tokens = list(range(block_size * 4)) + + # Create a dummy LoRA sequence with a specific LoRA ID. + lora_seq = create_dummy_lora_sequence(request_id=0, + token_ids=tokens, + block_size=block_size, + lora_int_id=1) + + _ = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=tokens, + allocator=gpu_allocator, + extra_hash=lora_seq.extra_hash(), + ) + + allocator.mark_blocks_as_computed([]) + + # Create different dummy sequences that have the same token IDs + # but different LoRA IDs. + seq = create_dummy_sequence(request_id=1, + token_ids=tokens, + block_size=block_size) + + different_lora_seq = create_dummy_lora_sequence(request_id=2, + token_ids=tokens, + block_size=block_size, + lora_int_id=2) + + # Due to the different LoRA IDs, corresponding blocks are not cached. + assert tracker.get_num_cached_tokens(seq) == 0 + assert tracker.get_num_cached_tokens(different_lora_seq) == 0 + + # The number of cached tokens matches the length of the tokens + # for the cached LoRA sequence. + assert tracker.get_num_cached_tokens(lora_seq) == len(tokens) diff --git a/tests/core/utils.py b/tests/core/utils.py index 277368b57b938..16703cd19fa1e 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -46,6 +46,16 @@ def create_dummy_prompt( return prompt, seq_group +def create_dummy_lora_sequence(request_id: int, token_ids: List[int], + block_size: int, lora_int_id: int) -> Sequence: + return Sequence(seq_id=request_id, + inputs=token_inputs(token_ids), + block_size=block_size, + lora_request=LoRARequest(lora_name="dummy", + lora_path="/dummy", + lora_int_id=lora_int_id)) + + def create_dummy_sequence(request_id: int, token_ids: List[int], block_size: int) -> Sequence: return Sequence( diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index d10cb29ef4a7c..dca0b3fe8d304 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -80,7 +80,8 @@ def get_num_required_blocks(token_ids: List[int], def allocate(self, token_ids: List[int], - device: Device = Device.GPU) -> None: + device: Device = Device.GPU, + extra_hash: Optional[int] = None) -> None: """Allocates memory blocks for storing the given sequence of token IDs. This method allocates the required number of blocks to store the given @@ -90,12 +91,16 @@ def allocate(self, token_ids (List[int]): The sequence of token IDs to be stored. device (Device, optional): The device on which the blocks should be allocated. Defaults to Device.GPU. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefixcaching block. """ assert not self._is_allocated assert token_ids blocks = self._allocate_blocks_for_token_ids(prev_block=None, token_ids=token_ids, - device=device) + device=device, + extra_hash=extra_hash) self.update(blocks) self._num_full_slots = len(token_ids) @@ -108,7 +113,8 @@ def update(self, blocks: List[Block]) -> None: def append_token_ids(self, token_ids: List[int], num_lookahead_slots: int = 0, - num_computed_slots: Optional[int] = None) -> None: + num_computed_slots: Optional[int] = None, + extra_hash: Optional[int] = None) -> None: """Appends a sequence of token IDs to the existing blocks in the BlockTable. @@ -130,6 +136,9 @@ def append_token_ids(self, Without sliding window, None can be passed. Without chunked prefill, it should be the same as _num_full_slots. + extra_hash (Optional[int]): The hash value of additional + factors such as adapters that influence the block, apart + from the token_ids. """ assert self._is_allocated, "no blocks have been allocated" assert len(self._blocks) > 0 @@ -149,7 +158,8 @@ def append_token_ids(self, # Ensure there are enough empty slots for the new tokens plus # lookahead slots self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + - num_lookahead_slots) + num_lookahead_slots, + extra_hash=extra_hash) # Update the blocks with the new tokens first_block_idx = self._num_full_slots // self._block_size @@ -160,7 +170,9 @@ def append_token_ids(self, self._num_full_slots += len(token_ids) - def ensure_num_empty_slots(self, num_empty_slots: int) -> None: + def ensure_num_empty_slots(self, + num_empty_slots: int, + extra_hash: Optional[int] = None) -> None: """Ensures that the BlockTable has at least the specified number of empty slots available. @@ -171,6 +183,9 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None: Args: num_empty_slots (int): The minimum number of empty slots required. + extra_hash (Optional[int]): The hash value of additional + factors such as adapters that influence the block, apart + from the token_ids. """ # Currently the block table only supports # appending tokens to GPU blocks. @@ -187,7 +202,9 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None: assert len(self._blocks) > 0 self._blocks.append( self._allocator.allocate_mutable_block( - prev_block=self._blocks[-1], device=device)) + prev_block=self._blocks[-1], + device=device, + extra_hash=extra_hash)) def fork(self) -> "BlockTable": """Creates a new BlockTable instance with a copy of the blocks from the @@ -259,9 +276,12 @@ def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: # ones after the appended ones. return sequence_token_ids[self.num_full_slots:] - def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], - token_ids: List[int], - device: Device) -> List[Block]: + def _allocate_blocks_for_token_ids( + self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> List[Block]: blocks: List[Block] = [] block_token_ids = [] @@ -275,8 +295,10 @@ def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], if block_token_ids: blocks.extend( self._allocator.allocate_immutable_blocks( - prev_block, block_token_ids=block_token_ids, - device=device)) + prev_block, + block_token_ids=block_token_ids, + device=device, + extra_hash=extra_hash)) prev_block = blocks[-1] if tail_token_ids: @@ -284,7 +306,7 @@ def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], cur_token_ids = tail_token_ids[0] block = self._allocator.allocate_mutable_block( - prev_block=prev_block, device=device) + prev_block=prev_block, device=device, extra_hash=extra_hash) block.append_token_ids(cur_token_ids) blocks.append(block) diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index eb190adfbe802..c03b5932eafb6 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -177,7 +177,8 @@ def __init__(self, block_size: int, create_block: Block.Factory, token_ids=[], block_size=self._block_size, allocator=self._allocator, - block_id=None)) + block_id=None, + extra_hash=None)) def increase_pool(self): """Doubles the internal pool size @@ -194,10 +195,15 @@ def increase_pool(self): token_ids=[], block_size=self._block_size, allocator=self._allocator, - block_id=None)) - - def init_block(self, prev_block: Optional[Block], token_ids: List[int], - block_size: int, physical_block_id: Optional[int]) -> Block: + block_id=None, + extra_hash=None)) + + def init_block(self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + physical_block_id: Optional[int], + extra_hash: Optional[int] = None) -> Block: if len(self._free_ids) == 0: self.increase_pool() assert len(self._free_ids) > 0 @@ -210,7 +216,8 @@ def init_block(self, prev_block: Optional[Block], token_ids: List[int], token_ids=token_ids, block_size=block_size, allocator=block._allocator, # type: ignore[attr-defined] - block_id=physical_block_id) + block_id=physical_block_id, + extra_hash=extra_hash) block.pool_id = pool_id # type: ignore[attr-defined] return block diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3197af3c2b7a4..3a57487a6cd8a 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -121,23 +121,32 @@ def allocate_or_get_null_block(self) -> Block: self.allocate_mutable_block(None, Device.GPU)) return self._null_block - def allocate_mutable_block(self, prev_block: Optional[Block], - device: Device) -> Block: + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Device, + extra_hash: Optional[int] = None) -> Block: """Allocates a new mutable block on the specified device. Args: prev_block (Optional[Block]): The previous block to in the sequence. Used for prefix hashing. device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. Returns: Block: The newly allocated mutable block. """ - return self._allocators[device].allocate_mutable_block(prev_block) - - def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device) -> List[Block]: + return self._allocators[device].allocate_mutable_block( + prev_block, extra_hash=extra_hash) + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + extra_hash: Optional[int] = None) -> List[Block]: """Allocates a new group of immutable blocks with the provided block token IDs on the specified device. @@ -147,17 +156,22 @@ def allocate_immutable_blocks(self, prev_block: Optional[Block], block_token_ids (List[int]): The list of block token IDs to be stored in the new blocks. device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. Returns: List[Block]: The newly allocated list of immutable blocks containing the provided block token IDs. """ return self._allocators[device].allocate_immutable_blocks( - prev_block, block_token_ids) + prev_block, block_token_ids, extra_hash=extra_hash) - def allocate_immutable_block(self, prev_block: Optional[Block], + def allocate_immutable_block(self, + prev_block: Optional[Block], token_ids: List[int], - device: Device) -> Block: + device: Device, + extra_hash: Optional[int] = None) -> Block: """Allocates a new immutable block with the provided token IDs on the specified device. @@ -167,13 +181,16 @@ def allocate_immutable_block(self, prev_block: Optional[Block], token_ids (List[int]): The list of token IDs to be stored in the new block. device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. Returns: Block: The newly allocated immutable block containing the provided token IDs. """ return self._allocators[device].allocate_immutable_block( - prev_block, token_ids) + prev_block, token_ids, extra_hash=extra_hash) def free(self, block: Block) -> None: """Frees the memory occupied by the given block. @@ -387,6 +404,10 @@ def is_full(self): def prev_block(self): return self._proxy.prev_block + @property + def extra_hash(self): + return None + @property def computed(self): return self._proxy.computed diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 06f4851af3466..985a1098b6cd1 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -50,6 +50,11 @@ def is_full(self) -> bool: def prev_block(self) -> Optional["Block"]: pass + @property + @abstractmethod + def extra_hash(self) -> Optional[int]: + return None + @property @abstractmethod def computed(self) -> bool: @@ -81,6 +86,8 @@ def __call__( block_size: int, allocator: "BlockAllocator", block_id: Optional[int] = None, + computed: bool = False, + extra_hash: Optional[int] = None, ) -> "Block": pass @@ -99,18 +106,20 @@ def content_hash(self) -> Optional[int]: class BlockAllocator(ABC): @abstractmethod - def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block: + def allocate_mutable_block(self, prev_block: Optional[Block], + extra_hash: Optional[int]) -> Block: pass @abstractmethod def allocate_immutable_block(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + token_ids: List[int], + extra_hash: Optional[int]) -> Block: pass @abstractmethod - def allocate_immutable_blocks( - self, prev_block: Optional[Block], - block_token_ids: List[List[int]]) -> List[Block]: + def allocate_immutable_blocks(self, prev_block: Optional[Block], + block_token_ids: List[List[int]], + extra_hash: Optional[int]) -> List[Block]: pass @abstractmethod @@ -197,14 +206,18 @@ def find_cached_blocks_prefix( class DeviceAwareBlockAllocator(ABC): @abstractmethod - def allocate_mutable_block(self, prev_block: Optional[Block], - device: Device) -> Block: + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Device, + extra_hash: Optional[int] = None) -> Block: pass @abstractmethod - def allocate_immutable_block(self, prev_block: Optional[Block], + def allocate_immutable_block(self, + prev_block: Optional[Block], token_ids: List[int], - device: Device) -> Block: + device: Device, + extra_hash: Optional[int] = None) -> Block: pass @abstractmethod @@ -213,6 +226,7 @@ def allocate_immutable_blocks( prev_block: Optional[Block], block_token_ids: List[List[int]], device: Device, + extra_hash: Optional[int] = None, ) -> List[Block]: pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index a2af5ad6362c1..9b94918ab38ef 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -63,6 +63,7 @@ def __init__( def allocate_immutable_block(self, prev_block: Optional[Block], token_ids: List[int], + extra_hash: Optional[int] = None, device: Optional[Device] = None) -> Block: """Allocates a new immutable block with the given token IDs, linked to the previous block. @@ -85,6 +86,7 @@ def allocate_immutable_blocks( self, prev_block: Optional[Block], block_token_ids: List[List[int]], + extra_hash: Optional[int] = None, device: Optional[Device] = None) -> List[Block]: assert device is None num_blocks = len(block_token_ids) @@ -106,6 +108,7 @@ def allocate_immutable_blocks( def allocate_mutable_block(self, prev_block: Optional[Block], + extra_hash: Optional[int] = None, device: Optional[Device] = None) -> Block: """Allocates a new mutable block, linked to the previous block. @@ -355,7 +358,8 @@ def __init__(self, block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, - _cow_target: Optional[Block] = None): + _cow_target: Optional[Block] = None, + extra_hash: Optional[int] = None): self._token_ids: List[int] = [] self._block_size = block_size self._prev_block = prev_block @@ -441,6 +445,10 @@ def block_size(self) -> int: def prev_block(self) -> Optional["Block"]: return self._prev_block + @property + def extra_hash(self): + return None + @property def content_hash(self) -> Optional[int]: return None diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index b736167f6ceb4..1238303234deb 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -126,6 +126,7 @@ def _create_block( allocator: BlockAllocator, block_id: Optional[int] = None, computed: bool = False, + extra_hash: Optional[int] = None, ) -> Block: # Bind block to self. allocator = self @@ -137,11 +138,13 @@ def _create_block( block_id=block_id, allocator=allocator, computed=computed, + extra_hash=extra_hash, ) def allocate_immutable_block(self, prev_block: Optional[Block], token_ids: List[int], + extra_hash: Optional[int] = None, device: Optional[Device] = None) -> Block: """Allocates an immutable block with the given token IDs, reusing cached blocks if possible. @@ -160,7 +163,8 @@ def allocate_immutable_block(self, block = self._block_pool.init_block(prev_block=prev_block, token_ids=token_ids, block_size=self._block_size, - physical_block_id=None) + physical_block_id=None, + extra_hash=extra_hash) assert block.content_hash is not None cached_block_id = self._cached_blocks.get(block.content_hash, None) @@ -173,7 +177,7 @@ def allocate_immutable_block(self, self._block_pool.free_block(block) # No cached block => Allocate a new block - block = self.allocate_mutable_block(prev_block) + block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash) block.append_token_ids(token_ids) return block @@ -181,17 +185,20 @@ def allocate_immutable_blocks( self, prev_block: Optional[Block], block_token_ids: List[List[int]], + extra_hash: Optional[int] = None, device: Optional[Device] = None) -> List[Block]: blocks = [] for token_ids in block_token_ids: prev_block = self.allocate_immutable_block(prev_block=prev_block, token_ids=token_ids, - device=device) + device=device, + extra_hash=extra_hash) blocks.append(prev_block) return blocks def allocate_mutable_block(self, prev_block: Optional[Block], + extra_hash: Optional[int] = None, device: Optional[Device] = None) -> Block: """Allocates a mutable block. If there are no free blocks, this will evict unused cached blocks. @@ -210,7 +217,8 @@ def allocate_mutable_block(self, block = self._block_pool.init_block(prev_block=prev_block, token_ids=[], block_size=self._block_size, - physical_block_id=block_id) + physical_block_id=block_id, + extra_hash=extra_hash) assert not block.computed assert block.content_hash is None return block @@ -382,7 +390,8 @@ def fork(self, last_block: Block) -> List[Block]: prev_block=prev_block, token_ids=block.token_ids, block_size=self._block_size, - physical_block_id=block_id) + physical_block_id=block_id, + extra_hash=block.extra_hash) forked_blocks.append(forked_block) prev_block = forked_blocks[-1] @@ -608,10 +617,12 @@ def swap_in(self, blocks: List[Block]) -> None: # existing "block" object if block.is_full: tmp_block = self.allocate_immutable_block( - prev_block=block.prev_block, token_ids=block.token_ids) + prev_block=block.prev_block, + token_ids=block.token_ids, + extra_hash=block.extra_hash) else: tmp_block = self.allocate_mutable_block( - prev_block=block.prev_block) + prev_block=block.prev_block, extra_hash=block.extra_hash) tmp_block.append_token_ids(block.token_ids) block_id = tmp_block.block_id @@ -679,6 +690,8 @@ class PrefixCachingBlock(Block): caching block allocator associated with this block. block_id (Optional[int], optional): The physical block index of this block. Defaults to None. + extra_hash (Optional[int]): The hash value of additional factors + such as adapters that influence the block, apart from the token_ids. """ def __init__( @@ -689,6 +702,7 @@ def __init__( allocator: BlockAllocator, block_id: Optional[int] = None, computed: bool = False, + extra_hash: Optional[int] = None, ): assert isinstance(allocator, PrefixCachingBlockAllocator), ( "Currently this class is only tested with " @@ -702,6 +716,7 @@ def __init__( self._allocator = allocator self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME self._computed = computed + self._extra_hash = extra_hash # On the first time, we create the block object, and next we only # reinitialize it @@ -811,6 +826,10 @@ def token_ids(self) -> List[int]: def prev_block(self) -> Optional[Block]: return self._prev_block + @property + def extra_hash(self) -> Optional[int]: + return self._extra_hash + @property def content_hash(self) -> Optional[int]: """Return the content-based hash of the current block, or None if it is @@ -841,18 +860,19 @@ def content_hash(self) -> Optional[int]: self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( is_first_block, prev_block_hash, - cur_block_token_ids=self.token_ids) + cur_block_token_ids=self.token_ids, + extra_hash=self._extra_hash) return self._cached_content_hash @staticmethod - def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], - cur_block_token_ids: List[int]) -> int: + def hash_block_tokens(is_first_block: bool, + prev_block_hash: Optional[int], + cur_block_token_ids: List[int], + extra_hash: Optional[int] = None) -> int: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. - NOTE: Content-based hashing does not yet support LoRA. - Parameters: - is_first_block (bool): A flag indicating if the block is the first in the sequence. @@ -860,12 +880,15 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], if this is the first block. - cur_block_token_ids (List[int]): A list of token ids in the current block. The current block is assumed to be full. + - extra_hash (Optional[int]): The hash value of additional factors + such as adapters that influence the block, apart from the token_ids. Returns: - int: The computed hash value for the block. """ assert (prev_block_hash is None) == is_first_block - return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) + return hash((is_first_block, prev_block_hash, *cur_block_token_ids, + extra_hash)) class ComputedBlocksTracker: @@ -935,12 +958,18 @@ def _update_seq_hashes(self, seq: Sequence) -> None: assert len(token_ids) >= (i + 1) * self._block_size block_token_ids = token_ids[i * self._block_size:(i + 1) * self._block_size] + + # NOTE: If there are any factors affecting the block besides + # token_ids, they should be added as input to extra_hash. + extra_hash = seq.extra_hash() + # This has to be kept in sync with the allocator's hash # calculation. block_hash = PrefixCachingBlock.hash_block_tokens( is_first_block=prev_block_hash is None, prev_block_hash=prev_block_hash, cur_block_token_ids=block_token_ids, + extra_hash=extra_hash, ) block_hashes_recorded.append(block_hash) prev_block_hash = block_hash diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 209487c6b4f9e..b41e848221882 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -151,8 +151,13 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: max_block_sliding_window=self.max_block_sliding_window, ) if seq.get_token_ids(): + # NOTE: If there are any factors affecting the block besides + # token_ids, they should be added as input to extra_hash. + extra_hash = seq.extra_hash() + # Add blocks to the block table only if the sequence is non empty. - block_table.allocate(seq.get_token_ids()) + block_table.allocate(token_ids=seq.get_token_ids(), + extra_hash=extra_hash) return block_table @@ -238,6 +243,7 @@ def append_slots( token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), num_lookahead_slots=num_lookahead_slots, num_computed_slots=seq.data.get_num_computed_tokens(), + extra_hash=seq.extra_hash(), ) # Return any new copy-on-writes. new_cows = self.block_allocator.clear_copy_on_writes() diff --git a/vllm/sequence.py b/vllm/sequence.py index ddb9ca5944f10..cc3d96fc93a79 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -527,6 +527,19 @@ def hash_of_block(self, logical_idx: int) -> int: hashed_tokens = self.data.get_prefix_token_ids(num_tokens) return hash((hashed_tokens, self.lora_int_id)) + def extra_hash(self) -> Optional[int]: + """ + This function computes an extra hash for a sequence, specifically + designed for prefix caching mode. The final sequence hash is determined + by applying token_ids from the sequence's blocks. + """ + if self.prompt_adapter_id == 0 and self.lora_int_id == 0: + return None + + # NOTE: If there are additional factors influencing the block aside from + # token_ids, include them as input parameters to the hash. + return hash((self.prompt_adapter_id, self.lora_int_id)) + def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size