diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 8eb6c8ad7606b..6ee03dc2258b8 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -117,7 +117,8 @@ def main(args): input_length_range = tuple(map(int, args.input_length_range.split(':'))) random.seed(args.seed) if args.dataset_path is not None: - print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}") + print(f"Start to sample {args.num_prompts} prompts " + f"from {args.dataset_path}") filtered_datasets = sample_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, diff --git a/tests/core/block/test_block_manager.py b/tests/core/block/test_block_manager.py index 16e2876c015b1..742ebf870ac5a 100644 --- a/tests/core/block/test_block_manager.py +++ b/tests/core/block/test_block_manager.py @@ -256,10 +256,9 @@ def test_can_allocate_with_prefix_cache( # Allocate the seq 1 block_manager.allocate(seq_group_1) - # Mark the seq 1 as computed (This shoudl be done by the scheduler in reality) - block_manager.mark_blocks_as_computed( - seq_group=seq_group_1, token_chunk_size=len(tokens_1) - ) + # Mark the seq 1 as computed (This should be done by the scheduler in reality) + block_manager.mark_blocks_as_computed(seq_group=seq_group_1, + token_chunk_size=len(tokens_1)) # Test if allocatable of seq 2. seq_group_2 = create_seq_group( @@ -399,7 +398,9 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, watermark=0, enable_caching=enable_caching) prompt, seq_group = create_dummy_prompt( - "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1, block_size=block_size + "1", + prompt_length=(num_gpu_blocks - 1) * block_size - 1, + block_size=block_size, ) prompt.status = SequenceStatus.WAITING block_manager.allocate(seq_group) diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index 06f1e297695a9..54c49657b737d 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -105,8 +105,7 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int): block_size=block_size, block_allocator=allocator, enable_prefix_caching=True, - ) - ) + )) seq = make_sequence(alloc_i, token_ids, block_size) block_tables[-1].allocate(seq=seq, device=Device.GPU) @@ -148,7 +147,8 @@ def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str, block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) for i in range(5): @@ -193,7 +193,8 @@ def test_append_token_ids_allocation(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) num_expected_blocks_before_append = len( @@ -250,7 +251,8 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) num_expected_blocks_before_append = len( @@ -308,7 +310,8 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) seq = make_sequence(0, token_ids, block_size) block_table.allocate(seq=seq, device=Device.GPU) @@ -353,7 +356,8 @@ def test_fork(seq_len: int, block_size: int, allocator_type: str): block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) seq = make_sequence(0, token_ids, block_size) @@ -414,7 +418,8 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, original_block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) num_expected_non_cow_blocks = cdiv(sequence_len, block_size) @@ -504,7 +509,8 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, original_block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) seq = make_sequence(0, token_ids, block_size) @@ -590,7 +596,8 @@ def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) seq = make_sequence(0, token_ids, block_size) block_table.allocate(seq=seq, device=Device.GPU) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 8c8690bc23d1b..8195888f6bef8 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -795,9 +795,8 @@ def test_get_cached_blocks(): block_size = 16 num_blocks = 5 - allocator = PrefixCachingBlockAllocator( - block_size=block_size, num_blocks=num_blocks - ) + allocator = PrefixCachingBlockAllocator(block_size=block_size, + num_blocks=num_blocks) # 1. Allocate a list of blocks block_hashes = [random.randint(1, 1000000) for _ in range(num_blocks)] @@ -825,12 +824,11 @@ def test_get_cached_blocks(): result = allocator.get_cached_blocks(cached_hashes) assert ( result == expected_cached_blocks - ), f"Expected {expected_cached_blocks}, but got {result}, with test case {cached_hashes}. blcok hashes = {block_hashes}" + ), f"Expected {expected_cached_blocks}, but got {result}, with test case {cached_hashes}. block hashes = {block_hashes}" # Test with some non-existent hashes non_existent_hash = max(block_hashes) + 1 test_hashes = block_hashes[:3] + [non_existent_hash] + block_hashes[3:] result = allocator.get_cached_blocks(test_hashes) - assert ( - result == block_hashes[0:3] - ), f"Expected {block_hashes[0:3]}, but got {result}" + assert (result == block_hashes[0:3] + ), f"Expected {block_hashes[0:3]}, but got {result}" diff --git a/tests/core/utils.py b/tests/core/utils.py index 6a539db9fc3f7..5c2e9e9c9ca1e 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -116,7 +116,7 @@ def create_dummy_prompt_encoder_decoder( def create_seq_group( seq_prompt_len: int = 1024, - seq_output_lens: GenericSequence[int] = (128,), + seq_output_lens: GenericSequence[int] = (128, ), request_id: str = "0", seq_id_start: int = 0, sampling_params: Optional[SamplingParams] = None, diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 9b3782e97fe72..6b6ab7a85a5b9 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -43,10 +43,10 @@ def test_mixed_requests( cached_prompt = example_prompts[cached_position] with vllm_runner( - model, - dtype=dtype, - enable_prefix_caching=True, - enable_chunked_prefill=enable_chunked_prefill, + model, + dtype=dtype, + enable_prefix_caching=True, + enable_chunked_prefill=enable_chunked_prefill, ) as vllm_model: # Run the first prompt so the cache is populated vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 74962541dfb22..1554f306903e9 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -55,30 +55,6 @@ def __init__( self._max_block_sliding_window = max_block_sliding_window self._num_full_slots = self._get_num_token_ids() - # @staticmethod - # def get_num_required_blocks(token_ids: List[int], - # block_size: int, - # num_lookahead_slots: int = 0) -> int: - # """Calculates the minimum number of blocks required to store a given - # sequence of token IDs along with any look-ahead slots that may be - # required (like in multi-step + chunked-prefill). - - # This assumes worst-case scenario, where every block requires a new - # allocation (e.g. ignoring prefix caching). - - # Args: - # token_ids (List[int]): The sequence of token IDs to be stored. - # block_size (int): The maximum number of tokens that can be stored in - # a single block. - # num_lookahead_slots (int): look-ahead slots that the sequence may - # require. - - # Returns: - # int: The minimum number of blocks required to store the given - # sequence of token IDs along with any required look-ahead slots. - # """ - # return cdiv(len(token_ids) + num_lookahead_slots, block_size) - def allocate( self, token_ids: List[int], @@ -100,14 +76,13 @@ def allocate( if not token_ids: return - blocks = self._allocate_blocks_for_token_ids( - token_ids, block_hashes, device - ) + blocks = self._allocate_blocks_for_token_ids(token_ids, block_hashes, + device) self.update(blocks) self._num_full_slots = len(token_ids) def update(self, blocks: List[Block]) -> None: - """Resets the table to the newly provided blocks + """Resets the table to the newly provided blocks (with their corresponding block ids) """ self._blocks.update(blocks) @@ -164,17 +139,14 @@ def append_slots( # Update the blocks with the new tokens first_block_idx = self._num_full_slots // self._block_size token_blocks = self._chunk_token_blocks_for_append(token_ids) - - if len(token_blocks) != len(block_hashes): - breakpoint() - assert len(token_blocks) == len( block_hashes ), "chunked token_ids and block_hashes must have the same length" for i, token_block in enumerate(token_blocks): block_hash = block_hashes[i] - self._blocks.append_token_ids(first_block_idx + i, token_block, block_hash) + self._blocks.append_token_ids(first_block_idx + i, token_block, + block_hash) self._num_full_slots += len(token_ids) @@ -304,10 +276,9 @@ def _allocate_blocks_for_token_ids( self._allocator.allocate_immutable_blocks( prev_block, block_token_ids=block_token_ids, - block_hashes=block_hashes[: len(block_token_ids)], + block_hashes=block_hashes[:len(block_token_ids)], device=device, - ) - ) + )) prev_block = blocks[-1] if tail_token_ids: @@ -315,8 +286,7 @@ def _allocate_blocks_for_token_ids( assert block_hashes[-1] is None cur_token_ids = tail_token_ids[0] block = self._allocator.allocate_mutable_block( - prev_block=prev_block, device=device - ) + prev_block=prev_block, device=device) block.append_token_ids(cur_token_ids, block_hash=None) blocks.append(block) diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index c2117ccaaeb50..6815a3808876e 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -254,9 +254,8 @@ def update(self, blocks: List[Block]): for block in self._blocks: self._add_block_id(block.block_id) - def append_token_ids( - self, block_index: int, token_ids: List[int], block_hash: Optional[int] - ) -> None: + def append_token_ids(self, block_index: int, token_ids: List[int], + block_hash: Optional[int]) -> None: block = self._blocks[block_index] prev_block_id = block.block_id diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 1674514f354e7..0c8c00a461e99 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -139,7 +139,7 @@ def allocate_immutable_blocks( prev_block: Optional[Block], block_token_ids: List[List[int]], device: Device, - block_hashes: Optional[List[Optional[int]]] = None, + block_hashes: List[Optional[int]], ) -> List[Block]: """Allocates a new group of immutable blocks with the provided block token IDs on the specified device. @@ -156,15 +156,7 @@ def allocate_immutable_blocks( containing the provided block token IDs. """ return self._allocators[device].allocate_immutable_blocks( - prev_block, block_token_ids, block_hashes - ) - - def get_allocated_cached_blocks( - self, - block_hashes: List[int], - device: Device, - ) -> List[int]: - return self._allocators[device].get_allocated_cached_blocks(block_hashes) + prev_block, block_token_ids, block_hashes) def allocate_immutable_block(self, prev_block: Optional[Block], token_ids: List[int], @@ -353,14 +345,12 @@ def get_and_reset_swaps(self) -> List[Tuple[int, int]]: self._swap_mapping.clear() return list(mapping.items()) - def find_cached_blocks_prefix( - self, block_hashes: List[int], allocated: bool - ) -> List[int]: + def find_cached_blocks_prefix(self, block_hashes: List[int], + allocated: bool) -> List[int]: # Prefix caching only supported on GPU. device = Device.GPU return self._allocators[device].find_cached_blocks_prefix( - block_hashes, allocated - ) + block_hashes, allocated) class NullBlock(Block): @@ -376,7 +366,9 @@ def __init__(self, proxy: Block): super().__init__() self._proxy = proxy - def append_token_ids(self, token_ids: List[BlockId]): + def append_token_ids(self, + token_ids: List[BlockId], + block_hash: Optional[int] = None) -> None: raise ValueError("null block should not be modified") @property @@ -429,4 +421,5 @@ def content_hash(self): return self._proxy.content_hash def set_content_hash(self, content_hash: Optional[int]) -> None: - raise NotImplementedError("NullBlock does not support set_content_hash") + raise NotImplementedError( + "NullBlock does not support set_content_hash") diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index f55133f92ee0c..08149a0aabb17 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -10,9 +10,9 @@ class Block(ABC): @abstractmethod - def append_token_ids( - self, token_ids: List[int], block_hash: Optional[int] = None - ) -> None: + def append_token_ids(self, + token_ids: List[int], + block_hash: Optional[int] = None) -> None: pass @property @@ -204,9 +204,8 @@ def get_prefix_cache_hit_rate(self) -> float: pass @abstractmethod - def find_cached_blocks_prefix( - self, block_hashes: List[int], allocated: bool - ) -> List[int]: + def find_cached_blocks_prefix(self, block_hashes: List[int], + allocated: bool) -> List[int]: pass class NoFreeBlocksError(ValueError): @@ -227,9 +226,13 @@ def allocate_immutable_block(self, prev_block: Optional[Block], pass @abstractmethod - def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device) -> List[Block]: + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + block_hashes: List[Optional[int]], + ) -> List[Block]: pass @abstractmethod @@ -306,13 +309,6 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: pass @abstractmethod - def get_allocated_cached_blocks( - self, block_hashes: List[int], device: Device - ) -> List[int]: + def find_cached_blocks_prefix(self, block_hashes: List[int], + allocated: bool) -> List[int]: pass - - @abstractmethod - def find_cached_blocks_prefix( - self, block_hashes: List[int], allocated: bool, device: Device - ) -> List[int]: - pass \ No newline at end of file diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 83725faef6d39..9bc4a316555a3 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -85,10 +85,10 @@ def allocate_immutable_block( return block def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - block_hashes: Optional[List[Optional[int]]] = None, + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + block_hashes: List[Optional[int]], # Not used ) -> List[Block]: num_blocks = len(block_token_ids) @@ -335,12 +335,9 @@ def swap_in(self, blocks: List[Block]) -> None: def get_prefix_cache_hit_rate(self) -> float: return -1 - def get_allocated_cached_blocks(self, block_hashes: List[int]) -> List[int]: - return [] - - def find_cached_blocks_prefix( - self, block_hashes: List[int], allocated: bool = False - ) -> List[int]: + def find_cached_blocks_prefix(self, + block_hashes: List[int], + allocated: bool = False) -> List[int]: return [] @@ -382,9 +379,9 @@ def __init__(self, self._append_token_ids_no_cow(token_ids) - def append_token_ids( - self, token_ids: List[int], block_hash: Optional[int] = None - ) -> None: + def append_token_ids(self, + token_ids: List[int], + block_hash: Optional[int] = None) -> None: """Appends the given token IDs to the block and performs a copy-on-write if necessary. @@ -465,7 +462,6 @@ def prev_block(self) -> Optional["Block"]: def content_hash(self) -> Optional[int]: return None - def set_content_hash(self, content_hash: int) -> None: + def set_content_hash(self, content_hash: Optional[int]) -> None: raise NotImplementedError( - "Setting content hash is not supported for naive block" - ) + "Setting content hash is not supported for naive block") diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 035a9a8dd92f9..61e1fd56ce970 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -4,7 +4,13 @@ from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device +from vllm.core.block.interfaces import ( + Block, + BlockAllocator, + BlockId, + Device, + DeviceAwareBlockAllocator, +) from vllm.core.block.naive_block import (BlockPool, NaiveBlock, NaiveBlockAllocator) from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor @@ -153,7 +159,8 @@ def allocate_immutable_block( Returns: Block: The allocated immutable block. """ - assert len(token_ids) == self._block_size, "An immutable block should be full" + assert (len(token_ids) == self._block_size + ), "An immutable block should be full" assert ( block_hash is not None ), "An immutable block should have a content hash for prefix caching" @@ -405,7 +412,8 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int: assert device is None # The number of free blocks is the number of hashless free blocks # plus the number of blocks evictor could free from its list. - return self._hashless_allocator.get_num_free_blocks() + self.evictor.num_blocks + return (self._hashless_allocator.get_num_free_blocks() + + self.evictor.num_blocks) def get_num_total_blocks(self) -> int: return self._hashless_allocator.get_num_total_blocks() @@ -638,18 +646,16 @@ def swap_in(self, blocks: List[Block]) -> None: # and the block_id is assigned to "block" to allow reusing the # existing "block" object if block.is_full: - assert ( - block.content_hash is not None - ), "Block is full but has no content hash" + assert (block.content_hash + is not None), "Block is full but has no content hash" tmp_block = self.allocate_immutable_block( prev_block=block.prev_block, token_ids=block.token_ids, block_hash=block.content_hash, ) else: - assert ( - block.content_hash is None - ), "Block is not full but has content hash" + assert (block.content_hash is + None), "Block is not full but has content hash" tmp_block = self.allocate_mutable_block( prev_block=block.prev_block) tmp_block.append_token_ids(block.token_ids, block_hash=None) @@ -659,9 +665,9 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id - def find_cached_blocks_prefix( - self, block_hashes: List[PrefixHash], allocated: bool = False - ) -> List[PrefixHash]: + def find_cached_blocks_prefix(self, + block_hashes: List[PrefixHash], + allocated: bool = False) -> List[PrefixHash]: """ Return the prefix of the block hashes that are already computed and cached. @@ -686,9 +692,9 @@ def block_is_cached(block_hash: PrefixHash) -> bool: # Look for the first block that's not cached, and returns the prefix # , i.e. blocks that are cached. - idx = bisect_left( - block_hashes, True, key=lambda x: not block_is_cached(x) - ) + idx = bisect_left(block_hashes, + True, + key=lambda x: not block_is_cached(x)) return block_hashes[:idx] @@ -724,9 +730,7 @@ def __init__( assert isinstance(allocator, PrefixCachingBlockAllocator), ( "Currently this class is only tested with " "PrefixCachingBlockAllocator. Got instead allocator = {}".format( - allocator - ) - ) + allocator)) assert_prefix_caching_block_or_none(prev_block) self._prev_block = prev_block @@ -788,9 +792,9 @@ def last_accessed(self) -> float: def last_accessed(self, last_accessed_ts: float): self._last_accessed = last_accessed_ts - def append_token_ids( - self, token_ids: List[int], block_hash: Optional[int] = None - ) -> None: + def append_token_ids(self, + token_ids: List[int], + block_hash: Optional[int] = None) -> None: """Appends the given token IDs to the block and registers the block as immutable if the block becomes full. @@ -860,9 +864,8 @@ def set_content_hash(self, content_hash: Optional[int]) -> None: assert self.content_hash is None, "Content hash already set" if content_hash is None: # This could happen when forking a mutable block. - assert ( - not self.is_full - ), "Block should not be full when new content hash is None" + assert (not self.is_full + ), "Block should not be full when new content hash is None" # No op. return assert self.is_full, "Block is not full when setting content hash" @@ -901,7 +904,10 @@ class ComputedBlocksTracker: """ def __init__( - self, allocator: BlockAllocator, block_size: int, enable_caching: bool + self, + allocator: DeviceAwareBlockAllocator, + block_size: int, + enable_caching: bool, ): self._allocator = allocator self._block_size = block_size @@ -950,28 +956,26 @@ def update_seq(self, seq: Sequence) -> None: f" already recorded {cur_num_blocks_recorded} blocks. " "This should not happen since we assume blocks are " "only added. When the sequence is recomputed, we should have " - "removed the info of the old blocks." - ) + "removed the info of the old blocks.") # Update the computed block hashes for the sequence num_total_blocks = len(token_ids) // self._block_size # We need to know the hash of the previous block to compute the hash of # the current block so that blocks could be uniquely identified across # sequences of prefixes. - prev_block_hash = ( - None if cur_num_blocks_recorded == 0 else block_hashes[-1] - ) + prev_block_hash = (None if cur_num_blocks_recorded == 0 else + block_hashes[-1]) # Only update the computed block hashes for the new blocks for i in range(cur_num_blocks_recorded, num_total_blocks): - block_hash = seq.hash_of_block( - prev_block_hash=prev_block_hash, cur_block_idx=i - ) + block_hash = seq.hash_of_block(prev_block_hash=prev_block_hash, + cur_block_idx=i) block_hashes.append(block_hash) prev_block_hash = block_hash - def get_block_hashes( - self, seq: Sequence, start_block_idx: int = 0, end_block_idx: int = -1 - ) -> List[Optional[int]]: + def get_block_hashes(self, + seq: Sequence, + start_block_idx: int = 0, + end_block_idx: int = -1) -> List[Optional[int]]: """ Returns the list of block hashes for the sequence. If the last block is not yet full, its hash is None. @@ -994,17 +998,16 @@ def get_block_hashes( if seq.seq_id not in self._full_blocks_hashes: self.update_seq(seq) - seq_block_hashes = self._full_blocks_hashes[seq.seq_id] + seq_block_hashes: List[Optional[int]] = self._full_blocks_hashes[ + seq.seq_id] # type: ignore num_full_blocks = len(seq_block_hashes) assert num_blocks - num_full_blocks <= 1, ( "There should only be at most one block in the end of the " f"sequence that's not yet computed and full. Got {num_blocks} " - f"blocks, {num_full_blocks} full blocks." - ) + f"blocks, {num_full_blocks} full blocks.") # Add the None block if the last block is not yet full. - seq_block_hashes = seq_block_hashes + [None] * ( - num_blocks - num_full_blocks - ) + seq_block_hashes = seq_block_hashes + [None] * (num_blocks - + num_full_blocks) return seq_block_hashes[start_block_idx:end_block_idx] def get_num_tokens_computed(self, seq: Sequence, allocated: bool) -> int: @@ -1031,8 +1034,7 @@ def get_num_tokens_computed(self, seq: Sequence, allocated: bool) -> int: self.update_seq(seq) num_computed_tokens_prev = self._num_tokens_computed.get( - (seq.seq_id, allocated), None - ) + (seq.seq_id, allocated), None) if num_computed_tokens_prev is not None and seq.is_prefill(): # For a sequence that is still in prefill, we don't have to # recompute the number of cached tokens. @@ -1046,8 +1048,7 @@ def get_num_tokens_computed(self, seq: Sequence, allocated: bool) -> int: # This is currently O(logN), where N is the number of blocks. num_cached_blocks = len( - self._allocator.find_cached_blocks_prefix(block_hashes, allocated) - ) + self._allocator.find_cached_blocks_prefix(block_hashes, allocated)) num_cached_tokens = num_cached_blocks * self._block_size self._num_tokens_computed[(seq.seq_id, allocated)] = num_cached_tokens return num_cached_tokens diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 51a5a66928cbd..f80166ca28789 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -5,6 +5,7 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.interfaces import Block from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec @@ -100,14 +101,13 @@ def __init__( self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} self._computed_blocks_tracker = ComputedBlocksTracker( - self.block_allocator, self.block_size, enable_caching - ) + self.block_allocator, self.block_size, enable_caching) self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) - def _get_num_blocks_to_allocate( - self, seq: Sequence, num_lookahead_slots: int = 0 - ) -> int: + def _get_num_blocks_to_allocate(self, + seq: Sequence, + num_lookahead_slots: int = 0) -> int: """ Get the number of new blocks to allocate for a sequence. @@ -127,7 +127,7 @@ def _get_num_blocks_to_allocate( # allocated, and so we should not count them towards the number of # blocks we need to allocate. # - # consider a scenario with a seqeuence of 3 token blocks, and a block + # consider a scenario with a sequence of 3 token blocks, and a block # pool of only 2 blocks: [b0, b1, b2], where b0 and b1 are computed # but evicted, b2 is not computed. So b0, b1 are the 2 free blocks, # in evictor. When deciding how many more blocks need to be allocated for @@ -135,15 +135,13 @@ def _get_num_blocks_to_allocate( # just 1 block (b2). num_cached_tokens = ( self._computed_blocks_tracker.get_num_tokens_computed( - seq, allocated=True - ) - ) + seq, allocated=True)) - assert ( - num_cached_tokens % self.block_size == 0 - ), "Cached tokens must be a multiple of block size" + assert (num_cached_tokens % self.block_size == 0 + ), "Cached tokens must be a multiple of block size" num_cached_blocks = cdiv(num_cached_tokens, self.block_size) - num_required_blocks = cdiv(seq.get_len() + num_lookahead_slots, self.block_size) + num_required_blocks = cdiv(seq.get_len() + num_lookahead_slots, + self.block_size) return num_required_blocks - num_cached_blocks @@ -162,8 +160,7 @@ def get_num_cached_tokens(self, seq: Sequence, allocated: bool) -> int: """ return self._computed_blocks_tracker.get_num_tokens_computed( - seq, allocated - ) + seq, allocated) def can_allocate(self, seq_group: SequenceGroup, @@ -177,25 +174,21 @@ def can_allocate(self, self._computed_blocks_tracker.update_seq(seq) num_blocks_to_allocate = self._get_num_blocks_to_allocate( - seq, num_lookahead_slots - ) + seq, num_lookahead_slots) if seq_group.is_encoder_decoder(): encoder_seq = seq_group.get_encoder_seq() assert encoder_seq is not None self._computed_blocks_tracker.update_seq(encoder_seq) num_blocks_to_allocate += self._get_num_blocks_to_allocate( - encoder_seq, num_lookahead_slots=0 - ) + encoder_seq, num_lookahead_slots=0) if self.max_block_sliding_window is not None: - num_blocks_to_allocate = min( - num_blocks_to_allocate, self.max_block_sliding_window - ) + num_blocks_to_allocate = min(num_blocks_to_allocate, + self.max_block_sliding_window) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - device=Device.GPU - ) + device=Device.GPU) if self.num_total_gpu_blocks - num_blocks_to_allocate < self.watermark_blocks: return AllocStatus.NEVER if num_free_gpu_blocks - num_blocks_to_allocate >= self.watermark_blocks: @@ -217,8 +210,7 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_table.allocate( token_ids=seq.get_token_ids(), block_hashes=self._computed_blocks_tracker.get_block_hashes( - seq - ), + seq), ) return block_table @@ -285,14 +277,15 @@ def can_append_slots(self, seq_group: SequenceGroup, for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): self._computed_blocks_tracker.update_seq(seq) block_table = self.block_tables[seq.seq_id] - num_touched_blocks += block_table.get_num_blocks_touched_by_append_slots( - token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, - ) + num_touched_blocks += ( + block_table.get_num_blocks_touched_by_append_slots( + token_ids=block_table.get_unseen_token_ids( + seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + )) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - device=Device.GPU - ) + device=Device.GPU) return num_touched_blocks <= num_free_gpu_blocks def append_slots( @@ -305,7 +298,8 @@ def append_slots( # Extend the block table for any new decoded tokens, as well as reserve # space for lookahead slots. - unseen_token_ids = block_table.get_unseen_token_ids(seq.get_token_ids()) + unseen_token_ids = block_table.get_unseen_token_ids( + seq.get_token_ids()) if len(unseen_token_ids) == 0: block_hashes = [] else: @@ -396,9 +390,7 @@ def get_common_computed_block_ids( all_blocks = self.block_tables[seq.seq_id].physical_block_ids num_cached_tokens = ( self._computed_blocks_tracker.get_num_tokens_computed( - seq, allocated=True - ) - ) + seq, allocated=True)) assert num_cached_tokens % self.block_size == 0 num_cached_blocks = num_cached_tokens // self.block_size computed_block_ids = all_blocks[:num_cached_blocks] @@ -581,10 +573,8 @@ def _can_swap(self, if self.block_allocator.get_num_total_blocks( device) < num_blocks_touched: return AllocStatus.NEVER - elif ( - self.block_allocator.get_num_free_blocks(device=device) - num_blocks_touched - >= watermark_blocks - ): + elif (self.block_allocator.get_num_free_blocks(device=device) - + num_blocks_touched >= watermark_blocks): return AllocStatus.OK else: return AllocStatus.LATER diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7b0a81778c07f..ca269bceeefa4 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -561,16 +561,14 @@ def _schedule_running( while running_queue: seq_group = running_queue[0] num_running_tokens, num_running_tokens_cached = ( - self._get_num_new_tokens_to_schedule( - seq_group, SequenceStatus.RUNNING, enable_chunking, budget - ) - ) + self._get_num_new_tokens_to_schedule(seq_group, + SequenceStatus.RUNNING, + enable_chunking, budget)) assert num_running_tokens_cached == 0, ( "No tokens should have been cached for running seq groups " "(be it in continuous prefill with chunked prefill or in " - "decode)" - ) + "decode)") if num_running_tokens == 0: # No budget => Stop @@ -745,14 +743,13 @@ def _schedule_swapped( # exceed the maximum number of sequences. num_new_seqs = seq_group.get_max_num_running_seqs() num_new_tokens_uncached, num_new_tokens_cached = ( - self._get_num_new_tokens_to_schedule( - seq_group, SequenceStatus.SWAPPED, enable_chunking, budget - ) - ) + self._get_num_new_tokens_to_schedule(seq_group, + SequenceStatus.SWAPPED, + enable_chunking, budget)) if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, ): break @@ -766,10 +763,9 @@ def _schedule_swapped( prefill_seq_groups.append( ScheduledSequenceGroup( seq_group, - token_chunk_size=num_new_tokens_uncached - + num_new_tokens_cached, - ) - ) + token_chunk_size=num_new_tokens_uncached + + num_new_tokens_cached, + )) else: decode_seq_groups.append( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) @@ -845,43 +841,34 @@ def _schedule_priority_preemption( seq_group = waiting_queue.popleft() num_new_seqs = seq_group.get_max_num_running_seqs() num_new_tokens_uncached, _ = self._get_num_new_tokens_to_schedule( - seq_group, SequenceStatus.WAITING, False, budget - ) + seq_group, SequenceStatus.WAITING, False, budget) # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) - if ( - num_new_tokens_uncached - and can_allocate == AllocStatus.OK - and budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - ) - ): + if (num_new_tokens_uncached and can_allocate == AllocStatus.OK + and budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + )): break # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens_uncached, num_running_tokens_cached = ( self._get_num_new_tokens_to_schedule( - vseq_group, SequenceStatus.RUNNING, False, budget - ) - ) + vseq_group, SequenceStatus.RUNNING, False, budget)) assert num_running_tokens_cached == 0, ( "No tokens should have been cached for running seq " "groups (be it in continuous prefill with chunked prefill " - "or in decode)" - ) + "or in decode)") budget.subtract_num_batched_tokens( - vseq_group.request_id, num_running_tokens_uncached - ) + vseq_group.request_id, num_running_tokens_uncached) num_running_seqs = vseq_group.get_max_num_running_seqs() - budget.subtract_num_seqs( - vseq_group.request_id, num_running_seqs - ) + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out, @@ -946,8 +933,7 @@ def _schedule_prefills( SequenceStatus.WAITING, enable_chunking, budget, - ) - ) + )) num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached if not enable_chunking: # Sanity check that the number of new tokens is correct. @@ -1006,8 +992,8 @@ def _schedule_prefills( # We have new tokens but they might be cached. if not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, ): # No more budget for new tokens. break @@ -1030,16 +1016,15 @@ def _schedule_prefills( else: seq_group.init_multi_step_from_lookahead_slots( num_lookahead_slots, - num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, is_multi_step=self.scheduler_config.is_multi_step, enable_chunking=enable_chunking, ) seq_groups.append( - ScheduledSequenceGroup( - seq_group=seq_group, token_chunk_size=num_new_tokens - ) - ) + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_new_tokens)) budget.add_num_batched_tokens( seq_group.request_id, num_batched_tokens=num_new_tokens_uncached, @@ -1056,8 +1041,7 @@ def _schedule_prefills( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking - ), + is_prefill=True, enable_chunking=enable_chunking), ) def _schedule_default(self) -> SchedulerOutputs: @@ -1076,20 +1060,13 @@ def _schedule_default(self) -> SchedulerOutputs: # Make sure we include num running seqs before scheduling prefill, # so that we don't schedule beyond max_num_seqs for prefill. for seq_group in self.running: - budget.add_num_seqs( - seq_group.request_id, seq_group.get_max_num_running_seqs() - ) + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - curr_loras = ( - set( - seq_group.lora_int_id - for seq_group in self.running - if seq_group.lora_int_id > 0 - ) - if self.lora_enabled - else None - ) + curr_loras = (set( + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None) prefills = SchedulerPrefillOutputs.create_empty() running_scheduled = SchedulerRunningOutputs.create_empty() @@ -1097,37 +1074,30 @@ def _schedule_default(self) -> SchedulerOutputs: # If any requests are swapped, prioritized swapped requests. if not self.swapped: - prefills = self._schedule_prefills( - budget, curr_loras, enable_chunking=False - ) + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=False) - if ( - len(prefills.seq_groups) == 0 - and self.scheduler_config.policy == "priority" - ): + if (len(prefills.seq_groups) == 0 + and self.scheduler_config.policy == "priority"): self._schedule_priority_preemption(budget) # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. if len(prefills.seq_groups) == 0: - running_scheduled = self._schedule_running( - budget, curr_loras, enable_chunking=False - ) + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=False) # If any sequence group is preempted, do not swap in any sequence # group. because it means there's no slot for new running requests. - if ( - len(running_scheduled.preempted) - + len(running_scheduled.swapped_out) - == 0 - ): + if (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) == 0): swapped_in = self._schedule_swapped(budget, curr_loras) - assert ( - budget.num_batched_tokens - <= self.scheduler_config.max_num_batched_tokens - ) + assert (budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -1140,14 +1110,12 @@ def _schedule_default(self) -> SchedulerOutputs: if len(swapped_in.decode_seq_groups) > 0: self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups] - ) + [s.seq_group for s in swapped_in.decode_seq_groups]) # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) preempted = len(running_scheduled.preempted) + len( - running_scheduled.swapped_out - ) + running_scheduled.swapped_out) # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. @@ -1207,27 +1175,24 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: swapped_in = SchedulerSwappedInOutputs.create_empty() # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running( - budget, curr_loras, enable_chunking=True - ) + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=True) # Schedule swapped out requests. # If preemption happens, it means we don't have space for swap-in. - if ( - len(running_scheduled.preempted) - + len(running_scheduled.swapped_out) - == 0 - ): + if (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) == 0): swapped_in = self._schedule_swapped(budget, curr_loras) # Schedule new prefills. - prefills = self._schedule_prefills( - budget, curr_loras, enable_chunking=True - ) + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=True) assert ( - budget.num_batched_tokens - <= self.scheduler_config.max_num_batched_tokens + budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens ), f"{budget.num_batched_tokens=}, {self.scheduler_config.max_num_batched_tokens=}" assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs @@ -1238,47 +1203,39 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # By default, vLLM scheduler prioritizes prefills. # Once chunked prefill is enabled, # the policy is changed to prioritize decode requests. - self.running.extend([s.seq_group for s in swapped_in.decode_seq_groups]) self.running.extend( - [s.seq_group for s in swapped_in.prefill_seq_groups] - ) + [s.seq_group for s in swapped_in.decode_seq_groups]) self.running.extend( - [s.seq_group for s in running_scheduled.decode_seq_groups] - ) + [s.seq_group for s in swapped_in.prefill_seq_groups]) self.running.extend( - [s.seq_group for s in running_scheduled.prefill_seq_groups] - ) + [s.seq_group for s in running_scheduled.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.prefill_seq_groups]) self.running.extend([s.seq_group for s in prefills.seq_groups]) # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) return SchedulerOutputs( - scheduled_seq_groups=( - prefills.seq_groups - + running_scheduled.prefill_seq_groups - + swapped_in.prefill_seq_groups - + running_scheduled.decode_seq_groups - + swapped_in.decode_seq_groups - ), - num_prefill_groups=( - len(prefills.seq_groups) - + len(swapped_in.prefill_seq_groups) - + len(running_scheduled.prefill_seq_groups) - ), + scheduled_seq_groups=(prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), + num_prefill_groups=(len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)), num_batched_tokens=budget.num_batched_and_cached_tokens, num_batched_tokens_from_budget=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy - + swapped_in.blocks_to_copy, - ignored_seq_groups=prefills.ignored_seq_groups - + swapped_in.infeasible_seq_groups, + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), - preempted=( - len(running_scheduled.preempted) - + len(running_scheduled.swapped_out) - ), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), ) def _schedule(self) -> SchedulerOutputs: @@ -1288,25 +1245,21 @@ def _schedule(self) -> SchedulerOutputs: else: return self._schedule_default() - def _can_append_slots( - self, seq_group: SequenceGroup, enable_chunking: bool - ) -> bool: + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ # It is True only for testing case to trigger artificial preemption. - if ( - self.enable_artificial_preemption - and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB - and self.artificial_preempt_cnt > 0 - ): + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): self.artificial_preempt_cnt -= 1 return False is_prefill = seq_group.is_prefill() num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill, enable_chunking - ) + is_prefill, enable_chunking) if is_prefill and num_lookahead_slots > 0: # Appending prefill slots only happens multi-step and @@ -1314,20 +1267,17 @@ def _can_append_slots( assert self.scheduler_config.is_multi_step and enable_chunking return self.block_manager.can_append_slots( - seq_group=seq_group, num_lookahead_slots=num_lookahead_slots - ) + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: # async_output_proc is allowed only when we have a single sequence # in the sequence group no_single_seq = seq_group.sampling_params is None or ( - seq_group.sampling_params.n == 1 - ) + seq_group.sampling_params.n == 1) return no_single_seq def schedule( - self, - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: + self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. @@ -1344,15 +1294,13 @@ def schedule( # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] for i, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups - ): + scheduler_outputs.scheduled_seq_groups): seq_group = scheduled_seq_group.seq_group token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) seq_group_metadata = self._seq_group_metadata_cache[ - self.cache_id - ].get_object() + self.cache_id].get_object() seq_group_metadata.seq_data.clear() seq_group_metadata.block_tables.clear() @@ -1369,8 +1317,7 @@ def schedule( # Block table for cross-attention # Also managed at SequenceGroup level cross_block_table = self.block_manager.get_cross_block_table( - seq_group - ) + seq_group) else: encoder_seq_data = None cross_block_table = None @@ -1384,9 +1331,7 @@ def schedule( if self.cache_config.enable_prefix_caching: common_computed_block_nums = ( self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING) - ) - ) + seq_group.get_seqs(status=SequenceStatus.RUNNING))) do_sample = True is_prompt = seq_group.is_prefill() @@ -1404,10 +1349,8 @@ def schedule( # NOTE: We use get_len instead of get_prompt_len because when # a sequence is preempted, prefill includes previous generated # output tokens. - if ( - token_chunk_size + num_computed_tokens - < seqs[0].data.get_len() - ): + if (token_chunk_size + num_computed_tokens < + seqs[0].data.get_len()): do_sample = False # It assumes the scheduled_seq_groups is ordered by @@ -1432,8 +1375,7 @@ def schedule( # the subsequent comms can still use delta, but # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups > 0 - else None, + if scheduler_outputs.num_prefill_groups > 0 else None, mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) @@ -1456,8 +1398,7 @@ def schedule( if allow_async_output_proc: allow_async_output_proc = self._allow_async_output_proc( - seq_group - ) + seq_group) # Now that the batch has been created, we can assume all blocks in the # batch will have been computed before the next scheduling invocation. @@ -1564,8 +1505,7 @@ def _append_slots( """ is_prefill: bool = seq_group.is_prefill() num_lookahead_slots: int = self._get_num_lookahead_slots( - is_prefill, enable_chunking - ) + is_prefill, enable_chunking) seq_group.init_multi_step_from_lookahead_slots( num_lookahead_slots, @@ -1672,8 +1612,7 @@ def _swap_out( # entire engine. raise RuntimeError( "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error." - ) + "the swap space to avoid this error.") mapping = self.block_manager.swap_out(seq_group) blocks_to_swap_out.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): @@ -1686,18 +1625,16 @@ def _passed_delay(self, now: float) -> bool: # Delay scheduling prompts to let waiting queue fill up if self.scheduler_config.delay_factor > 0 and self.waiting: earliest_arrival_time = min( - [e.metrics.arrival_time for e in self.waiting] - ) + [e.metrics.arrival_time for e in self.waiting]) passed_delay = (now - earliest_arrival_time) > ( - self.scheduler_config.delay_factor * self.last_prompt_latency - ) or not self.running + self.scheduler_config.delay_factor * + self.last_prompt_latency) or not self.running else: passed_delay = True return passed_delay - def _get_num_lookahead_slots( - self, is_prefill: bool, enable_chunking: bool - ) -> int: + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. @@ -1777,10 +1714,8 @@ def _get_num_new_tokens_to_schedule( # So we set allocated=False. if self.cache_config.enable_prefix_caching: num_cached_tokens_seq = ( - self.block_manager.get_num_cached_tokens( - seq, allocated=False - ) - ) + self.block_manager.get_num_cached_tokens(seq, + allocated=False)) # Any computed token should have been cached too. if not self.scheduler_config.chunked_prefill_enabled: @@ -1788,19 +1723,16 @@ def _get_num_new_tokens_to_schedule( f"Number of cached tokens ({num_cached_tokens_seq}) " "should be no less than the number of computed " f"tokens ({num_computed_tokens_seq}) when chunked prefill " - "is disabled" - ) + "is disabled") num_new_tokens_cached_seq = max( - 0, num_cached_tokens_seq - num_computed_tokens_seq - ) + 0, num_cached_tokens_seq - num_computed_tokens_seq) else: num_cached_tokens_seq = 0 num_new_tokens_cached_seq = 0 all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq - num_new_tokens_uncached_seq = ( - all_num_new_tokens_seq - num_new_tokens_cached_seq - ) + num_new_tokens_uncached_seq = (all_num_new_tokens_seq - + num_new_tokens_cached_seq) num_uncached_new_tokens += num_new_tokens_uncached_seq num_cached_new_tokens += num_new_tokens_cached_seq @@ -1840,11 +1772,8 @@ def _chunk_new_tokens_to_schedule( # the sequence. pass else: - num_new_tokens = ( - 0 - if num_new_tokens > remaining_token_budget - else num_new_tokens - ) + num_new_tokens = (0 if num_new_tokens > remaining_token_budget + else num_new_tokens) elif self.cache_config.enable_prefix_caching: # When prefix caching is enabled, we always allocate # the number of new tokens that is dividable by the block @@ -1852,139 +1781,18 @@ def _chunk_new_tokens_to_schedule( block_size = self.cache_config.block_size remainder = budget.token_budget % block_size if remainder != 0: - raise ValueError( - "When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}" - ) + raise ValueError("When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {remainder}") # Round down to block - remaining_token_budget = ( - remaining_token_budget // block_size * block_size - ) + remaining_token_budget = (remaining_token_budget // block_size * + block_size) # Calculate the number of new tokens that are not cached with chunk cap. num_new_tokens = min(num_new_tokens, remaining_token_budget) else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens - - # def _get_num_new_tokens( - # self, - # seq_group: SequenceGroup, - # status: SequenceStatus, - # enable_chunking: bool, - # budget: SchedulingBudget, - # ) -> int: - # """Get the next new tokens to compute for a given sequence group - # that's in a given `status`. - - # The API could chunk the number of tokens to compute based on `budget` - # if `enable_chunking` is True. If a sequence group has multiple - # sequences (e.g., running beam search), it means it is in decoding - # phase, so chunking doesn't happen. - - # Returns 0 if the new token cannot be computed due to token budget. - # """ - # num_new_tokens = 0 - # seqs = seq_group.get_seqs(status=status) - # for seq in seqs: - # num_new_tokens += seq.get_num_new_tokens() - # assert num_new_tokens > 0 - # # Chunk if a running request cannot fit in the given budget. - # # If number of seq > 1, it means it is doing beam search - # # in a decode phase. Do not chunk. - # if enable_chunking and len(seqs) == 1: - # remaining_token_budget = budget.remaining_token_budget() - # seq = seqs[0] - # if self.scheduler_config.is_multi_step: - # # The current multi-step + chunked prefill capability does - # # not actually support chunking prompts. - # # - # # Therefore, `num_new_tokens` is computed in the same fashion - # # for both multi-step+chunked-prefill & - # # multi-step+chunked-prefill+APC - # # - # # Prompts with more tokens than the current remaining budget - # # are postponed to future scheduler steps - # if num_new_tokens > self._get_prompt_limit(seq_group): - # # If the seq_group is in prompt-stage, pass the - # # num_new_tokens as-is so the caller can ignore - # # the sequence. - # pass - # else: - # num_new_tokens = 0 \ - # if num_new_tokens > remaining_token_budget \ - # else num_new_tokens - # elif self.cache_config.enable_prefix_caching: - # # When prefix caching is enabled, we always allocate - # # the number of new tokens that is dividable by the block - # # size to avoid partial block matching. - # block_size = self.cache_config.block_size - # remainder = budget.token_budget % block_size - # if remainder != 0: - # raise ValueError("When enabling chunked prefill and " - # "prefix caching, max_num_batched_tokens " - # "(chunk size) must be dividable by " - # "block size, but got chunk_size " - # f"({budget.token_budget}) % block_size " - # f"({block_size}) = {remainder}") - # num_new_tokens_cached = seq.get_num_cached_tokens() - seq.get_num_computed_tokens() - # num_new_tokens_cached = max(0, num_new_tokens_cached) - # # Round down to block - # remaining_token_budget = remaining_token_budget // block_size * block_size - - # # Calculate the number of new tokens that are not cached with chunk cap. - # num_new_tokens_uncached = min(num_new_tokens - num_new_tokens_cached, remaining_token_budget) - # if num_new_tokens_uncached == 0: - # # No more budget for new tokens, don't include any cached tokens too. - # return 0 - # num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached - # else: - # num_new_tokens = min(num_new_tokens, remaining_token_budget) - # return num_new_tokens - # - # def _update_prefix_cached_tokens(self, seq: Sequence): - # """ - # Update the number of prefix cached tokens for a sequence. - - # This function takes O(log(n)) time, where n is the number of blocks - # in the sequence. - # """ - # num_prefix_cached_tokens = self.block_manager.get_num_computed_tokens(seq) - # seq.set_num_prefix_cached_tokens(num_prefix_cached_tokens) - - # def _get_num_new_tokens_exclude_cached( - # self, num_new_tokens: int, seq: Sequence - # ) -> int: - # """ - # Get the number of new tokens to compute for a sequence, excluding - # cached tokens. - - # Args: - # num_new_tokens: The number of new tokens to compute. - # seq: The sequence to compute the new tokens for. - - # Returns: - # Given `num_new_tokens`, returns the number of uncached tokens. - # """ - - # # If a decode sequence, new tokens are always not computed/cached. - # if not seq.is_prefill(): - # return num_new_tokens - - # # If a prefill sequence, we need to exclude the number of cached tokens. - # num_computed_tokens = seq.get_num_computed_tokens() - # num_cached_tokens = seq.get_num_cached_tokens() - - # # We subtract the number of cached tokens from the number of new tokens - # num_computed_tokens_new = num_new_tokens + num_computed_tokens - # num_new_tokens_exclude_cached = max( - # 0, num_computed_tokens_new - num_cached_tokens - # ) - # assert ( - # num_new_tokens_exclude_cached <= num_new_tokens - # ), "Number of new tokens exclude cached should be less than or equal to the number of new tokens" - # return num_new_tokens_exclude_cached diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index f9ad930e0a7c3..a46625eff1e4a 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -475,9 +475,8 @@ def _log_prometheus(self, stats: Stats) -> None: stats.num_preemption_iter) self._log_counter(self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter) - self._log_counter( - self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter - ) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) self._log_histogram(self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter) self._log_histogram(self.metrics.histogram_time_per_output_token, diff --git a/vllm/sequence.py b/vllm/sequence.py index 5d68371ce7d6c..c73aeb36974a7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -591,22 +591,21 @@ def _update_block_hashes(self): token_ids = self.get_token_ids() # All token ids in the sequence num_full_blocks = len(token_ids) // self.block_size cur_num_full_blocks = len(self._computed_block_hashes) - prev_block_hash = ( - None if cur_num_full_blocks == 0 else self._computed_block_hashes[-1] - ) + prev_block_hash = (None if cur_num_full_blocks == 0 else + self._computed_block_hashes[-1]) for i in range(cur_num_full_blocks, num_full_blocks): - block_token_ids = token_ids[i * self.block_size : (i + 1) * self.block_size] + block_token_ids = token_ids[i * self.block_size:(i + 1) * + self.block_size] assert len(block_token_ids) == self.block_size - block_hash = hash( - ( - prev_block_hash, # Previous block hash - self.from_decoder_prompt, # Whether the sequence is decoder-only - # LoRA int id since the attention output will depend on - # LoRA with same token ids. - self.lora_int_id, - *block_token_ids, # The block token ids - ) - ) + block_hash = hash(( + prev_block_hash, # Previous block hash + self. + from_decoder_prompt, # Whether the sequence is decoder-only + # LoRA int id since the attention output will depend on + # LoRA with same token ids. + self.lora_int_id, + *block_token_ids, # The block token ids + )) self._computed_block_hashes.append(block_hash) prev_block_hash = block_hash @@ -619,15 +618,13 @@ def _reset_block_hashes(self): """ num_full_prompt_blocks = self.get_prompt_len() // self.block_size self._computed_block_hashes = self._computed_block_hashes[ - num_full_prompt_blocks: - ] + num_full_prompt_blocks:] def set_num_prefix_cached_tokens(self, num_prefix_cached_tokens: int): self.data.set_num_prefix_cached_tokens(num_prefix_cached_tokens) - def hash_of_block( - self, prev_block_hash: Optional[int], cur_block_idx: int - ) -> int: + def hash_of_block(self, prev_block_hash: Optional[int], + cur_block_idx: int) -> int: """ Get the hash of a given block of the sequence. @@ -642,19 +639,15 @@ def hash_of_block( token_ids = self.get_token_ids() assert (cur_block_idx + 1) * self.block_size <= len(token_ids), ( f"Invalid block index: {cur_block_idx}. The sequence only has " - f"{len(token_ids) // self.block_size} blocks." - ) - block_token_ids = token_ids[ - cur_block_idx * self.block_size : (cur_block_idx + 1) - * self.block_size - ] - return hash( - ( - prev_block_hash, - self.lora_int_id, - *block_token_ids, - ) - ) + f"{len(token_ids) // self.block_size} blocks.") + block_token_ids = token_ids[cur_block_idx * + self.block_size:(cur_block_idx + 1) * + self.block_size] + return hash(( + prev_block_hash, + self.lora_int_id, + *block_token_ids, + )) def reset_state_for_recompute(self): """Reset the sequence states for recomputation."""