From a5875e8927120963dd530a7170811d47d132d902 Mon Sep 17 00:00:00 2001 From: rickyx Date: Thu, 10 Oct 2024 23:39:53 +0000 Subject: [PATCH 01/12] Tests passing --- tests/core/block/test_block_manager_v2.py | 4 +- tests/core/block/test_block_table.py | 98 ++++++++-- tests/core/utils.py | 24 +++ vllm/core/block/block_table.py | 68 ++++--- vllm/core/block/common.py | 32 +++- vllm/core/block/cpu_gpu_block_allocator.py | 31 +++- vllm/core/block/interfaces.py | 32 +++- vllm/core/block/naive_block.py | 53 ++++-- vllm/core/block/prefix_caching_block.py | 205 ++++++++++++++------- vllm/core/block_manager_v2.py | 66 ++++--- vllm/sequence.py | 64 ++++++- 11 files changed, 508 insertions(+), 169 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index e67883367879f..c4c41e714a02d 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -326,8 +326,10 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, num_gpu_blocks, watermark=0, enable_caching=enable_caching) + print(f"prompt_length={(num_gpu_blocks - 1) * block_size - 1}") prompt, seq_group = create_dummy_prompt( - "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1) + "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1, block_size=block_size + ) prompt.status = SequenceStatus.WAITING block_manager.allocate(seq_group) prompt.status = SequenceStatus.RUNNING diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index e2391a5680b36..06f1e297695a9 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -1,12 +1,33 @@ -from typing import List +from typing import List, Optional import pytest +from tests.core.utils import create_dummy_sequence from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.sequence import Logprob from vllm.utils import Device, cdiv, chunk_list +def make_sequence( + request_id: int, + token_ids: List[int], + block_size: int, + num_output_tokens: int = 0, + output_tokens: Optional[List[int]] = None, +): + if output_tokens is None: + output_tokens = list(range(num_output_tokens)) + + seq = create_dummy_sequence( + sequence_id=request_id, + prompt_tokens=token_ids, + block_size=block_size, + output_tokens=output_tokens, + ) + return seq + + @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("sequence_len", [1, 16, 129]) def test_allocate_naive(block_size: int, sequence_len: int): @@ -35,12 +56,13 @@ def test_allocate_naive(block_size: int, sequence_len: int): assert allocator.get_num_free_blocks( device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc + seq = make_sequence(i, token_ids, block_size) block_tables.append( BlockTable( block_size=block_size, block_allocator=allocator, )) - block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU) + block_tables[-1].allocate(seq=seq, device=Device.GPU) @pytest.mark.parametrize("block_size", [16]) @@ -82,8 +104,11 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int): BlockTable( block_size=block_size, block_allocator=allocator, - )) - block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU) + enable_prefix_caching=True, + ) + ) + seq = make_sequence(alloc_i, token_ids, block_size) + block_tables[-1].allocate(seq=seq, device=Device.GPU) # Expect all sequences to share allocations, except for their last block # (which may be mutable). @@ -123,10 +148,12 @@ def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str, block_table = BlockTable( block_size=block_size, block_allocator=allocator, + enable_prefix_caching=True if allocator_type == "prefix_caching" else False, ) for i in range(5): - block_table.allocate(token_ids=token_ids, device=device) + seq = make_sequence(i, token_ids, block_size) + block_table.allocate(seq=seq, device=device) assert allocator.get_num_free_blocks( device) == num_device_blocks - num_blocks_per_alloc assert all(block_id is not None @@ -166,6 +193,7 @@ def test_append_token_ids_allocation(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, + enable_prefix_caching=True if allocator_type == "prefix_caching" else False, ) num_expected_blocks_before_append = len( @@ -174,11 +202,18 @@ def test_append_token_ids_allocation(block_size: int, sequence_len: int, list(chunk_list(token_ids + token_ids_to_append, block_size))) - num_expected_blocks_before_append - block_table.allocate(token_ids=token_ids, device=Device.GPU) + seq = make_sequence(0, token_ids, block_size) + + block_table.allocate(seq=seq, device=Device.GPU) assert len( block_table.physical_block_ids) == num_expected_blocks_before_append - block_table.append_token_ids(token_ids_to_append) + + # Update the sequence. + for token_id in token_ids_to_append: + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + + block_table.append_slots(seq=seq, num_lookahead_slots=0) assert len( block_table.physical_block_ids ) == num_expected_blocks_before_append + num_expected_appended_blocks @@ -215,6 +250,7 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, + enable_prefix_caching=True if allocator_type == "prefix_caching" else False, ) num_expected_blocks_before_append = len( @@ -223,7 +259,9 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, list(chunk_list(token_ids + [-1] * num_empty_slots, block_size))) - num_expected_blocks_before_append - block_table.allocate(token_ids=token_ids, device=Device.GPU) + seq = make_sequence(0, token_ids, block_size) + + block_table.allocate(seq=seq, device=Device.GPU) # Assert that the empty slots consume the expected number of additional # blocks. @@ -236,7 +274,10 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, # Now, ensure no additional blocks consumed as we fill up the empty slots. num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU) - block_table.append_token_ids(token_ids=list(range(num_empty_slots))) + tokens_to_append = list(range(num_empty_slots)) + for token_id in tokens_to_append: + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + block_table.append_slots(seq=seq, num_lookahead_slots=0) assert num_free_blocks == allocator.get_num_free_blocks(device=Device.GPU) @@ -267,12 +308,17 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, + enable_prefix_caching=True if allocator_type == "prefix_caching" else False, ) - block_table.allocate(token_ids=token_ids, device=Device.GPU) + seq = make_sequence(0, token_ids, block_size) + block_table.allocate(seq=seq, device=Device.GPU) appended_so_far: List[int] = [] for append in chunk_list(token_ids_to_append, append_size): - block_table.append_token_ids(append) + for token_id in append: + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + block_table.append_slots(seq=seq, num_lookahead_slots=0) + appended_so_far.extend(append) assert block_table._get_all_token_ids() == token_ids + appended_so_far @@ -307,9 +353,11 @@ def test_fork(seq_len: int, block_size: int, allocator_type: str): block_table = BlockTable( block_size=block_size, block_allocator=allocator, + enable_prefix_caching=True if allocator_type == "prefix_caching" else False, ) - block_table.allocate(token_ids) + seq = make_sequence(0, token_ids, block_size) + block_table.allocate(seq=seq, device=Device.GPU) num_free_blocks_before_fork = allocator.get_num_free_blocks( device=Device.GPU) @@ -366,13 +414,15 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, original_block_table = BlockTable( block_size=block_size, block_allocator=allocator, + enable_prefix_caching=True if allocator_type == "prefix_caching" else False, ) num_expected_non_cow_blocks = cdiv(sequence_len, block_size) num_expected_cow_blocks = cdiv(sequence_len + append_len, block_size) - (sequence_len // block_size) - original_block_table.allocate(token_ids=token_ids, device=Device.GPU) + seq = make_sequence(0, token_ids, block_size) + original_block_table.allocate(seq=seq, device=Device.GPU) original_block_ids = original_block_table.physical_block_ids[:] print("original_block_ids = {}".format(original_block_ids)) @@ -392,7 +442,9 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, raise ValueError(f"unknown test config {appender=}") # Write tokens. - appender_block_table.append_token_ids(token_ids_to_append) + for token_id in token_ids_to_append: + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + appender_block_table.append_slots(seq=seq, num_lookahead_slots=0) # Expect the non-appending block table to have no change. assert static_block_table.physical_block_ids == original_block_ids @@ -452,9 +504,11 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, original_block_table = BlockTable( block_size=block_size, block_allocator=allocator, + enable_prefix_caching=True if allocator_type == "prefix_caching" else False, ) - original_block_table.allocate(token_ids=token_ids, device=Device.GPU) + seq = make_sequence(0, token_ids, block_size) + original_block_table.allocate(seq=seq, device=Device.GPU) # Allocate lookahead slots. original_block_table.ensure_num_empty_slots(lookahead_slots) @@ -472,7 +526,9 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, raise ValueError(f"unknown test config {appender=}") # Write tokens. - appender_block_table.append_token_ids(token_ids_to_append) + for token_id in token_ids_to_append: + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + appender_block_table.append_slots(seq=seq, num_lookahead_slots=0) # Expect the non-appending block table to have no change. assert static_block_table.physical_block_ids == original_block_ids @@ -534,9 +590,10 @@ def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, + enable_prefix_caching=True if allocator_type == "prefix_caching" else False, ) - - block_table.allocate(token_ids=token_ids, device=Device.GPU) + seq = make_sequence(0, token_ids, block_size) + block_table.allocate(seq=seq, device=Device.GPU) # Add lookahead before fork so both sequences have the same lookahead # blocks. @@ -556,7 +613,10 @@ def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, # # We expect append_token_ids to CoW all mutated blocks that have refcount>1. num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU) - block_table.append_token_ids(token_ids_to_append, num_lookahead_slots) + for token_id in token_ids_to_append: + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + block_table.append_slots(seq=seq, num_lookahead_slots=num_lookahead_slots) + num_consumed_blocks = (num_free_blocks_before_append - allocator.get_num_free_blocks(Device.GPU)) diff --git a/tests/core/utils.py b/tests/core/utils.py index a95a573db7cd3..a2c2df029cb19 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -44,6 +44,30 @@ def create_dummy_prompt( return prompt, seq_group +def create_dummy_sequence( + sequence_id: int, + prompt_tokens: List[int], + block_size: int, + output_tokens: Optional[List[int]] = None, +): + if output_tokens is None: + output_tokens = [] + + seq = Sequence( + sequence_id, + inputs={ + "prompt": " ".join([str(t) for t in prompt_tokens]), + "prompt_token_ids": prompt_tokens, + }, + block_size=block_size, + ) + + for token_id in output_tokens: + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + + return seq + + def create_dummy_prompt_encoder_decoder( request_id: str, decoder_prompt_length: int, diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index d10cb29ef4a7c..64267e09e1316 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -3,6 +3,7 @@ from vllm.core.block.common import BlockList from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator +from vllm.sequence import Sequence from vllm.utils import Device, cdiv, chunk_list @@ -44,6 +45,7 @@ def __init__( block_allocator: DeviceAwareBlockAllocator, _blocks: Optional[List[Block]] = None, max_block_sliding_window: Optional[int] = None, + enable_prefix_caching: bool = False, ): self._block_size = block_size self._allocator = block_allocator @@ -54,6 +56,9 @@ def __init__( self._max_block_sliding_window = max_block_sliding_window self._num_full_slots = self._get_num_token_ids() + # Whether to enable prefix caching. + self._enable_prefix_caching = enable_prefix_caching + @staticmethod def get_num_required_blocks(token_ids: List[int], block_size: int, @@ -78,26 +83,24 @@ def get_num_required_blocks(token_ids: List[int], """ return cdiv(len(token_ids) + num_lookahead_slots, block_size) - def allocate(self, - token_ids: List[int], - device: Device = Device.GPU) -> None: + def allocate(self, seq: Sequence, device: Device = Device.GPU) -> None: """Allocates memory blocks for storing the given sequence of token IDs. This method allocates the required number of blocks to store the given sequence of token IDs. Args: - token_ids (List[int]): The sequence of token IDs to be stored. + seq (Sequence): The sequence to allocate blocks for. device (Device, optional): The device on which the blocks should be allocated. Defaults to Device.GPU. """ assert not self._is_allocated - assert token_ids - blocks = self._allocate_blocks_for_token_ids(prev_block=None, - token_ids=token_ids, - device=device) + if not seq.get_token_ids(): + return + + blocks = self._allocate_blocks_for_token_ids(seq=seq, device=device) self.update(blocks) - self._num_full_slots = len(token_ids) + self._num_full_slots = len(seq.get_token_ids()) def update(self, blocks: List[Block]) -> None: """Resets the table to the newly provided blocks @@ -105,10 +108,11 @@ def update(self, blocks: List[Block]) -> None: """ self._blocks.update(blocks) - def append_token_ids(self, - token_ids: List[int], - num_lookahead_slots: int = 0, - num_computed_slots: Optional[int] = None) -> None: + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int = 0, + ) -> None: """Appends a sequence of token IDs to the existing blocks in the BlockTable. @@ -134,6 +138,9 @@ def append_token_ids(self, assert self._is_allocated, "no blocks have been allocated" assert len(self._blocks) > 0 + token_ids = self.get_unseen_token_ids(seq.get_token_ids()) + num_computed_slots = seq.data.get_num_computed_tokens() + # Drop blocks that are no longer needed due to sliding window if self._max_block_sliding_window is not None: null_block = self._allocator.allocate_or_get_null_block() @@ -156,7 +163,11 @@ def append_token_ids(self, token_blocks = self._chunk_token_blocks_for_append(token_ids) for i, token_block in enumerate(token_blocks): - self._blocks.append_token_ids(first_block_idx + i, token_block) + if self._enable_prefix_caching: + block_hash: Optional[int] = seq.get_block_hash(first_block_idx + i) + else: + block_hash = None + self._blocks.append_token_ids(first_block_idx + i, token_block, block_hash) self._num_full_slots += len(token_ids) @@ -210,6 +221,7 @@ def fork(self) -> "BlockTable": block_allocator=self._allocator, _blocks=forked_blocks, max_block_sliding_window=self._max_block_sliding_window, + enable_prefix_caching=self._enable_prefix_caching, ) def free(self) -> None: @@ -259,33 +271,47 @@ def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: # ones after the appended ones. return sequence_token_ids[self.num_full_slots:] - def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], - token_ids: List[int], - device: Device) -> List[Block]: + def _allocate_blocks_for_token_ids( + self, seq: Sequence, device: Device + ) -> List[Block]: blocks: List[Block] = [] + block_hashes: List[Optional[int]] = [] + prev_block: Optional[Block] = None block_token_ids = [] tail_token_ids = [] - for cur_token_ids in chunk_list(token_ids, self._block_size): + token_ids = seq.get_token_ids() + chunked_block_token_ids = chunk_list(token_ids, self._block_size) + for block_idx, cur_token_ids in enumerate(chunked_block_token_ids): if len(cur_token_ids) == self._block_size: block_token_ids.append(cur_token_ids) + if self._enable_prefix_caching: + block_hashes.append(seq.get_block_hash(block_idx)) + else: + block_hashes.append(None) else: tail_token_ids.append(cur_token_ids) + block_hashes.append(None) if block_token_ids: blocks.extend( self._allocator.allocate_immutable_blocks( - prev_block, block_token_ids=block_token_ids, - device=device)) + prev_block, + block_token_ids=block_token_ids, + block_hashes=block_hashes, + device=device, + ) + ) prev_block = blocks[-1] if tail_token_ids: assert len(tail_token_ids) == 1 + assert block_hashes[-1] is None cur_token_ids = tail_token_ids[0] block = self._allocator.allocate_mutable_block( prev_block=prev_block, device=device) - block.append_token_ids(cur_token_ids) + block.append_token_ids(cur_token_ids, block_hash=None) blocks.append(block) diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index eb190adfbe802..191eb6c8fdd15 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -196,8 +196,25 @@ def increase_pool(self): allocator=self._allocator, block_id=None)) - def init_block(self, prev_block: Optional[Block], token_ids: List[int], - block_size: int, physical_block_id: Optional[int]) -> Block: + # TODO(rickyx): This should take in kwargs for flexible initialization of different types of blocks + # Right now, we update explicitly blocks with other args after initialization, e.g. block_hash + # computed for the prefix caching block. + def init_block( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + physical_block_id: Optional[int], + ) -> Block: + """Initializes a block with the given parameters. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + token_ids (List[int]): The token IDs to be stored in the block. + block_size (int): The size of the block. + physical_block_id (Optional[int]): The physical block ID. + block_hash (Optional[int]): The hash of the block's content. + """ if len(self._free_ids) == 0: self.increase_pool() assert len(self._free_ids) > 0 @@ -209,8 +226,9 @@ def init_block(self, prev_block: Optional[Block], token_ids: List[int], prev_block=prev_block, token_ids=token_ids, block_size=block_size, - allocator=block._allocator, # type: ignore[attr-defined] - block_id=physical_block_id) + allocator=block._allocator, # type: ignore[attr-defined] + block_id=physical_block_id, + ) block.pool_id = pool_id # type: ignore[attr-defined] return block @@ -248,11 +266,13 @@ def update(self, blocks: List[Block]): for block in self._blocks: self._add_block_id(block.block_id) - def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: + def append_token_ids( + self, block_index: int, token_ids: List[int], block_hash: Optional[int] + ) -> None: block = self._blocks[block_index] prev_block_id = block.block_id - block.append_token_ids(token_ids) + block.append_token_ids(token_ids, block_hash) # CoW or promotion may update the internal block_id if prev_block_id != block.block_id: diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 6eda5f99aa1c8..d2d3ee61a1dde 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -4,6 +4,7 @@ DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.sequence import Sequence from vllm.utils import Device @@ -116,8 +117,11 @@ def allocate_or_get_null_block(self) -> Block: self.allocate_mutable_block(None, Device.GPU)) return self._null_block - def allocate_mutable_block(self, prev_block: Optional[Block], - device: Device) -> Block: + def allocate_mutable_block( + self, + prev_block: Optional[Block], + device: Device, + ) -> Block: """Allocates a new mutable block on the specified device. Args: @@ -130,9 +134,13 @@ def allocate_mutable_block(self, prev_block: Optional[Block], """ return self._allocators[device].allocate_mutable_block(prev_block) - def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device) -> List[Block]: + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + block_hashes: Optional[List[Optional[int]]] = None, + ) -> List[Block]: """Allocates a new group of immutable blocks with the provided block token IDs on the specified device. @@ -148,7 +156,15 @@ def allocate_immutable_blocks(self, prev_block: Optional[Block], containing the provided block token IDs. """ return self._allocators[device].allocate_immutable_blocks( - prev_block, block_token_ids) + prev_block, block_token_ids, block_hashes + ) + + def get_cached_blocks( + self, + block_hashes: List[int], + device: Device, + ) -> List[int]: + return self._allocators[device].get_cached_blocks(block_hashes) def allocate_immutable_block(self, prev_block: Optional[Block], token_ids: List[int], @@ -402,3 +418,6 @@ def last_accessed(self, last_accessed_ts: float): @property def content_hash(self): return self._proxy.content_hash + + def set_content_hash(self, content_hash: Optional[int]) -> None: + raise NotImplementedError("NullBlock does not support set_content_hash") diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 72bbab1dcea5d..11d22eadcca6d 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple +from vllm.sequence import Sequence from vllm.utils import Device BlockId = int @@ -9,7 +10,9 @@ class Block(ABC): @abstractmethod - def append_token_ids(self, token_ids: List[int]) -> None: + def append_token_ids( + self, token_ids: List[int], block_hash: Optional[int] = None + ) -> None: pass @property @@ -95,6 +98,10 @@ def content_hash(self) -> Optional[int]: """ return None + @abstractmethod + def set_content_hash(self, content_hash: int) -> None: + pass + class BlockAllocator(ABC): @@ -103,14 +110,21 @@ def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block: pass @abstractmethod - def allocate_immutable_block(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + def allocate_immutable_block( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_hash: Optional[int] = None, + ) -> Block: pass @abstractmethod def allocate_immutable_blocks( - self, prev_block: Optional[Block], - block_token_ids: List[List[int]]) -> List[Block]: + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + block_hashes: Optional[List[Optional[int]]] = None, + ) -> List[Block]: pass @abstractmethod @@ -189,6 +203,10 @@ def get_prefix_cache_hit_rate(self) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + @abstractmethod + def get_cached_blocks(self, block_hashes: List[int]) -> List[int]: + pass + class NoFreeBlocksError(ValueError): pass @@ -284,3 +302,7 @@ def allocate_or_get_null_block(self) -> Block: def get_prefix_cache_hit_rate(self, device: Device) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + + @abstractmethod + def get_cached_blocks(self, block_hashes: List[int], device: Device) -> List[int]: + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 9341a518d11c6..756d97bee402e 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -60,10 +60,13 @@ def __init__( # a block pool between allocators self._block_pool = block_pool - def allocate_immutable_block(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Optional[Device] = None) -> Block: + def allocate_immutable_block( + self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None, + block_hash: Optional[int] = None, + ) -> Block: """Allocates a new immutable block with the given token IDs, linked to the previous block. @@ -77,15 +80,19 @@ def allocate_immutable_block(self, Block: The newly allocated immutable block. """ assert device is None + assert block_hash is None + block = self.allocate_mutable_block(prev_block=prev_block) block.append_token_ids(token_ids) return block def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Optional[Device] = None) -> List[Block]: + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + block_hashes: Optional[List[Optional[int]]] = None, + device: Optional[Device] = None, + ) -> List[Block]: assert device is None num_blocks = len(block_token_ids) @@ -104,9 +111,12 @@ def allocate_immutable_blocks( return blocks - def allocate_mutable_block(self, - prev_block: Optional[Block], - device: Optional[Device] = None) -> Block: + def allocate_mutable_block( + self, + prev_block: Optional[Block], + device: Optional[Device] = None, + block_hash: Optional[int] = None, + ) -> Block: """Allocates a new mutable block, linked to the previous block. Args: @@ -118,6 +128,8 @@ def allocate_mutable_block(self, Block: The newly allocated mutable block. """ assert device is None + assert block_hash is None + block_id = self._allocate_block_id() block = self._block_pool.init_block(prev_block=prev_block, token_ids=[], @@ -318,7 +330,7 @@ def swap_in(self, blocks: List[Block]) -> None: else: tmp_block = self.allocate_mutable_block( prev_block=block.prev_block) - tmp_block.append_token_ids(block.token_ids) + tmp_block.append_token_ids(block.token_ids, block_hash=None) block_id = tmp_block.block_id tmp_block.block_id = None @@ -329,6 +341,9 @@ def swap_in(self, blocks: List[Block]) -> None: def get_prefix_cache_hit_rate(self) -> float: return -1 + def get_cached_blocks(self, block_hashes: List[int]) -> List[int]: + return [] + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix @@ -368,14 +383,17 @@ def __init__(self, self._append_token_ids_no_cow(token_ids) - def append_token_ids(self, token_ids: List[int]) -> None: - """Appends the given token IDs to the block and performs a + def append_token_ids( + self, token_ids: List[int], block_hash: Optional[int] = None + ) -> None: + """Appends the given token IDs to the block and performs a copy-on-write if necessary. Args: - token_ids (Optional[List[int]]): The token IDs to be appended + token_ids (Optional[List[int]]): The token IDs to be appended to the block. """ + assert block_hash is None self._append_token_ids_no_cow(token_ids) if self._block_id is not None: @@ -447,3 +465,8 @@ def prev_block(self) -> Optional["Block"]: @property def content_hash(self) -> Optional[int]: return None + + def set_content_hash(self, content_hash: int) -> None: + raise NotImplementedError( + "Setting content hash is not supported for naive block" + ) diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 7c8a2bc493513..84330270a203a 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -134,54 +134,70 @@ def _create_block( computed=computed, ) - def allocate_immutable_block(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Optional[Device] = None) -> Block: + def allocate_immutable_block( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_hash: Optional[int] = None, + device: Optional[Device] = None, + ) -> Block: """Allocates an immutable block with the given token IDs, reusing cached blocks if possible. Args: prev_block (Optional[Block]): The previous block in the sequence. token_ids (List[int]): The token IDs to be stored in the block. + block_hash (int): The hash of the block's content. Returns: Block: The allocated immutable block. """ assert device is None + assert len(token_ids) == self._block_size, "An immutable block should be full" + assert ( + block_hash is not None + ), "An immutable block should have a content hash for prefix caching" assert_prefix_caching_block_or_none(prev_block) - # First, try to create a block that points to cached data - block = self._block_pool.init_block(prev_block=prev_block, - token_ids=token_ids, - block_size=self._block_size, - physical_block_id=None) - assert block.content_hash is not None - - cached_block_id = self._cached_blocks.get(block.content_hash, None) + cached_block_id = self._cached_blocks.get(block_hash, None) if cached_block_id is not None: + # Initialize a block that points to cached data + block: Block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=token_ids, + block_size=self._block_size, + physical_block_id=cached_block_id, + ) + block.set_content_hash(block_hash) self.metric_data.query(hit=True) - block.block_id = cached_block_id self._incr_refcount_cached_block(block) return block self.metric_data.query(hit=False) - self._block_pool.free_block(block) # No cached block => Allocate a new block block = self.allocate_mutable_block(prev_block) - block.append_token_ids(token_ids) + block.append_token_ids(token_ids, block_hash=block_hash) return block def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Optional[Device] = None) -> List[Block]: + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + block_hashes: Optional[List[int]] = None, + device: Optional[Device] = None, + ) -> List[Block]: blocks = [] - for token_ids in block_token_ids: - prev_block = self.allocate_immutable_block(prev_block=prev_block, - token_ids=token_ids, - device=device) + assert ( + block_hashes is not None + ), "block_hashes must be provided for immutable prefix cache blocks" + + for token_ids, block_hash in zip(block_token_ids, block_hashes): + prev_block = self.allocate_immutable_block( + prev_block=prev_block, + token_ids=token_ids, + block_hash=block_hash, + device=device, + ) blocks.append(prev_block) return blocks @@ -284,7 +300,7 @@ def _allocate_block_id(self) -> BlockId: def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]: try: - # Allocate mutable block and extract its block_id + # Allocate mutable block and extrct its block_id block = self._hashless_allocator.allocate_mutable_block( prev_block=None) block_id = block.block_id @@ -373,11 +389,14 @@ def fork(self, last_block: Block) -> List[Block]: assert refcount != 1, "can't fork free'd block_id = {}".format( block_id) - forked_block = self._block_pool.init_block( + forked_block: Block = self._block_pool.init_block( prev_block=prev_block, token_ids=block.token_ids, block_size=self._block_size, - physical_block_id=block_id) + physical_block_id=block_id, + ) + + forked_block.set_content_hash(block.content_hash) forked_blocks.append(forked_block) prev_block = forked_blocks[-1] @@ -622,18 +641,35 @@ def swap_in(self, blocks: List[Block]) -> None: # and the block_id is assigned to "block" to allow reusing the # existing "block" object if block.is_full: + assert ( + block.content_hash is not None + ), "Block is full but has no content hash" tmp_block = self.allocate_immutable_block( - prev_block=block.prev_block, token_ids=block.token_ids) + prev_block=block.prev_block, + token_ids=block.token_ids, + block_hash=block.content_hash, + ) else: + assert ( + block.content_hash is None + ), "Block is not full but has content hash" tmp_block = self.allocate_mutable_block( prev_block=block.prev_block) - tmp_block.append_token_ids(block.token_ids) + tmp_block.append_token_ids(block.token_ids, block_hash=None) block_id = tmp_block.block_id self._block_pool.free_block(tmp_block) block.block_id = block_id # Assign block_id + def get_cached_blocks(self, block_hashes: List[PrefixHash]) -> List[PrefixHash]: + # Search for the longest prefix in `block_hashes` that are present cached blocks. + # TODO(rickyx): this could be made to binary search. + for i, block_hash in enumerate(block_hashes): + if block_hash not in self._cached_blocks: + return block_hashes[:i] + return block_hashes + class PrefixCachingBlock(Block): """A block implementation that supports prefix caching. @@ -685,13 +721,16 @@ def __init__( token_ids=token_ids, block_size=block_size, block_id=block_id, - allocator=self._allocator) + allocator=self._allocator, + ) else: - self._block = NaiveBlock(prev_block=prev_block, - token_ids=token_ids, - block_size=block_size, - block_id=block_id, - allocator=self._allocator) + self._block = NaiveBlock( + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator, + ) self._update_num_tokens_total() @@ -726,12 +765,15 @@ def last_accessed(self) -> float: def last_accessed(self, last_accessed_ts: float): self._last_accessed = last_accessed_ts - def append_token_ids(self, token_ids: List[int]) -> None: + def append_token_ids( + self, token_ids: List[int], block_hash: Optional[int] = None + ) -> None: """Appends the given token IDs to the block and registers the block as immutable if the block becomes full. Args: token_ids (List[int]): The token IDs to be appended to the block. + block_hash (Optional[int]): The content hash of the block. None if the block is not full. """ # Ensure this is mutable block (not promoted) assert self.content_hash is None @@ -744,14 +786,13 @@ def append_token_ids(self, token_ids: List[int]) -> None: assert token_ids, "Got token_ids = {}".format(token_ids) # Naive block handles CoW. - self._block.append_token_ids(token_ids) + self._block.append_token_ids(token_ids, block_hash=None) self._update_num_tokens_total() - # If the content hash is present, then the block can be made immutable. - # Register ourselves with the allocator, potentially replacing the - # physical block index. - if self.content_hash is not None: - self.block_id = self._allocator.promote_to_immutable_block(self) + # Promote the block to an immutable block if it is full. + if block_hash is not None: + self.set_content_hash(block_hash) + self._allocator.promote_to_immutable_block(self) @property def block_id(self) -> Optional[int]: @@ -785,38 +826,56 @@ def token_ids(self) -> List[int]: def prev_block(self) -> Optional[Block]: return self._prev_block + # @property + # def content_hash(self) -> Optional[int]: + # """Return the content-based hash of the current block, or None if it is + # not yet defined. + + # For the content-based hash to be defined, the current block must be + # full. + # """ + # # If the hash is already computed, return it. + # if self._cached_content_hash is not None: # return self._cached_content_hash + + # # We cannot compute a hash for the current block because it is not full. + # if not self.is_full: + # return None + + # is_first_block = self._prev_block is None + # prev_block_hash = ( + # None if is_first_block else + # self._prev_block.content_hash # type: ignore + # ) + + # # Previous block exists but does not yet have a hash. + # # Return no hash in this case. + # if prev_block_hash is None and not is_first_block: + # return None + + # self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( + # is_first_block, + # prev_block_hash, + # cur_block_token_ids=self.token_ids) + # return self._cached_content_hash + @property def content_hash(self) -> Optional[int]: - """Return the content-based hash of the current block, or None if it is - not yet defined. + return self._cached_content_hash - For the content-based hash to be defined, the current block must be - full. + def set_content_hash(self, content_hash: Optional[int]) -> None: """ - # If the hash is already computed, return it. - if self._cached_content_hash is not None: - return self._cached_content_hash - - # We cannot compute a hash for the current block because it is not full. - if not self.is_full: - return None - - is_first_block = self._prev_block is None - prev_block_hash = ( - None if is_first_block else - self._prev_block.content_hash # type: ignore - ) - - # Previous block exists but does not yet have a hash. - # Return no hash in this case. - if prev_block_hash is None and not is_first_block: - return None - - self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( - is_first_block, - prev_block_hash, - cur_block_token_ids=self.token_ids) - return self._cached_content_hash + Set the content hash of the block. + """ + assert self.content_hash is None, "Content hash already set" + if content_hash is None: + # This could happen when forking a mutable block. + assert ( + not self.is_full + ), "Block should not be full when new content hash is None" + # No op. + return + assert self.is_full, "Block is not full when setting content hash" + self._cached_content_hash = content_hash @staticmethod def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], @@ -878,6 +937,12 @@ def remove_seq(self, seq_id: int) -> None: assert seq_id in self._cached_computed_seq_blocks del self._cached_computed_seq_blocks[seq_id] + def update_seq(self, seq_id: int, computed_tokens: List[int]): + pass + + def get_cached_computed_blocks(self, seq_id: int) -> List[int]: + pass + def get_cached_computed_blocks_and_update( self, seq_id: int, block_ids: List[int]) -> List[int]: """ Look at the class documentation for details @@ -915,6 +980,8 @@ def get_cached_computed_blocks_and_update( True, # We skip last block id to avoid caching of full seq ) + # QQ(rickyx): why is it possible to actually have a gap? + # Detect if there is a "gap" has_gap = len(computed_block_ids) < num_cur_blocks diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index c7ee6609306d7..8c85bdc05fe64 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -11,7 +11,7 @@ from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device +from vllm.utils import Device, cdiv SeqId = int EncoderSeqId = str @@ -106,6 +106,19 @@ def __init__( self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) + def _get_num_blocks_to_allocate( + self, seq: Sequence, num_lookahead_slots: int = 0 + ) -> int: + seq_blocks = seq.get_block_hashes() + cached_seq_blocks = self.block_allocator.get_cached_blocks( + block_hashes=seq_blocks, + device=Device.GPU, + ) + + num_required_blocks = cdiv(seq.get_len() + num_lookahead_slots, self.block_size) + + return num_required_blocks - len(cached_seq_blocks) + def can_allocate(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> AllocStatus: @@ -115,32 +128,29 @@ def can_allocate(self, check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = BlockTable.get_num_required_blocks( - seq.get_token_ids(), - block_size=self.block_size, - num_lookahead_slots=num_lookahead_slots, + num_blocks_to_allocate = self._get_num_blocks_to_allocate( + seq, num_lookahead_slots ) if seq_group.is_encoder_decoder(): encoder_seq = seq_group.get_encoder_seq() assert encoder_seq is not None - num_required_blocks += BlockTable.get_num_required_blocks( - encoder_seq.get_token_ids(), - block_size=self.block_size, + num_blocks_to_allocate += self._get_num_blocks_to_allocate( + encoder_seq, num_lookahead_slots=0 ) if self.max_block_sliding_window is not None: - num_required_blocks = min(num_required_blocks, - self.max_block_sliding_window) + num_blocks_to_allocate = min( + num_blocks_to_allocate, self.max_block_sliding_window + ) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( device=Device.GPU) # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks < - self.watermark_blocks): + if self.num_total_gpu_blocks - num_blocks_to_allocate < self.watermark_blocks: return AllocStatus.NEVER - if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + if num_free_gpu_blocks - num_blocks_to_allocate >= self.watermark_blocks: return AllocStatus.OK else: return AllocStatus.LATER @@ -150,10 +160,11 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_size=self.block_size, block_allocator=self.block_allocator, max_block_sliding_window=self.max_block_sliding_window, + enable_prefix_caching=self.enable_caching, ) if seq.get_token_ids(): # Add blocks to the block table only if the sequence is non empty. - block_table.allocate(seq.get_token_ids()) + block_table.allocate(seq) return block_table @@ -214,20 +225,22 @@ def can_append_slots(self, seq_group: SequenceGroup, This is used by speculative decoding when speculating future tokens. """ - num_touched_blocks = 0 + num_blocks_to_allocate = 0 for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - block_table = self.block_tables[seq.seq_id] + num_blocks_to_allocate += self._get_num_blocks_to_allocate( + seq, num_lookahead_slots=num_lookahead_slots + ) - num_touched_blocks += ( - block_table.get_num_blocks_touched_by_append_slots( - token_ids=block_table.get_unseen_token_ids( - seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, - )) + # num_touched_blocks += ( + # block_table.get_num_blocks_touched_by_append_slots( + # token_ids=block_table.get_unseen_token_ids( + # seq.get_token_ids()), + # num_lookahead_slots=num_lookahead_slots, + # )) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( Device.GPU) - return num_touched_blocks <= num_free_gpu_blocks + return num_blocks_to_allocate <= num_free_gpu_blocks def append_slots( self, @@ -237,10 +250,11 @@ def append_slots( block_table = self.block_tables[seq.seq_id] - block_table.append_token_ids( - token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + # Extend the block table for any new decoded tokens, as well as reserve + # space for lookahead slots. + block_table.append_slots( + seq=seq, num_lookahead_slots=num_lookahead_slots, - num_computed_slots=seq.data.get_num_computed_tokens(), ) # Return any new copy-on-writes. new_cows = self.block_allocator.clear_copy_on_writes() diff --git a/vllm/sequence.py b/vllm/sequence.py index 9116408a001ff..65855e48cc16c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -447,6 +447,8 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + self._computed_block_hashes: List[int] = [] + @property def n_blocks(self) -> int: return (self.get_len() + self.block_size - 1) // self.block_size @@ -530,6 +532,61 @@ def get_output_token_ids_to_return( return self.data._cached_all_token_ids[-num_new_tokens:] + def get_block_hash(self, block_idx: int) -> Optional[int]: + + # Lazy update the block hashes on the first invocation. + if block_idx >= len(self._computed_block_hashes): + self._update_block_hashes() + + if block_idx < len(self._computed_block_hashes): + return self._computed_block_hashes[block_idx] + return None + + def get_block_hashes(self) -> List[int]: + # TODO(rickyx): maybe better to have an API to track if the computed hash is updated. + self._update_block_hashes() + return self._computed_block_hashes + + def _update_block_hashes(self): + """ + Update the block hashes for all the full blocks in the sequence. + + It skips the blocks that have already been computed. + """ + token_ids = self.get_token_ids() # All token ids in the sequence + num_full_blocks = len(token_ids) // self.block_size + cur_num_full_blocks = len(self._computed_block_hashes) + prev_block_hash = ( + None if cur_num_full_blocks == 0 else self._computed_block_hashes[-1] + ) + for i in range(cur_num_full_blocks, num_full_blocks): + block_token_ids = token_ids[i * self.block_size : (i + 1) * self.block_size] + assert len(block_token_ids) == self.block_size + block_hash = hash( + ( + prev_block_hash, # Previous block hash + self.from_decoder_prompt, # Whether the sequence is decoder-only + # LoRA int id since the attention output will depend on + # LoRA with same token ids. + self.lora_int_id, + *block_token_ids, # The block token ids + ) + ) + self._computed_block_hashes.append(block_hash) + prev_block_hash = block_hash + + def _reset_block_hashes(self): + """ + Clear all the block hashes from output tokens. The full blocks for the + prompt tokens should not be cleared. + + This is used when the sequence is recomputed. + """ + num_full_prompt_blocks = self.get_prompt_len() // self.block_size + self._computed_block_hashes = self._computed_block_hashes[ + num_full_prompt_blocks: + ] + def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size @@ -546,6 +603,7 @@ def num_hashed_tokens_of_block(self, logical_idx: int): def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() + self._reset_block_hashes() def append_token_id(self, token_id: int, logprobs: Dict[int, Logprob]) -> None: @@ -553,6 +611,10 @@ def append_token_id(self, token_id: int, logprobs: Dict[int, self.output_logprobs.append(logprobs) self.data.append_token_id(token_id, logprobs[token_id].logprob) + def update_num_computed_tokens(self, num_tokens: int): + self.data.update_num_computed_tokens(num_tokens) + self._update_block_hashes() + def get_len(self) -> int: return self.data.get_len() @@ -837,7 +899,7 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" for seq in self.seqs: if not seq.is_finished(): - seq.data.update_num_computed_tokens(num_new_computed_tokens) + seq.update_num_computed_tokens(num_new_computed_tokens) def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 From add282a8a3a7468ff4d98873d7d261b176bf4341 Mon Sep 17 00:00:00 2001 From: rickyx Date: Fri, 11 Oct 2024 22:41:58 +0000 Subject: [PATCH 02/12] staging --- benchmarks/benchmark_prefix_caching.py | 24 ++- tests/core/block/test_block_manager_v2.py | 137 ++++++++++++++ tests/core/block/test_prefix_caching_block.py | 70 +++++++ tests/core/utils.py | 21 ++- vllm/config.py | 4 + vllm/core/block/block_table.py | 11 +- vllm/core/block/cpu_gpu_block_allocator.py | 4 +- vllm/core/block/interfaces.py | 6 +- vllm/core/block/naive_block.py | 2 +- vllm/core/block/prefix_caching_block.py | 74 ++++++-- vllm/core/block_manager_v2.py | 125 ++++++++++--- vllm/core/scheduler.py | 174 ++++++++++++++---- vllm/engine/llm_engine.py | 15 +- vllm/engine/metrics.py | 5 +- vllm/engine/metrics_types.py | 1 + vllm/sequence.py | 23 ++- 16 files changed, 591 insertions(+), 105 deletions(-) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index eeb43a692076e..da668982579b2 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -115,8 +115,9 @@ def main(args): input_length_range = tuple(map(int, args.input_length_range.split(':'))) random.seed(args.seed) if args.dataset_path is not None: - print(f"Start to sample {args.num_prompts} prompts" - "from {args.dataset_path}") + print( + f"Start to sample {args.num_prompts} prompts " f"from {args.dataset_path}" + ) filtered_datasets = sample_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, @@ -129,13 +130,18 @@ def main(args): filtered_datasets = [(PROMPT, prompt_len, args.output_len) ] * args.num_prompts - llm = LLM(model=args.model, - tokenizer_mode='auto', - trust_remote_code=True, - enforce_eager=True, - use_v2_block_manager=args.use_v2_block_manager, - tensor_parallel_size=args.tensor_parallel_size, - enable_prefix_caching=args.enable_prefix_caching) + llm = LLM( + model=args.model, + tokenizer_mode="auto", + trust_remote_code=True, + enforce_eager=True, + use_v2_block_manager=args.use_v2_block_manager, + tensor_parallel_size=args.tensor_parallel_size, + enable_prefix_caching=args.enable_prefix_caching, + disable_log_stats=False, + max_num_batched_tokens=4096 * 2, + enable_chunked_prefill=True, + ) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index c4c41e714a02d..e8b87a89f70ee 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -1,3 +1,4 @@ +import math import pytest from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, @@ -205,6 +206,142 @@ def test_can_allocate_encoder_decoder_fails_with_prefix_cache( assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE +@pytest.mark.parametrize("block_size", [1, 4]) +@pytest.mark.parametrize("num_prefill_tokens", [1, 2, 4, 5, 6, 8, 10]) +@pytest.mark.parametrize("prefix_shared_percentage", [0.0, 0.3, 0.5, 0.7, 1.0]) +def test_can_allocate_with_prefix_cache( + block_size: int, + num_prefill_tokens: int, + prefix_shared_percentage: float, +): + num_seqs_fittable = 1.5 + num_blocks_required_seq = math.ceil(num_prefill_tokens / block_size) + num_gpu_blocks = math.ceil(num_seqs_fittable * num_blocks_required_seq) + + num_tokens_shared = int(num_prefill_tokens * prefix_shared_percentage) + num_blocks_shared = num_tokens_shared // block_size + + tokens_1 = list(range(num_prefill_tokens)) + tokens_2 = tokens_1[:num_tokens_shared] + [ + t + 10 for t in tokens_1[num_tokens_shared:] + ] + + print(f"tokens_1: {tokens_1}") + print(f"tokens_2: {tokens_2}") + print(f"num_blocks_shared: {num_blocks_shared}") + print(f"num_blocks_required_seq: {num_blocks_required_seq}") + print(f"num_gpu_blocks: {num_gpu_blocks}") + + # Num blocks needed for 2 seqs, minus the number of blocks shared. + num_blocks_required_with_sharing = 2 * num_blocks_required_seq - num_blocks_shared + + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + enable_caching=True, # Prefix cache + ) + + seq_group_1 = create_seq_group( + seq_output_lens=[0], + request_id="0", + seq_id_start=0, + prompt_token_ids=tokens_1, + block_size=block_size, + ) + assert block_manager.can_allocate(seq_group_1) == AllocStatus.OK + # Allocate the seq 1 + block_manager.allocate(seq_group_1) + + # Test if allocatable of seq 2. + seq_group_2 = create_seq_group( + seq_output_lens=[0], + request_id="1", + seq_id_start=1, + prompt_token_ids=tokens_2, + block_size=block_size, + ) + if num_blocks_required_with_sharing <= num_gpu_blocks: + assert block_manager.can_allocate(seq_group_2) == AllocStatus.OK + block_manager.allocate(seq_group_2) + else: + assert block_manager.can_allocate(seq_group_2) == AllocStatus.LATER + + +@pytest.mark.skip(reason="Not correct yet") +@pytest.mark.parametrize("block_size", [1, 4]) +@pytest.mark.parametrize("num_prefill_tokens", [1, 2, 4, 5, 6, 8, 10]) +@pytest.mark.parametrize("prefix_shared_percentage", [0.0, 0.3, 0.5, 0.7, 1.0]) +def test_can_append_with_prefix_cache( + block_size: int, + num_prefill_tokens: int, + prefix_shared_percentage: float, +): + num_seqs_allocable = 1.5 + + num_blocks_required_seq_1 = math.ceil(num_prefill_tokens / block_size) + num_gpu_blocks = math.ceil(num_blocks_required_seq_1 * num_seqs_allocable) + + num_tokens_shared = int( + 1 + (num_prefill_tokens - 1) * prefix_shared_percentage + ) # We will always share the first token. + num_blocks_shared = num_tokens_shared // block_size + + tokens_1 = list(range(num_prefill_tokens)) + tokens_2 = tokens_1[:num_tokens_shared] + [ + t + 10 for t in tokens_1[num_tokens_shared:] + ] + + print(f"tokens_1: {tokens_1}") + print(f"tokens_2: {tokens_2}") + print(f"num_blocks_shared: {num_blocks_shared}") + print(f"num_blocks_required_seq_1: {num_blocks_required_seq_1}") + print(f"num_gpu_blocks: {num_gpu_blocks}") + + num_blocks_required_with_sharing = 2 * num_blocks_required_seq_1 - num_blocks_shared + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + enable_caching=True, # Prefix cache + ) + + # Allocate seq 1. + seq_group_1 = create_seq_group( + seq_output_lens=[0], + request_id="0", + seq_id_start=0, + prompt_token_ids=tokens_1, + block_size=block_size, + ) + assert block_manager.can_allocate(seq_group_1) == AllocStatus.OK + block_manager.allocate(seq_group_1) + + # Allocate seq 2. + seq_group_2 = create_seq_group( + seq_output_lens=[0], + request_id="1", + seq_id_start=1, + prompt_token_ids=tokens_2[:1], # Just one token for prefill. + block_size=block_size, + ) + assert block_manager.can_allocate(seq_group_2) == AllocStatus.OK + block_manager.allocate(seq_group_2) + + # Test if append is possible. + seq = seq_group_2.get_seqs()[0] + for token_id in tokens_2[1:]: + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + + seq.update_num_computed_tokens(len(tokens_2[1:])) + seq.status = SequenceStatus.RUNNING + + if num_blocks_required_with_sharing <= num_gpu_blocks: + assert block_manager.can_append_slots(seq_group_2, 0) + else: + assert not block_manager.can_append_slots(seq_group_2, 0) + + @pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 1a6e17ef7b445..8c8690bc23d1b 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -764,3 +764,73 @@ def create_immutable_chain( blocks.append(prev_block) return blocks + + @staticmethod + def create_immutable_chain_with_hashes( + block_size: int, + block_hashes: List[int], + allocator: PrefixCachingBlockAllocator, + ) -> List[PrefixCachingBlock]: + """Helper method which creates a chain of blocks with explicit hashes.""" + blocks: List[Block] = [] + + prev_block = None + for block_hash in block_hashes: + # We don't need actual token_ids, so we pass an empty list + prev_block = allocator.allocate_immutable_block( + prev_block=prev_block, + token_ids=list(range(block_size)), # Just placeholder + block_hash=block_hash, + ) + blocks.append(prev_block) + + return blocks + + @staticmethod + def test_get_cached_blocks(): + """ + This test checks that `get_cached_blocks` returns the correct prefix of + block hashes for a given list of block hashes. + """ + + block_size = 16 + num_blocks = 5 + allocator = PrefixCachingBlockAllocator( + block_size=block_size, num_blocks=num_blocks + ) + + # 1. Allocate a list of blocks + block_hashes = [random.randint(1, 1000000) for _ in range(num_blocks)] + blocks = TestPrefixCachingBlockAllocator.create_immutable_chain_with_hashes( + block_size=block_size, + block_hashes=block_hashes, + allocator=allocator, + ) + + # Verify that all blocks have been allocated + assert len(blocks) == num_blocks + assert all(isinstance(block, PrefixCachingBlock) for block in blocks) + assert all(block.content_hash is not None for block in blocks) + + # 2. Test different prefixes of cached blocks + test_cases = [ + ([], []), # No blocks cached + ([block_hashes[0]], [block_hashes[0]]), # First block cached + (block_hashes[:3], block_hashes[0:3]), # First three blocks cached + (block_hashes, block_hashes), # All blocks cached + ] + + for cached_hashes, expected_cached_blocks in test_cases: + # Check if get_cached_blocks returns the correct prefix + result = allocator.get_cached_blocks(cached_hashes) + assert ( + result == expected_cached_blocks + ), f"Expected {expected_cached_blocks}, but got {result}, with test case {cached_hashes}. blcok hashes = {block_hashes}" + + # Test with some non-existent hashes + non_existent_hash = max(block_hashes) + 1 + test_hashes = block_hashes[:3] + [non_existent_hash] + block_hashes[3:] + result = allocator.get_cached_blocks(test_hashes) + assert ( + result == block_hashes[0:3] + ), f"Expected {block_hashes[0:3]}, but got {result}" diff --git a/tests/core/utils.py b/tests/core/utils.py index a2c2df029cb19..6a539db9fc3f7 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -115,25 +115,28 @@ def create_dummy_prompt_encoder_decoder( def create_seq_group( - seq_prompt_len: int = 1024, - seq_output_lens: GenericSequence[int] = (128, ), - request_id: str = '0', - seq_id_start: int = 0, - sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: - - assert len(seq_output_lens) > 0 + seq_prompt_len: int = 1024, + seq_output_lens: GenericSequence[int] = (128,), + request_id: str = "0", + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None, + prompt_token_ids: Optional[List[int]] = None, + block_size: int = 16, +) -> SequenceGroup: if sampling_params is None: sampling_params = SamplingParams() - prompt_token_ids = [0] * seq_prompt_len + if prompt_token_ids is None: + assert seq_prompt_len > 0 + prompt_token_ids = [0] * seq_prompt_len seqs: List[Sequence] = [] for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, inputs={"prompt_token_ids": prompt_token_ids}, - block_size=16, + block_size=block_size, ) for i in range(output_len): diff --git a/vllm/config.py b/vllm/config.py index 7b3996dc90b94..fbc18bd972a27 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1008,6 +1008,10 @@ def __init__(self, self.policy = policy self._verify_args() + print( + f"max_num_batched_tokens: {self.max_num_batched_tokens}, max_num_seqs: {self.max_num_seqs}" + ) + def _verify_args(self) -> None: if (self.max_num_batched_tokens < self.max_model_len and not self.chunked_prefill_enabled): diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 64267e09e1316..9a3a66523bd34 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional +from typing import Dict, List, Optional from vllm.core.block.common import BlockList from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator @@ -308,9 +308,12 @@ def _allocate_blocks_for_token_ids( assert len(tail_token_ids) == 1 assert block_hashes[-1] is None cur_token_ids = tail_token_ids[0] - - block = self._allocator.allocate_mutable_block( - prev_block=prev_block, device=device) + try: + block = self._allocator.allocate_mutable_block( + prev_block=prev_block, device=device + ) + except Exception as e: + breakpoint() block.append_token_ids(cur_token_ids, block_hash=None) blocks.append(block) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index d2d3ee61a1dde..9294e668a07bf 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -159,12 +159,12 @@ def allocate_immutable_blocks( prev_block, block_token_ids, block_hashes ) - def get_cached_blocks( + def get_allocated_cached_blocks( self, block_hashes: List[int], device: Device, ) -> List[int]: - return self._allocators[device].get_cached_blocks(block_hashes) + return self._allocators[device].get_allocated_cached_blocks(block_hashes) def allocate_immutable_block(self, prev_block: Optional[Block], token_ids: List[int], diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 11d22eadcca6d..378d273a271ab 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -204,7 +204,7 @@ def get_prefix_cache_hit_rate(self) -> float: pass @abstractmethod - def get_cached_blocks(self, block_hashes: List[int]) -> List[int]: + def get_allocated_cached_blocks(self, block_hashes: List[int]) -> List[int]: pass class NoFreeBlocksError(ValueError): @@ -304,5 +304,7 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: pass @abstractmethod - def get_cached_blocks(self, block_hashes: List[int], device: Device) -> List[int]: + def get_allocated_cached_blocks( + self, block_hashes: List[int], device: Device + ) -> List[int]: pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 756d97bee402e..2ffdb94e904a0 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -341,7 +341,7 @@ def swap_in(self, blocks: List[Block]) -> None: def get_prefix_cache_hit_rate(self) -> float: return -1 - def get_cached_blocks(self, block_hashes: List[int]) -> List[int]: + def get_allocated_cached_blocks(self, block_hashes: List[int]) -> List[int]: return [] diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 84330270a203a..84d66cca4dd46 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -8,6 +8,7 @@ from vllm.core.block.naive_block import (BlockPool, NaiveBlock, NaiveBlockAllocator) from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor +from vllm.sequence import Sequence PrefixHash = int @@ -162,6 +163,9 @@ def allocate_immutable_block( cached_block_id = self._cached_blocks.get(block_hash, None) if cached_block_id is not None: # Initialize a block that points to cached data + # print( + # f"reuse block_hash={block_hash} from cached_block_id: {cached_block_id}" + # ) block: Block = self._block_pool.init_block( prev_block=prev_block, token_ids=token_ids, @@ -172,6 +176,10 @@ def allocate_immutable_block( self.metric_data.query(hit=True) self._incr_refcount_cached_block(block) return block + + # print( + # f"alloc from new block(block_hash: {block_hash}), get_num_free_blocks: {self.get_num_free_blocks()}" + # ) self.metric_data.query(hit=False) # No cached block => Allocate a new block @@ -216,7 +224,9 @@ def allocate_mutable_block(self, """ assert device is None assert_prefix_caching_block_or_none(prev_block) - + # print( + # f"Allocating mutable block: get_num_free_blocks: {self.get_num_free_blocks()}" + # ) block_id = self._allocate_block_id() block = self._block_pool.init_block(prev_block=prev_block, token_ids=[], @@ -287,6 +297,7 @@ def _allocate_block_id(self) -> BlockId: """First tries to allocate a block id from the hashless allocator, and if there are no blocks, then tries to evict an unused cached block. """ + # print(f"allocating block_id: get_num_free_blocks: {self.get_num_free_blocks()}") hashless_block_id = self._maybe_allocate_hashless_block_id() if hashless_block_id is not None: return hashless_block_id @@ -407,8 +418,9 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int: assert device is None # The number of free blocks is the number of hashless free blocks # plus the number of blocks evictor could free from its list. - return self._hashless_allocator.get_num_free_blocks( - ) + self.evictor.num_blocks + return self._hashless_allocator.get_num_free_blocks() + ( + self.evictor.num_blocks + ) def get_num_total_blocks(self) -> int: return self._hashless_allocator.get_num_total_blocks() @@ -499,6 +511,9 @@ def cow_block_if_not_appendable(self, block: Block) -> BlockId: return src_block_id self._free_block_id(block) + # print( + # f"Allocating block for COW: get_num_free_blocks: {self.get_num_free_blocks()}" + # ) trg_block_id = self._allocate_block_id() self._cow_tracker.record_cow(src_block_id, trg_block_id) @@ -662,13 +677,50 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id - def get_cached_blocks(self, block_hashes: List[PrefixHash]) -> List[PrefixHash]: + def get_allocated_cached_blocks(self, block_hashes: List[PrefixHash]) -> List[PrefixHash]: + """ + Get the list of blocks that are already computed and allocated so that they can + be shared by multiple sequences, and no needed to be allocated again. + + INVARIANCE: + For a sequence of blocks, it's also guaranteed that if a block is allocated (i.e. + block_is_active_computed(block_hash) == True), then the previous block must also be + allocated (i.e. block_is_active_computed(prev_block_hash) == True). + + This is because we allocate and free entire sequence of blocks atomically (no partial + sequence is allocated or freed). Therefore, because block hash includes the previous + block's hash, if a current block is allocated, this means the previous block must also + be allocated. + + NOTE: we exclude computed blocks in evictor because they are already freed even if they + are cached. They would still have to be allocated by a sequence. If not, consider a + scenario with a seqeuence of 3 token blocks, and a block pool of only 2 blocks: + [b0, b1, b2], where b0 and b1 are computed but evicted, b2 is not computed. + So b0, b1 are the 2 free blocks, in evictor. + When deciding how many more blocks need to be allocated for this sequence, it should be + all 3 blocks (b0, b1, b2) rather than just 1 block (b2). + """ # Search for the longest prefix in `block_hashes` that are present cached blocks. - # TODO(rickyx): this could be made to binary search. - for i, block_hash in enumerate(block_hashes): + def block_is_active_computed(block_hash: PrefixHash) -> bool: if block_hash not in self._cached_blocks: - return block_hashes[:i] - return block_hashes + return False + + cached_block_id = self._cached_blocks[block_hash] + if cached_block_id in self.evictor: + return False + + # We only consider the blocks that are marked as computed. + if not self._block_tracker[cached_block_id].computed: + return False + + return True + + from bisect import bisect_left + + idx = bisect_left( + block_hashes, True, key=lambda x: not block_is_active_computed(x) + ) + return block_hashes[:idx] class PrefixCachingBlock(Block): @@ -937,12 +989,6 @@ def remove_seq(self, seq_id: int) -> None: assert seq_id in self._cached_computed_seq_blocks del self._cached_computed_seq_blocks[seq_id] - def update_seq(self, seq_id: int, computed_tokens: List[int]): - pass - - def get_cached_computed_blocks(self, seq_id: int) -> List[int]: - pass - def get_cached_computed_blocks_and_update( self, seq_id: int, block_ids: List[int]) -> List[int]: """ Look at the class documentation for details diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 8c85bdc05fe64..5d2f226465b07 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -109,15 +109,69 @@ def __init__( def _get_num_blocks_to_allocate( self, seq: Sequence, num_lookahead_slots: int = 0 ) -> int: + num_cached_tokens = seq.get_num_cached_tokens() + + assert ( + num_cached_tokens % self.block_size == 0 + ), "Cached tokens must be a multiple of block size" + num_cached_blocks = cdiv(num_cached_tokens, self.block_size) + num_required_blocks = cdiv(seq.get_len() + num_lookahead_slots, self.block_size) + + return num_required_blocks - num_cached_blocks + + def get_num_computed_tokens(self, seq: Sequence) -> int: seq_blocks = seq.get_block_hashes() - cached_seq_blocks = self.block_allocator.get_cached_blocks( + cached_seq_blocks = self.block_allocator.get_allocated_cached_blocks( block_hashes=seq_blocks, device=Device.GPU, ) + return len(cached_seq_blocks) * self.block_size - num_required_blocks = cdiv(seq.get_len() + num_lookahead_slots, self.block_size) + # def get_num_computed_blocks(self, seq_group: SequenceGroup) -> Dict[SeqId, int]: + # num_computed_blocks = {} + # for seq in seq_group.get_seqs(): + # num_computed_blocks[seq.seq_id] = self._get_num_computed_tokens(seq) + # return num_computed_blocks + + def can_allocate_old( + self, seq_group: SequenceGroup, num_lookahead_slots: int = 0 + ) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - return num_required_blocks - len(cached_seq_blocks) + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + num_required_blocks = BlockTable.get_num_required_blocks( + seq.get_token_ids(), + block_size=self.block_size, + num_lookahead_slots=num_lookahead_slots, + ) + + if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + num_required_blocks += BlockTable.get_num_required_blocks( + encoder_seq.get_token_ids(), + block_size=self.block_size, + ) + + if self.max_block_sliding_window is not None: + num_required_blocks = min( + num_required_blocks, self.max_block_sliding_window + ) + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + device=Device.GPU + ) + + # Use watermark to avoid frequent cache eviction. + if self.num_total_gpu_blocks - num_required_blocks < self.watermark_blocks: + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER def can_allocate(self, seq_group: SequenceGroup, @@ -145,9 +199,16 @@ def can_allocate(self, ) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - device=Device.GPU) + device=Device.GPU + ) + # print(f"num_blocks_to_allocate: {num_blocks_to_allocate}") + # print(f"num_free_gpu_blocks: {num_free_gpu_blocks}") + # print(f"watermark_blocks: {self.watermark_blocks}") # Use watermark to avoid frequent cache eviction. + # old_can_allocate = self.can_allocate_old(seq_group, num_lookahead_slots) + + can_allocate = None if self.num_total_gpu_blocks - num_blocks_to_allocate < self.watermark_blocks: return AllocStatus.NEVER if num_free_gpu_blocks - num_blocks_to_allocate >= self.watermark_blocks: @@ -155,6 +216,11 @@ def can_allocate(self, else: return AllocStatus.LATER + # if old_can_allocate != can_allocate: + # print(f"old_can_allocate: {old_can_allocate}, can_allocate: {can_allocate}") + + return can_allocate + def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_table = BlockTable( block_size=self.block_size, @@ -225,22 +291,21 @@ def can_append_slots(self, seq_group: SequenceGroup, This is used by speculative decoding when speculating future tokens. """ - num_blocks_to_allocate = 0 + num_touched_blocks = 0 + # TODO(rickyx): Potentially there's could be cached blocks reuse here. i.e + # The newly appended tokens might create one or more full blocks, which + # could be reused from the cache. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - num_blocks_to_allocate += self._get_num_blocks_to_allocate( - seq, num_lookahead_slots=num_lookahead_slots + block_table = self.block_tables[seq.seq_id] + num_touched_blocks += block_table.get_num_blocks_touched_by_append_slots( + token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, ) - # num_touched_blocks += ( - # block_table.get_num_blocks_touched_by_append_slots( - # token_ids=block_table.get_unseen_token_ids( - # seq.get_token_ids()), - # num_lookahead_slots=num_lookahead_slots, - # )) - num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - Device.GPU) - return num_blocks_to_allocate <= num_free_gpu_blocks + device=Device.GPU + ) + return num_touched_blocks <= num_free_gpu_blocks def append_slots( self, @@ -329,11 +394,23 @@ def get_common_computed_block_ids( """ computed_seq_block_ids = [] for seq in seqs: - computed_seq_block_ids.append( - self._computed_blocks_tracker. - get_cached_computed_blocks_and_update( - seq.seq_id, - self.block_tables[seq.seq_id].physical_block_ids)) + all_blocks = self.block_tables[seq.seq_id].physical_block_ids + num_cached_token = seq.get_num_cached_tokens() + assert num_cached_token % self.block_size == 0 + num_cached_block = num_cached_token // self.block_size + computed_block_ids = all_blocks[:num_cached_block] + computed_seq_block_ids.append(computed_block_ids) + + # old_computed_block_ids = ( + # self._computed_blocks_tracker.get_cached_computed_blocks_and_update( + # seq.seq_id, all_blocks + # ) + # ) + # if old_computed_block_ids != computed_block_ids: + # print( + # f"old_computed_block_ids: \n{old_computed_block_ids}\n, computed_block_ids: \n{computed_block_ids}\n" + # ) + # print(f"seq: {seq}") # NOTE(sang): This assumes seq_block_ids doesn't contain any None. return self.block_allocator.get_common_computed_block_ids( @@ -513,8 +590,10 @@ def _can_swap(self, if self.block_allocator.get_num_total_blocks( device) < num_blocks_touched: return AllocStatus.NEVER - elif self.block_allocator.get_num_free_blocks( - device) - num_blocks_touched >= watermark_blocks: + elif ( + self.block_allocator.get_num_free_blocks(device=device) - num_blocks_touched + >= watermark_blocks + ): return AllocStatus.OK else: return AllocStatus.LATER diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c57e6cd716405..e25af4a501748 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -56,6 +56,7 @@ class SchedulingBudget: _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) _num_batched_tokens: int = 0 + _num_batched_and_cached_tokens: int = 0 _num_curr_seqs: int = 0 def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): @@ -67,18 +68,33 @@ def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): def remaining_token_budget(self): return self.token_budget - self.num_batched_tokens - def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + def add_num_batched_tokens( + self, + req_id: str, + num_batched_tokens: int, + num_batched_and_cached_tokens: Optional[int] = None, + ): if req_id in self._request_ids_num_batched_tokens: return self._request_ids_num_batched_tokens.add(req_id) self._num_batched_tokens += num_batched_tokens + if num_batched_and_cached_tokens is None: + num_batched_and_cached_tokens = num_batched_tokens + self._num_batched_and_cached_tokens += num_batched_and_cached_tokens - def subtract_num_batched_tokens(self, req_id: str, - num_batched_tokens: int): + def subtract_num_batched_tokens( + self, + req_id: str, + num_batched_tokens: int, + num_batched_and_cached_tokens: Optional[int] = None, + ): if req_id in self._request_ids_num_batched_tokens: self._request_ids_num_batched_tokens.remove(req_id) self._num_batched_tokens -= num_batched_tokens + if num_batched_and_cached_tokens is None: + num_batched_and_cached_tokens = num_batched_tokens + self._num_batched_and_cached_tokens -= num_batched_and_cached_tokens def add_num_seqs(self, req_id: str, num_curr_seqs: int): if req_id in self._request_ids_num_curr_seqs: @@ -96,6 +112,10 @@ def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): def num_batched_tokens(self): return self._num_batched_tokens + @property + def num_batched_and_cached_tokens(self): + return self._num_batched_and_cached_tokens + @property def num_curr_seqs(self): return self._num_curr_seqs @@ -120,6 +140,7 @@ class SchedulerOutputs: num_prefill_groups: int # Total number of batched tokens. num_batched_tokens: int + num_batched_tokens_from_budget: int # Blocks to swap in. List of CPU -> GPU block number. blocks_to_swap_in: List[Tuple[int, int]] # Blocks to swap out. List of GPU -> CPU block number. @@ -614,6 +635,9 @@ def _schedule_running( self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() + if seq_group.is_prefill() and not enable_chunking: + breakpoint() + scheduled_seq_group: ScheduledSequenceGroup = \ self._scheduled_seq_group_cache[self.cache_id].get_object() scheduled_seq_group.seq_group = seq_group @@ -807,17 +831,18 @@ def _schedule_priority_preemption( SequenceStatus.WAITING, False, budget) - #Only preempt if priority inversion exists + # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): - #Only preempt if waiting sequence cannot be allocated + # Only preempt if waiting sequence cannot be allocated + assert False can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens and can_allocate == AllocStatus.OK and budget.can_schedule(num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs)): break - #Adjust budget to remove the victim sequence group + # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens = self._get_num_new_tokens( vseq_group, SequenceStatus.RUNNING, False, budget) @@ -827,12 +852,12 @@ def _schedule_priority_preemption( budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) - #Preempt out the victim sequence group + # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out, PreemptionMode.RECOMPUTE) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 - #Put the sequence back into the waiting queue + # Put the sequence back into the waiting queue waiting_queue.appendleft(seq_group) waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) @@ -883,9 +908,18 @@ def _schedule_prefills( assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - enable_chunking, budget) + # self._update_prefix_cached_tokens(waiting_seqs[0]) + num_new_tokens = self._get_num_new_tokens( + seq_group, + SequenceStatus.WAITING, + enable_chunking, + budget, + ) + + num_new_tokens_exclude_cached = self._get_num_new_tokens_exclude_cached( + num_new_tokens, waiting_seqs[0] + ) + if not enable_chunking: num_prompt_tokens = waiting_seqs[0].get_len() assert num_new_tokens == num_prompt_tokens @@ -909,6 +943,16 @@ def _schedule_prefills( # If the sequence group cannot be allocated, stop. can_allocate = self.block_manager.can_allocate( seq_group, num_lookahead_slots=num_lookahead_slots) + + old_can_allocate = self.block_manager.can_allocate_old( + seq_group, num_lookahead_slots=num_lookahead_slots + ) + + if can_allocate != old_can_allocate: + print( + f"can_allocate: {can_allocate}, old_can_allocate: {old_can_allocate}" + ) + if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: @@ -937,16 +981,28 @@ def _schedule_prefills( continue num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + if num_new_tokens == 0: + # TODO(rickyx): this could be made earlier? + # No more new tokens to schedule. + break + + assert num_new_tokens > 0 + # We have new tokens but they might be cached. + if num_new_tokens_exclude_cached > 0 and not budget.can_schedule( + num_new_tokens=num_new_tokens_exclude_cached, + num_new_seqs=num_new_seqs, + ): + # No more budget for new tokens. break # Can schedule this request. if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) waiting_queue.popleft() - self._allocate_and_set_running(seq_group) + try: + self._allocate_and_set_running(seq_group) + except Exception as e: + breakpoint() if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] @@ -968,7 +1024,11 @@ def _schedule_prefills( seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) - budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_exclude_cached, + num_batched_and_cached_tokens=num_new_tokens, + ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) # Queue requests that couldn't be scheduled. @@ -1055,6 +1115,11 @@ def _schedule_default(self) -> SchedulerOutputs: # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. + if len(running_scheduled.prefill_seq_groups) > 0: + print( + f"running_scheduled.prefill_seq_groups: {running_scheduled.prefill_seq_groups}" + ) + breakpoint() assert len(running_scheduled.prefill_seq_groups) == 0 assert len(swapped_in.prefill_seq_groups) == 0 @@ -1076,7 +1141,8 @@ def _schedule_default(self) -> SchedulerOutputs: return SchedulerOutputs( scheduled_seq_groups=scheduled_seq_groups, num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens, + num_batched_tokens=budget.num_batched_and_cached_tokens, + num_batched_tokens_from_budget=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_copy=blocks_to_copy, @@ -1149,25 +1215,30 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) return SchedulerOutputs( - scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.prefill_seq_groups + - swapped_in.prefill_seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups), - num_prefill_groups=(len(prefills.seq_groups) + - len(swapped_in.prefill_seq_groups) + - len(running_scheduled.prefill_seq_groups)), - num_batched_tokens=budget.num_batched_tokens, + scheduled_seq_groups=( + prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups + ), + num_prefill_groups=( + len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups) + ), + num_batched_tokens=budget.num_batched_and_cached_tokens, + num_batched_tokens_from_budget=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy + - swapped_in.blocks_to_copy, - ignored_seq_groups=prefills.ignored_seq_groups + - swapped_in.infeasible_seq_groups, + blocks_to_copy=running_scheduled.blocks_to_copy + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), - preempted=(len(running_scheduled.preempted) + - len(running_scheduled.swapped_out)), + preempted=( + len(running_scheduled.preempted) + len(running_scheduled.swapped_out) + ), ) def _schedule(self) -> SchedulerOutputs: @@ -1477,7 +1548,7 @@ def _preempt( else: preemption_mode = PreemptionMode.RECOMPUTE - if self.num_cumulative_preemption % 50 == 0: + if self.num_cumulative_preemption % 5 == 0: logger.warning( "Sequence group %s is preempted by %s mode because there is " "not enough KV cache space. This can affect the end-to-end " @@ -1584,9 +1655,13 @@ def _get_num_lookahead_slots(self, is_prefill: bool, return self.scheduler_config.num_lookahead_slots - def _get_num_new_tokens(self, seq_group: SequenceGroup, - status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> int: + def _get_num_new_tokens( + self, + seq_group: SequenceGroup, + status: SequenceStatus, + enable_chunking: bool, + budget: SchedulingBudget, + ) -> int: """Get the next new tokens to compute for a given sequence group that's in a given `status`. @@ -1600,6 +1675,7 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, num_new_tokens = 0 seqs = seq_group.get_seqs(status=status) for seq in seqs: + self._update_prefix_cached_tokens(seq) num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 # Chunk if a running request cannot fit in the given budget. @@ -1645,3 +1721,29 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens + + def _update_prefix_cached_tokens(self, seq: Sequence): + num_prefix_cached_tokens = self.block_manager.get_num_computed_tokens(seq) + seq.set_num_prefix_cached_tokens(num_prefix_cached_tokens) + + def _get_num_new_tokens_exclude_cached( + self, num_new_tokens: int, seq: Sequence + ) -> int: + + # If a decode sequence, new tokens are always not computed/cached. + if not seq.is_prefill(): + return num_new_tokens + + # If a prefill sequence, we need to exclude the number of cached tokens. + num_computed_tokens = seq.get_num_computed_tokens() + num_cached_tokens = seq.get_num_cached_tokens() + + # We subtract the number of cached tokens from the number of new tokens + num_computed_tokens_new = num_new_tokens + num_computed_tokens + num_new_tokens_exclude_cached = max( + 0, num_computed_tokens_new - num_cached_tokens + ) + assert ( + num_new_tokens_exclude_cached <= num_new_tokens + ), "Number of new tokens exclude cached should be less than or equal to the number of new tokens" + return num_new_tokens_exclude_cached diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6372d4b5d2117..1d85ac8088f84 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1594,6 +1594,7 @@ def _get_stats(self, # Iteration stats num_prompt_tokens_iter = 0 num_generation_tokens_iter = 0 + num_extra_batched_tokens_iter = 0 time_to_first_tokens_iter: List[float] = [] time_per_output_tokens_iter: List[float] = [] num_preemption_iter = (0 if scheduler_outputs is None else @@ -1616,6 +1617,17 @@ def _get_stats(self, # not counted (to avoid double counting) actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore + num_extra_batched_tokens_iter = ( + actual_num_batched_tokens + - scheduler_outputs.num_batched_tokens_from_budget + ) + if num_extra_batched_tokens_iter > 0: + print( + f"num_extra_batched_tokens_iter: {num_extra_batched_tokens_iter}, " + f"actual_num_batched_tokens: {actual_num_batched_tokens}, " + f"num_batched_tokens_from_budget: {scheduler_outputs.num_batched_tokens_from_budget}" + ) + num_generation_tokens_from_prefill_groups = 0. # NOTE: if scheduler_outputs.num_prefill_groups > 0 and # the len of scheduler_outputs.scheduled_seq_groups is != @@ -1715,7 +1727,6 @@ def _get_stats(self, # Prefix Cache Hit Rate cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, - # Iteration stats num_prompt_tokens_iter=num_prompt_tokens_iter, num_generation_tokens_iter=num_generation_tokens_iter, @@ -1723,7 +1734,7 @@ def _get_stats(self, time_per_output_tokens_iter=time_per_output_tokens_iter, spec_decode_metrics=spec_decode_metrics, num_preemption_iter=num_preemption_iter, - + num_extra_batched_tokens_iter=num_extra_batched_tokens_iter, # Request stats # Latency time_e2e_requests=time_e2e_requests, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 74277cae7c8ef..68d90c0057fd9 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -450,8 +450,9 @@ def _log_prometheus(self, stats: Stats) -> None: stats.num_preemption_iter) self._log_counter(self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter) - self._log_counter(self.metrics.counter_generation_tokens, - stats.num_generation_tokens_iter) + self._log_counter( + self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter + ) self._log_histogram(self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter) self._log_histogram(self.metrics.histogram_time_per_output_token, diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 1eccb23593408..1deee8f70ba7d 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -42,6 +42,7 @@ class Stats: time_to_first_tokens_iter: List[float] time_per_output_tokens_iter: List[float] num_preemption_iter: int + num_extra_batched_tokens_iter: int # Request stats (should have _requests suffix) # Latency diff --git a/vllm/sequence.py b/vllm/sequence.py index 65855e48cc16c..e76154560f258 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -162,6 +162,7 @@ class SequenceData(msgspec.Struct, ...] = msgspec.field(default_factory=tuple) # The number of tokens that are computed (that run against the model). _num_computed_tokens: int = 0 + _num_prefix_cached_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL _cached_all_token_ids: List[int] = msgspec.field(default_factory=list) @@ -296,6 +297,12 @@ def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens + def set_num_prefix_cached_tokens(self, num_new_cached_tokens: int): + self._num_prefix_cached_tokens = num_new_cached_tokens + + def get_num_prefix_cached_tokens(self) -> int: + return self._num_prefix_cached_tokens + def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" self._num_computed_tokens += num_new_computed_tokens @@ -587,6 +594,9 @@ def _reset_block_hashes(self): num_full_prompt_blocks: ] + def set_num_prefix_cached_tokens(self, num_prefix_cached_tokens: int): + self.data.set_num_prefix_cached_tokens(num_prefix_cached_tokens) + def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size @@ -647,6 +657,9 @@ def fork(self, new_seq_id: int) -> "Sequence": new_seq.seq_id = new_seq_id return new_seq + def get_num_computed_tokens(self) -> int: + return self.data.get_num_computed_tokens() + def get_num_new_tokens(self) -> int: """Get the number of new tokens to be computed. @@ -656,11 +669,19 @@ def get_num_new_tokens(self) -> int: """ if self.data.stage == SequenceStage.DECODE: return 1 - return self.data.get_num_uncomputed_tokens() + + num_computed_tokens = self.data.get_num_computed_tokens() + return self.data.get_len() - num_computed_tokens + + def get_num_cached_tokens(self) -> int: + return self.data.get_num_prefix_cached_tokens() def is_prefill(self) -> bool: return self.data.stage == SequenceStage.PREFILL + def is_from_decoder_prompt(self) -> bool: + return self.from_decoder_prompt + def __repr__(self) -> str: return (f"Sequence(seq_id={self.seq_id}, " f"status={self.status.name}, " From d2303dbab85a3e6f4024263c4d2f01a8f43ecf2f Mon Sep 17 00:00:00 2001 From: rickyx Date: Mon, 28 Oct 2024 23:15:26 +0000 Subject: [PATCH 03/12] fix --- vllm/core/scheduler.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6e3e4db831b65..03c8c6790dc89 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -908,7 +908,9 @@ def _schedule_prefills( assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - # self._update_prefix_cached_tokens(waiting_seqs[0]) + seq = waiting_seqs[0] + self._update_prefix_cached_tokens(seq) + num_cached_tokens = seq.get_num_cached_tokens() num_new_tokens = self._get_num_new_tokens( seq_group, SequenceStatus.WAITING, @@ -917,11 +919,11 @@ def _schedule_prefills( ) num_new_tokens_exclude_cached = self._get_num_new_tokens_exclude_cached( - num_new_tokens, waiting_seqs[0] + num_new_tokens, seq ) if not enable_chunking: - num_prompt_tokens = waiting_seqs[0].get_len() + num_prompt_tokens = seq.get_len() assert num_new_tokens == num_prompt_tokens prompt_limit = self._get_prompt_limit(seq_group) @@ -929,8 +931,7 @@ def _schedule_prefills( logger.warning( "Input prompt (%d tokens) is too long" " and exceeds limit of %d", num_new_tokens, prompt_limit) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED + seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) waiting_queue.popleft() continue @@ -960,8 +961,7 @@ def _schedule_prefills( "Input prompt (%d tokens) + lookahead slots (%d) is " "too long and exceeds the capacity of block_manager", num_new_tokens, num_lookahead_slots) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED + seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) waiting_queue.popleft() continue @@ -1004,6 +1004,12 @@ def _schedule_prefills( except Exception as e: breakpoint() + # NOTE(rickyx): We are updating this again since some of the previously + # cached blocks that were in evictor might now become active again. + # Therefore, the actual number of tokens cached might have changed. + self._update_prefix_cached_tokens(seq) + num_cached_tokens = max(num_cached_tokens, seq.get_num_cached_tokens()) + if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] # init_multi_step_from_lookahead_slots happens in append_slots @@ -1677,7 +1683,7 @@ def _get_num_new_tokens( num_new_tokens = 0 seqs = seq_group.get_seqs(status=status) for seq in seqs: - self._update_prefix_cached_tokens(seq) + # self._update_prefix_cached_tokens(seq) num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 # Chunk if a running request cannot fit in the given budget. From fc3d044f83e78785633478021ec1cb245d557c4c Mon Sep 17 00:00:00 2001 From: rickyx Date: Tue, 29 Oct 2024 00:26:03 +0000 Subject: [PATCH 04/12] fix --- vllm/core/scheduler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 03c8c6790dc89..1915dea08cb3d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1008,7 +1008,9 @@ def _schedule_prefills( # cached blocks that were in evictor might now become active again. # Therefore, the actual number of tokens cached might have changed. self._update_prefix_cached_tokens(seq) - num_cached_tokens = max(num_cached_tokens, seq.get_num_cached_tokens()) + num_new_tokens_exclude_cached = self._get_num_new_tokens_exclude_cached( + num_new_tokens, seq + ) if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] From aac5ae6fa1def0d18345faf87bb95aa34a9b3e8a Mon Sep 17 00:00:00 2001 From: rickyx Date: Tue, 29 Oct 2024 19:11:40 +0000 Subject: [PATCH 05/12] up --- vllm/core/scheduler.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1915dea08cb3d..225c1010b51e6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -61,7 +61,7 @@ class SchedulingBudget: _num_curr_seqs: int = 0 def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): - assert num_new_tokens != 0 + # assert num_new_tokens != 0 assert num_new_seqs != 0 return (self.num_batched_tokens + num_new_tokens <= self.token_budget and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) @@ -659,6 +659,7 @@ def _schedule_running( if enable_chunking: num_running_seqs = seq_group.get_max_num_running_seqs() budget.add_num_seqs(seq_group.request_id, num_running_seqs) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) @@ -763,6 +764,7 @@ def _schedule_swapped( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_seqs(seq_group.request_id, num_new_seqs) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs swapped_queue.extendleft(leftover_swapped) @@ -910,7 +912,6 @@ def _schedule_prefills( "sequence.") seq = waiting_seqs[0] self._update_prefix_cached_tokens(seq) - num_cached_tokens = seq.get_num_cached_tokens() num_new_tokens = self._get_num_new_tokens( seq_group, SequenceStatus.WAITING, @@ -988,7 +989,7 @@ def _schedule_prefills( assert num_new_tokens > 0 # We have new tokens but they might be cached. - if num_new_tokens_exclude_cached > 0 and not budget.can_schedule( + if not budget.can_schedule( num_new_tokens=num_new_tokens_exclude_cached, num_new_seqs=num_new_seqs, ): @@ -1068,6 +1069,8 @@ def _schedule_default(self) -> SchedulerOutputs: for seq_group in self.running: budget.add_num_seqs(seq_group.request_id, seq_group.get_max_num_running_seqs()) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + curr_loras = set( seq_group.lora_int_id for seq_group in self.running if seq_group.lora_int_id > 0) if self.lora_enabled else None From 661b890cb89f748494f2240d24108f3bbb087cb4 Mon Sep 17 00:00:00 2001 From: rickyx Date: Wed, 30 Oct 2024 00:06:19 +0000 Subject: [PATCH 06/12] fix --- vllm/core/scheduler.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 225c1010b51e6..f2d52f1b819f0 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -75,6 +75,7 @@ def add_num_batched_tokens( num_batched_tokens: int, num_batched_and_cached_tokens: Optional[int] = None, ): + assert num_batched_tokens >= 0 if req_id in self._request_ids_num_batched_tokens: return @@ -84,6 +85,8 @@ def add_num_batched_tokens( num_batched_and_cached_tokens = num_batched_tokens self._num_batched_and_cached_tokens += num_batched_and_cached_tokens + assert self._num_batched_tokens <= self.token_budget, f"{self._num_batched_tokens} > {self.token_budget}" + def subtract_num_batched_tokens( self, req_id: str, @@ -923,6 +926,7 @@ def _schedule_prefills( num_new_tokens, seq ) + # print(f"[{seq_group.request_id=}] {num_new_tokens=} {num_new_tokens_exclude_cached}, budget: {budget.num_batched_tokens}") if not enable_chunking: num_prompt_tokens = seq.get_len() assert num_new_tokens == num_prompt_tokens @@ -1009,6 +1013,12 @@ def _schedule_prefills( # cached blocks that were in evictor might now become active again. # Therefore, the actual number of tokens cached might have changed. self._update_prefix_cached_tokens(seq) + num_new_tokens = self._get_num_new_tokens( + seq_group, + SequenceStatus.RUNNING, + enable_chunking, + budget, + ) num_new_tokens_exclude_cached = self._get_num_new_tokens_exclude_cached( num_new_tokens, seq ) @@ -1033,6 +1043,7 @@ def _schedule_prefills( seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) + # print(f"[{seq_group.request_id}] {num_new_tokens=} {num_new_tokens_exclude_cached=}, budget: {budget.num_batched_tokens}") budget.add_num_batched_tokens( seq_group.request_id, num_batched_tokens=num_new_tokens_exclude_cached, @@ -1203,7 +1214,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: enable_chunking=True) assert (budget.num_batched_tokens <= - self.scheduler_config.max_num_batched_tokens) + self.scheduler_config.max_num_batched_tokens), f"{budget.num_batched_tokens=}, {self.scheduler_config.max_num_batched_tokens=}" assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -1696,6 +1707,7 @@ def _get_num_new_tokens( # in a decode phase. Do not chunk. if enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() + seq = seqs[0] if self.scheduler_config.is_multi_step: # The current multi-step + chunked prefill capability does # not actually support chunking prompts. @@ -1728,9 +1740,15 @@ def _get_num_new_tokens( "block size, but got chunk_size " f"({budget.token_budget}) % block_size " f"({block_size}) = {remainder}") - if remaining_token_budget < num_new_tokens: - num_new_tokens = (remaining_token_budget // - block_size) * block_size + num_new_tokens_cached = seq.get_num_cached_tokens() - seq.get_num_computed_tokens() + num_new_tokens_cached = max(0, num_new_tokens_cached) + # Round down to block + remaining_token_budget = remaining_token_budget // block_size * block_size + + # Calculate the number of new tokens that are not cached with chunk cap. + num_new_tokens_uncached = min(num_new_tokens - num_new_tokens_cached, remaining_token_budget) + num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached + # print(f"[{seq_group.request_id}] {num_new_tokens=} {num_new_tokens_uncached=} {num_new_tokens_cached=}, budget: {budget.num_batched_tokens}") else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens From 2cffdf9621de966d11be89c46cc36139d87f172a Mon Sep 17 00:00:00 2001 From: rickyx Date: Wed, 30 Oct 2024 18:08:49 +0000 Subject: [PATCH 07/12] fix issues with chunked prefill running schedule --- vllm/core/scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f2d52f1b819f0..51a88f338dc3c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1747,6 +1747,9 @@ def _get_num_new_tokens( # Calculate the number of new tokens that are not cached with chunk cap. num_new_tokens_uncached = min(num_new_tokens - num_new_tokens_cached, remaining_token_budget) + if num_new_tokens_uncached == 0: + # No more budget for new tokens, don't include any cached tokens too. + return 0 num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached # print(f"[{seq_group.request_id}] {num_new_tokens=} {num_new_tokens_uncached=} {num_new_tokens_cached=}, budget: {budget.num_batched_tokens}") else: From 0321fa79d494dd77d1095a4a4da51fe9071fb4a6 Mon Sep 17 00:00:00 2001 From: rickyx Date: Wed, 30 Oct 2024 23:06:28 +0000 Subject: [PATCH 08/12] clean up --- benchmarks/benchmark_prefix_caching.py | 24 +++--- tests/core/block/test_block_manager.py | 4 +- vllm/config.py | 4 - vllm/core/block/block_table.py | 15 ++-- vllm/core/block/common.py | 12 --- vllm/core/block/interfaces.py | 2 +- vllm/core/block/naive_block.py | 6 -- vllm/core/block/prefix_caching_block.py | 105 +----------------------- vllm/core/block_manager.py | 91 +++++--------------- vllm/core/scheduler.py | 83 ++++++++----------- vllm/engine/llm_engine.py | 16 +--- vllm/engine/metrics_types.py | 1 - vllm/sequence.py | 12 +-- 13 files changed, 91 insertions(+), 284 deletions(-) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index db9b7e1da2e46..5414036a37263 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -132,18 +132,18 @@ def main(args): filtered_datasets = [(PROMPT, prompt_len, args.output_len) ] * args.num_prompts - # llm = LLM( - # model=args.model, - # tokenizer_mode="auto", - # trust_remote_code=True, - # enforce_eager=True, - # use_v2_block_manager=args.use_v2_block_manager, - # tensor_parallel_size=args.tensor_parallel_size, - # enable_prefix_caching=args.enable_prefix_caching, - # disable_log_stats=False, - # max_num_batched_tokens=4096 * 2, - # enable_chunked_prefill=True, - # ) + llm = LLM( + model=args.model, + tokenizer_mode="auto", + trust_remote_code=True, + enforce_eager=True, + use_v2_block_manager=args.use_v2_block_manager, + tensor_parallel_size=args.tensor_parallel_size, + enable_prefix_caching=args.enable_prefix_caching, + disable_log_stats=False, + max_num_batched_tokens=4096 * 2, + enable_chunked_prefill=True, + ) engine_args = EngineArgs.from_cli_args(args) llm = LLM(**dataclasses.asdict(engine_args)) diff --git a/tests/core/block/test_block_manager.py b/tests/core/block/test_block_manager.py index e940d2a331d9b..1190fa2885298 100644 --- a/tests/core/block/test_block_manager.py +++ b/tests/core/block/test_block_manager.py @@ -235,7 +235,7 @@ def test_can_allocate_with_prefix_cache( # Num blocks needed for 2 seqs, minus the number of blocks shared. num_blocks_required_with_sharing = 2 * num_blocks_required_seq - num_blocks_shared - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0, @@ -299,7 +299,7 @@ def test_can_append_with_prefix_cache( print(f"num_gpu_blocks: {num_gpu_blocks}") num_blocks_required_with_sharing = 2 * num_blocks_required_seq_1 - num_blocks_shared - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0, diff --git a/vllm/config.py b/vllm/config.py index 98e0ea96f0d2c..99a82c8f1b40b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1074,10 +1074,6 @@ def __init__(self, self.policy = policy self._verify_args() - print( - f"max_num_batched_tokens: {self.max_num_batched_tokens}, max_num_seqs: {self.max_num_seqs}" - ) - def _verify_args(self) -> None: if (self.max_num_batched_tokens < self.max_model_len and not self.chunked_prefill_enabled): diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 9a3a66523bd34..2beee5b9885a4 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -164,7 +164,9 @@ def append_slots( for i, token_block in enumerate(token_blocks): if self._enable_prefix_caching: - block_hash: Optional[int] = seq.get_block_hash(first_block_idx + i) + block_hash: Optional[int] = seq.update_and_get_block_hash( + first_block_idx + i + ) else: block_hash = None self._blocks.append_token_ids(first_block_idx + i, token_block, block_hash) @@ -286,7 +288,7 @@ def _allocate_blocks_for_token_ids( if len(cur_token_ids) == self._block_size: block_token_ids.append(cur_token_ids) if self._enable_prefix_caching: - block_hashes.append(seq.get_block_hash(block_idx)) + block_hashes.append(seq.update_and_get_block_hash(block_idx)) else: block_hashes.append(None) else: @@ -308,12 +310,9 @@ def _allocate_blocks_for_token_ids( assert len(tail_token_ids) == 1 assert block_hashes[-1] is None cur_token_ids = tail_token_ids[0] - try: - block = self._allocator.allocate_mutable_block( - prev_block=prev_block, device=device - ) - except Exception as e: - breakpoint() + block = self._allocator.allocate_mutable_block( + prev_block=prev_block, device=device + ) block.append_token_ids(cur_token_ids, block_hash=None) blocks.append(block) diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 191eb6c8fdd15..c2117ccaaeb50 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -196,9 +196,6 @@ def increase_pool(self): allocator=self._allocator, block_id=None)) - # TODO(rickyx): This should take in kwargs for flexible initialization of different types of blocks - # Right now, we update explicitly blocks with other args after initialization, e.g. block_hash - # computed for the prefix caching block. def init_block( self, prev_block: Optional[Block], @@ -206,15 +203,6 @@ def init_block( block_size: int, physical_block_id: Optional[int], ) -> Block: - """Initializes a block with the given parameters. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. - token_ids (List[int]): The token IDs to be stored in the block. - block_size (int): The size of the block. - physical_block_id (Optional[int]): The physical block ID. - block_hash (Optional[int]): The hash of the block's content. - """ if len(self._free_ids) == 0: self.increase_pool() assert len(self._free_ids) > 0 diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 378d273a271ab..8cec110cca663 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -99,7 +99,7 @@ def content_hash(self) -> Optional[int]: return None @abstractmethod - def set_content_hash(self, content_hash: int) -> None: + def set_content_hash(self, content_hash: Optional[int]) -> None: pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 2ffdb94e904a0..e396ae3a3f4ff 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -64,7 +64,6 @@ def allocate_immutable_block( self, prev_block: Optional[Block], token_ids: List[int], - device: Optional[Device] = None, block_hash: Optional[int] = None, ) -> Block: """Allocates a new immutable block with the given token IDs, linked to @@ -79,7 +78,6 @@ def allocate_immutable_block( Returns: Block: The newly allocated immutable block. """ - assert device is None assert block_hash is None block = self.allocate_mutable_block(prev_block=prev_block) @@ -91,9 +89,7 @@ def allocate_immutable_blocks( prev_block: Optional[Block], block_token_ids: List[List[int]], block_hashes: Optional[List[Optional[int]]] = None, - device: Optional[Device] = None, ) -> List[Block]: - assert device is None num_blocks = len(block_token_ids) block_ids = [] @@ -114,7 +110,6 @@ def allocate_immutable_blocks( def allocate_mutable_block( self, prev_block: Optional[Block], - device: Optional[Device] = None, block_hash: Optional[int] = None, ) -> Block: """Allocates a new mutable block, linked to the previous block. @@ -127,7 +122,6 @@ def allocate_mutable_block( Returns: Block: The newly allocated mutable block. """ - assert device is None assert block_hash is None block_id = self._allocate_block_id() diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index feecf3d45ec3c..bdd12ebcc8be8 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -140,7 +140,6 @@ def allocate_immutable_block( prev_block: Optional[Block], token_ids: List[int], block_hash: Optional[int] = None, - device: Optional[Device] = None, ) -> Block: """Allocates an immutable block with the given token IDs, reusing cached blocks if possible. @@ -153,7 +152,6 @@ def allocate_immutable_block( Returns: Block: The allocated immutable block. """ - assert device is None assert len(token_ids) == self._block_size, "An immutable block should be full" assert ( block_hash is not None @@ -163,9 +161,6 @@ def allocate_immutable_block( cached_block_id = self._cached_blocks.get(block_hash, None) if cached_block_id is not None: # Initialize a block that points to cached data - # print( - # f"reuse block_hash={block_hash} from cached_block_id: {cached_block_id}" - # ) block: Block = self._block_pool.init_block( prev_block=prev_block, token_ids=token_ids, @@ -177,9 +172,6 @@ def allocate_immutable_block( self._incr_refcount_cached_block(block) return block - # print( - # f"alloc from new block(block_hash: {block_hash}), get_num_free_blocks: {self.get_num_free_blocks()}" - # ) self.metric_data.query(hit=False) # No cached block => Allocate a new block @@ -192,7 +184,6 @@ def allocate_immutable_blocks( prev_block: Optional[Block], block_token_ids: List[List[int]], block_hashes: Optional[List[int]] = None, - device: Optional[Device] = None, ) -> List[Block]: blocks = [] assert ( @@ -204,7 +195,6 @@ def allocate_immutable_blocks( prev_block=prev_block, token_ids=token_ids, block_hash=block_hash, - device=device, ) blocks.append(prev_block) return blocks @@ -224,9 +214,6 @@ def allocate_mutable_block(self, """ assert device is None assert_prefix_caching_block_or_none(prev_block) - # print( - # f"Allocating mutable block: get_num_free_blocks: {self.get_num_free_blocks()}" - # ) block_id = self._allocate_block_id() block = self._block_pool.init_block(prev_block=prev_block, token_ids=[], @@ -297,7 +284,6 @@ def _allocate_block_id(self) -> BlockId: """First tries to allocate a block id from the hashless allocator, and if there are no blocks, then tries to evict an unused cached block. """ - # print(f"allocating block_id: get_num_free_blocks: {self.get_num_free_blocks()}") hashless_block_id = self._maybe_allocate_hashless_block_id() if hashless_block_id is not None: return hashless_block_id @@ -418,9 +404,7 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int: assert device is None # The number of free blocks is the number of hashless free blocks # plus the number of blocks evictor could free from its list. - return self._hashless_allocator.get_num_free_blocks() + ( - self.evictor.num_blocks - ) + return self._hashless_allocator.get_num_free_blocks() + self.evictor.num_blocks def get_num_total_blocks(self) -> int: return self._hashless_allocator.get_num_total_blocks() @@ -511,9 +495,6 @@ def cow_block_if_not_appendable(self, block: Block) -> BlockId: return src_block_id self._free_block_id(block) - # print( - # f"Allocating block for COW: get_num_free_blocks: {self.get_num_free_blocks()}" - # ) trg_block_id = self._allocate_block_id() self._cow_tracker.record_cow(src_block_id, trg_block_id) @@ -878,38 +859,6 @@ def token_ids(self) -> List[int]: def prev_block(self) -> Optional[Block]: return self._prev_block - # @property - # def content_hash(self) -> Optional[int]: - # """Return the content-based hash of the current block, or None if it is - # not yet defined. - - # For the content-based hash to be defined, the current block must be - # full. - # """ - # # If the hash is already computed, return it. - # if self._cached_content_hash is not None: # return self._cached_content_hash - - # # We cannot compute a hash for the current block because it is not full. - # if not self.is_full: - # return None - - # is_first_block = self._prev_block is None - # prev_block_hash = ( - # None if is_first_block else - # self._prev_block.content_hash # type: ignore - # ) - - # # Previous block exists but does not yet have a hash. - # # Return no hash in this case. - # if prev_block_hash is None and not is_first_block: - # return None - - # self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( - # is_first_block, - # prev_block_hash, - # cur_block_token_ids=self.token_ids) - # return self._cached_content_hash - @property def content_hash(self) -> Optional[int]: return self._cached_content_hash @@ -952,7 +901,9 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], assert (prev_block_hash is None) == is_first_block return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) - +# TODO(rickyx): This is not used anymore. Or it could be used to track +# cached blocks for a sequence, so the sequence would be decoupled from the computed +# block hash calculation. class ComputedBlocksTracker: """Handles caching of per-sequence computed block ids. When a sequence appears for the first time, it traverses all of the @@ -989,54 +940,6 @@ def remove_seq(self, seq_id: int) -> None: assert seq_id in self._cached_computed_seq_blocks del self._cached_computed_seq_blocks[seq_id] - def get_cached_computed_blocks_and_update( - self, seq_id: int, block_ids: List[int]) -> List[int]: - """ Look at the class documentation for details - """ - # Ensure seq_id is already tracked - assert seq_id in self._cached_computed_seq_blocks - - # Get cached data (may be empty on the first time) - prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[ - seq_id] - - if has_gap: - # When gap is detected, we do not add more computed blocks at this - # sequence iteration - return prev_computed_block_ids - - # We do not consider the last block id for caching purposes. - num_cur_blocks = len(block_ids) - 1 - assert num_cur_blocks >= 0 - - if len(prev_computed_block_ids) >= num_cur_blocks: - # Cache HIT - assert len(prev_computed_block_ids) == num_cur_blocks - return prev_computed_block_ids - - # If here, then we may possibly add more computed blocks. As a result, - # traverse the additional blocks after prev_computed_block_ids to - # detect more computed blocks and add them. - - # Incremental init for seq_id => Look only at the new blocks - computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501 - prev_computed_block_ids, - block_ids, - skip_last_block_id= - True, # We skip last block id to avoid caching of full seq - ) - - # QQ(rickyx): why is it possible to actually have a gap? - - # Detect if there is a "gap" - has_gap = len(computed_block_ids) < num_cur_blocks - - # Record - self._cached_computed_seq_blocks[seq_id] = (computed_block_ids, - has_gap) - - return computed_block_ids - class LastAccessBlocksTracker: """Manages the last access time of the tracked sequences, in order to allow diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 998b706296fd2..df09a3a30743c 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -108,6 +108,16 @@ def __init__( def _get_num_blocks_to_allocate( self, seq: Sequence, num_lookahead_slots: int = 0 ) -> int: + """ + Get the number of new blocks to allocate for a sequence. + + Args: + seq (Sequence): The sequence to allocate blocks for. + num_lookahead_slots (int): The number of lookahead slots to allocate. + + Returns: + int: The number of new blocks to allocate. + """ num_cached_tokens = seq.get_num_cached_tokens() assert ( @@ -119,6 +129,17 @@ def _get_num_blocks_to_allocate( return num_required_blocks - num_cached_blocks def get_num_computed_tokens(self, seq: Sequence) -> int: + """ + Get the number of computed tokens for a sequence. + + NOTE: This only returns tokens in blocks that are BOTH cached and allocated (active). + + Args: + seq (Sequence): The sequence to get the number of computed tokens for. + + Returns: + int: The number of allocated and cached computed tokens. + """ seq_blocks = seq.get_block_hashes() cached_seq_blocks = self.block_allocator.get_allocated_cached_blocks( block_hashes=seq_blocks, @@ -126,52 +147,6 @@ def get_num_computed_tokens(self, seq: Sequence) -> int: ) return len(cached_seq_blocks) * self.block_size - # def get_num_computed_blocks(self, seq_group: SequenceGroup) -> Dict[SeqId, int]: - # num_computed_blocks = {} - # for seq in seq_group.get_seqs(): - # num_computed_blocks[seq.seq_id] = self._get_num_computed_tokens(seq) - # return num_computed_blocks - - def can_allocate_old( - self, seq_group: SequenceGroup, num_lookahead_slots: int = 0 - ) -> AllocStatus: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = BlockTable.get_num_required_blocks( - seq.get_token_ids(), - block_size=self.block_size, - num_lookahead_slots=num_lookahead_slots, - ) - - if seq_group.is_encoder_decoder(): - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - num_required_blocks += BlockTable.get_num_required_blocks( - encoder_seq.get_token_ids(), - block_size=self.block_size, - ) - - if self.max_block_sliding_window is not None: - num_required_blocks = min( - num_required_blocks, self.max_block_sliding_window - ) - - num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - device=Device.GPU - ) - - # Use watermark to avoid frequent cache eviction. - if self.num_total_gpu_blocks - num_required_blocks < self.watermark_blocks: - return AllocStatus.NEVER - if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - def can_allocate(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> AllocStatus: @@ -200,14 +175,6 @@ def can_allocate(self, num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( device=Device.GPU ) - # print(f"num_blocks_to_allocate: {num_blocks_to_allocate}") - # print(f"num_free_gpu_blocks: {num_free_gpu_blocks}") - # print(f"watermark_blocks: {self.watermark_blocks}") - - # Use watermark to avoid frequent cache eviction. - # old_can_allocate = self.can_allocate_old(seq_group, num_lookahead_slots) - - can_allocate = None if self.num_total_gpu_blocks - num_blocks_to_allocate < self.watermark_blocks: return AllocStatus.NEVER if num_free_gpu_blocks - num_blocks_to_allocate >= self.watermark_blocks: @@ -215,11 +182,6 @@ def can_allocate(self, else: return AllocStatus.LATER - # if old_can_allocate != can_allocate: - # print(f"old_can_allocate: {old_can_allocate}, can_allocate: {can_allocate}") - - return can_allocate - def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_table = BlockTable( block_size=self.block_size, @@ -400,17 +362,6 @@ def get_common_computed_block_ids( computed_block_ids = all_blocks[:num_cached_block] computed_seq_block_ids.append(computed_block_ids) - # old_computed_block_ids = ( - # self._computed_blocks_tracker.get_cached_computed_blocks_and_update( - # seq.seq_id, all_blocks - # ) - # ) - # if old_computed_block_ids != computed_block_ids: - # print( - # f"old_computed_block_ids: \n{old_computed_block_ids}\n, computed_block_ids: \n{computed_block_ids}\n" - # ) - # print(f"seq: {seq}") - # NOTE(sang): This assumes seq_block_ids doesn't contain any None. return self.block_allocator.get_common_computed_block_ids( computed_seq_block_ids) # type: ignore diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 51a88f338dc3c..1434eb9b3115d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -56,12 +56,14 @@ class SchedulingBudget: max_num_seqs: int _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) + # Number of batched tokens that are strictly not cached. _num_batched_tokens: int = 0 - _num_batched_and_cached_tokens: int = 0 + # Number of batched tokens that are cached. + _num_cached_tokens: int = 0 _num_curr_seqs: int = 0 def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): - # assert num_new_tokens != 0 + assert num_new_tokens >= 0 assert num_new_seqs != 0 return (self.num_batched_tokens + num_new_tokens <= self.token_budget and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) @@ -73,32 +75,25 @@ def add_num_batched_tokens( self, req_id: str, num_batched_tokens: int, - num_batched_and_cached_tokens: Optional[int] = None, + num_cached_tokens: int = 0, ): assert num_batched_tokens >= 0 + assert num_cached_tokens >= 0 if req_id in self._request_ids_num_batched_tokens: return self._request_ids_num_batched_tokens.add(req_id) self._num_batched_tokens += num_batched_tokens - if num_batched_and_cached_tokens is None: - num_batched_and_cached_tokens = num_batched_tokens - self._num_batched_and_cached_tokens += num_batched_and_cached_tokens - - assert self._num_batched_tokens <= self.token_budget, f"{self._num_batched_tokens} > {self.token_budget}" + self._num_cached_tokens += num_cached_tokens def subtract_num_batched_tokens( self, req_id: str, num_batched_tokens: int, - num_batched_and_cached_tokens: Optional[int] = None, ): if req_id in self._request_ids_num_batched_tokens: self._request_ids_num_batched_tokens.remove(req_id) self._num_batched_tokens -= num_batched_tokens - if num_batched_and_cached_tokens is None: - num_batched_and_cached_tokens = num_batched_tokens - self._num_batched_and_cached_tokens -= num_batched_and_cached_tokens def add_num_seqs(self, req_id: str, num_curr_seqs: int): if req_id in self._request_ids_num_curr_seqs: @@ -118,7 +113,7 @@ def num_batched_tokens(self): @property def num_batched_and_cached_tokens(self): - return self._num_batched_and_cached_tokens + return self._num_batched_tokens + self._num_cached_tokens @property def num_curr_seqs(self): @@ -638,9 +633,6 @@ def _schedule_running( self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() - if seq_group.is_prefill() and not enable_chunking: - breakpoint() - scheduled_seq_group: ScheduledSequenceGroup = \ self._scheduled_seq_group_cache[self.cache_id].get_object() scheduled_seq_group.seq_group = seq_group @@ -840,7 +832,6 @@ def _schedule_priority_preemption( while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): # Only preempt if waiting sequence cannot be allocated - assert False can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens and can_allocate == AllocStatus.OK and budget.can_schedule(num_new_tokens=num_new_tokens, @@ -926,7 +917,6 @@ def _schedule_prefills( num_new_tokens, seq ) - # print(f"[{seq_group.request_id=}] {num_new_tokens=} {num_new_tokens_exclude_cached}, budget: {budget.num_batched_tokens}") if not enable_chunking: num_prompt_tokens = seq.get_len() assert num_new_tokens == num_prompt_tokens @@ -950,15 +940,6 @@ def _schedule_prefills( can_allocate = self.block_manager.can_allocate( seq_group, num_lookahead_slots=num_lookahead_slots) - old_can_allocate = self.block_manager.can_allocate_old( - seq_group, num_lookahead_slots=num_lookahead_slots - ) - - if can_allocate != old_can_allocate: - print( - f"can_allocate: {can_allocate}, old_can_allocate: {old_can_allocate}" - ) - if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: @@ -1004,13 +985,10 @@ def _schedule_prefills( if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) waiting_queue.popleft() - try: - self._allocate_and_set_running(seq_group) - except Exception as e: - breakpoint() + self._allocate_and_set_running(seq_group) # NOTE(rickyx): We are updating this again since some of the previously - # cached blocks that were in evictor might now become active again. + # cached blocks that were in evictor might now become active again. # Therefore, the actual number of tokens cached might have changed. self._update_prefix_cached_tokens(seq) num_new_tokens = self._get_num_new_tokens( @@ -1019,8 +997,8 @@ def _schedule_prefills( enable_chunking, budget, ) - num_new_tokens_exclude_cached = self._get_num_new_tokens_exclude_cached( - num_new_tokens, seq + num_new_tokens_uncached = self._get_num_new_tokens_exclude_cached( + num_new_tokens, seq ) if enable_chunking and self.scheduler_config.is_multi_step: @@ -1041,13 +1019,14 @@ def _schedule_prefills( enable_chunking=enable_chunking) seq_groups.append( - ScheduledSequenceGroup(seq_group=seq_group, - token_chunk_size=num_new_tokens)) - # print(f"[{seq_group.request_id}] {num_new_tokens=} {num_new_tokens_exclude_cached=}, budget: {budget.num_batched_tokens}") + ScheduledSequenceGroup( + seq_group=seq_group, token_chunk_size=num_new_tokens + ) + ) budget.add_num_batched_tokens( seq_group.request_id, - num_batched_tokens=num_new_tokens_exclude_cached, - num_batched_and_cached_tokens=num_new_tokens, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens - num_new_tokens_uncached, ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) @@ -1137,11 +1116,6 @@ def _schedule_default(self) -> SchedulerOutputs: # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. - if len(running_scheduled.prefill_seq_groups) > 0: - print( - f"running_scheduled.prefill_seq_groups: {running_scheduled.prefill_seq_groups}" - ) - breakpoint() assert len(running_scheduled.prefill_seq_groups) == 0 assert len(swapped_in.prefill_seq_groups) == 0 @@ -1572,7 +1546,7 @@ def _preempt( else: preemption_mode = PreemptionMode.RECOMPUTE - if self.num_cumulative_preemption % 5 == 0: + if self.num_cumulative_preemption % 50 == 0: logger.warning( "Sequence group %s is preempted by %s mode because there is " "not enough KV cache space. This can affect the end-to-end " @@ -1699,7 +1673,6 @@ def _get_num_new_tokens( num_new_tokens = 0 seqs = seq_group.get_seqs(status=status) for seq in seqs: - # self._update_prefix_cached_tokens(seq) num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 # Chunk if a running request cannot fit in the given budget. @@ -1751,18 +1724,34 @@ def _get_num_new_tokens( # No more budget for new tokens, don't include any cached tokens too. return 0 num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached - # print(f"[{seq_group.request_id}] {num_new_tokens=} {num_new_tokens_uncached=} {num_new_tokens_cached=}, budget: {budget.num_batched_tokens}") else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens def _update_prefix_cached_tokens(self, seq: Sequence): + """ + Update the number of prefix cached tokens for a sequence. + + This function takes O(log(n)) time, where n is the number of blocks + in the sequence. + """ num_prefix_cached_tokens = self.block_manager.get_num_computed_tokens(seq) seq.set_num_prefix_cached_tokens(num_prefix_cached_tokens) def _get_num_new_tokens_exclude_cached( self, num_new_tokens: int, seq: Sequence ) -> int: + """ + Get the number of new tokens to compute for a sequence, excluding + cached tokens. + + Args: + num_new_tokens: The number of new tokens to compute. + seq: The sequence to compute the new tokens for. + + Returns: + Given `num_new_tokens`, returns the number of uncached tokens. + """ # If a decode sequence, new tokens are always not computed/cached. if not seq.is_prefill(): diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 17476dd78d1f6..34694a37405c0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1637,7 +1637,6 @@ def _get_stats(self, # Iteration stats num_prompt_tokens_iter = 0 num_generation_tokens_iter = 0 - num_extra_batched_tokens_iter = 0 time_to_first_tokens_iter: List[float] = [] time_per_output_tokens_iter: List[float] = [] num_preemption_iter = (0 if scheduler_outputs is None else @@ -1678,17 +1677,6 @@ def _get_stats(self, # not counted (to avoid double counting) actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore - num_extra_batched_tokens_iter = ( - actual_num_batched_tokens - - scheduler_outputs.num_batched_tokens_from_budget - ) - if num_extra_batched_tokens_iter > 0: - print( - f"num_extra_batched_tokens_iter: {num_extra_batched_tokens_iter}, " - f"actual_num_batched_tokens: {actual_num_batched_tokens}, " - f"num_batched_tokens_from_budget: {scheduler_outputs.num_batched_tokens_from_budget}" - ) - num_generation_tokens_from_prefill_groups = 0. # NOTE: if scheduler_outputs.num_prefill_groups > 0 and # the len of scheduler_outputs.scheduled_seq_groups is != @@ -1802,7 +1790,6 @@ def _get_stats(self, time_per_output_tokens_iter=time_per_output_tokens_iter, spec_decode_metrics=spec_decode_metrics, num_preemption_iter=num_preemption_iter, - num_extra_batched_tokens_iter=num_extra_batched_tokens_iter, # Request stats # Latency time_e2e_requests=time_e2e_requests, @@ -1813,7 +1800,8 @@ def _get_stats(self, finished_reason_requests=finished_reason_requests, max_lora=str(max_lora_stat), waiting_lora_adapters=list(waiting_lora_adapters.keys()), - running_lora_adapters=list(running_lora_adapters.keys())) + running_lora_adapters=list(running_lora_adapters.keys()), + ) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index d9bc9dfddfee1..e9a5bd3b586be 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -42,7 +42,6 @@ class Stats: time_to_first_tokens_iter: List[float] time_per_output_tokens_iter: List[float] num_preemption_iter: int - num_extra_batched_tokens_iter: int # Request stats (should have _requests suffix) # Latency diff --git a/vllm/sequence.py b/vllm/sequence.py index 0af3b7acdf3ab..a3d6c0b1492ad 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -566,8 +566,11 @@ def get_output_token_ids_to_return( return self.data._cached_all_token_ids[-num_new_tokens:] - def get_block_hash(self, block_idx: int) -> Optional[int]: - + def update_and_get_block_hash(self, block_idx: int) -> Optional[int]: + """ + Get the block hash for a given block index. + Optionally update the block hashes if not computed yet. + """ # Lazy update the block hashes on the first invocation. if block_idx >= len(self._computed_block_hashes): self._update_block_hashes() @@ -577,14 +580,12 @@ def get_block_hash(self, block_idx: int) -> Optional[int]: return None def get_block_hashes(self) -> List[int]: - # TODO(rickyx): maybe better to have an API to track if the computed hash is updated. self._update_block_hashes() return self._computed_block_hashes def _update_block_hashes(self): """ Update the block hashes for all the full blocks in the sequence. - It skips the blocks that have already been computed. """ token_ids = self.get_token_ids() # All token ids in the sequence @@ -697,8 +698,7 @@ def get_num_new_tokens(self) -> int: if self.data.stage == SequenceStage.DECODE: return 1 - num_computed_tokens = self.data.get_num_computed_tokens() - return self.data.get_len() - num_computed_tokens + return self.data.get_num_uncomputed_tokens() def get_num_cached_tokens(self) -> int: return self.data.get_num_prefix_cached_tokens() From f8b488d7482d389e86c1d442001747ff75bda8f0 Mon Sep 17 00:00:00 2001 From: rickyx Date: Wed, 30 Oct 2024 23:08:04 +0000 Subject: [PATCH 09/12] up --- benchmarks/benchmark_prefix_caching.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 5414036a37263..b904135b202c8 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -117,9 +117,7 @@ def main(args): input_length_range = tuple(map(int, args.input_length_range.split(':'))) random.seed(args.seed) if args.dataset_path is not None: - print( - f"Start to sample {args.num_prompts} prompts " f"from {args.dataset_path}" - ) + print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}") filtered_datasets = sample_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, From 4df83d0777e97e66eec1575988641602c411cba9 Mon Sep 17 00:00:00 2001 From: rickyx Date: Wed, 30 Oct 2024 23:08:28 +0000 Subject: [PATCH 10/12] up --- benchmarks/benchmark_prefix_caching.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index b904135b202c8..8637e4ff21718 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -130,18 +130,6 @@ def main(args): filtered_datasets = [(PROMPT, prompt_len, args.output_len) ] * args.num_prompts - llm = LLM( - model=args.model, - tokenizer_mode="auto", - trust_remote_code=True, - enforce_eager=True, - use_v2_block_manager=args.use_v2_block_manager, - tensor_parallel_size=args.tensor_parallel_size, - enable_prefix_caching=args.enable_prefix_caching, - disable_log_stats=False, - max_num_batched_tokens=4096 * 2, - enable_chunked_prefill=True, - ) engine_args = EngineArgs.from_cli_args(args) llm = LLM(**dataclasses.asdict(engine_args)) From 417760af59e678aeea6c1c315adaa4934ebd0a99 Mon Sep 17 00:00:00 2001 From: rickyx Date: Wed, 6 Nov 2024 22:56:57 +0000 Subject: [PATCH 11/12] up --- benchmarks/benchmark_prefix_caching.py | 1 + tests/core/block/test_block_manager.py | 83 +-- tests/prefix_caching/test_prefix_caching.py | 9 +- vllm/core/block/block_table.py | 116 ++-- vllm/core/block/cpu_gpu_block_allocator.py | 9 + vllm/core/block/interfaces.py | 12 +- vllm/core/block/naive_block.py | 5 + vllm/core/block/prefix_caching_block.py | 240 +++++-- vllm/core/block_manager.py | 91 ++- vllm/core/scheduler.py | 672 +++++++++++++------- vllm/sequence.py | 61 +- 11 files changed, 826 insertions(+), 473 deletions(-) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 8637e4ff21718..8eb6c8ad7606b 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -131,6 +131,7 @@ def main(args): ] * args.num_prompts engine_args = EngineArgs.from_cli_args(args) + engine_args.enable_chunked_prefill = True llm = LLM(**dataclasses.asdict(engine_args)) diff --git a/tests/core/block/test_block_manager.py b/tests/core/block/test_block_manager.py index 1190fa2885298..16e2876c015b1 100644 --- a/tests/core/block/test_block_manager.py +++ b/tests/core/block/test_block_manager.py @@ -234,6 +234,9 @@ def test_can_allocate_with_prefix_cache( # Num blocks needed for 2 seqs, minus the number of blocks shared. num_blocks_required_with_sharing = 2 * num_blocks_required_seq - num_blocks_shared + print( + f"num_blocks_required_with_sharing: {num_blocks_required_with_sharing}" + ) block_manager = SelfAttnBlockSpaceManager( block_size=block_size, @@ -253,6 +256,11 @@ def test_can_allocate_with_prefix_cache( # Allocate the seq 1 block_manager.allocate(seq_group_1) + # Mark the seq 1 as computed (This shoudl be done by the scheduler in reality) + block_manager.mark_blocks_as_computed( + seq_group=seq_group_1, token_chunk_size=len(tokens_1) + ) + # Test if allocatable of seq 2. seq_group_2 = create_seq_group( seq_output_lens=[0], @@ -268,80 +276,6 @@ def test_can_allocate_with_prefix_cache( assert block_manager.can_allocate(seq_group_2) == AllocStatus.LATER -@pytest.mark.skip(reason="Not correct yet") -@pytest.mark.parametrize("block_size", [1, 4]) -@pytest.mark.parametrize("num_prefill_tokens", [1, 2, 4, 5, 6, 8, 10]) -@pytest.mark.parametrize("prefix_shared_percentage", [0.0, 0.3, 0.5, 0.7, 1.0]) -def test_can_append_with_prefix_cache( - block_size: int, - num_prefill_tokens: int, - prefix_shared_percentage: float, -): - num_seqs_allocable = 1.5 - - num_blocks_required_seq_1 = math.ceil(num_prefill_tokens / block_size) - num_gpu_blocks = math.ceil(num_blocks_required_seq_1 * num_seqs_allocable) - - num_tokens_shared = int( - 1 + (num_prefill_tokens - 1) * prefix_shared_percentage - ) # We will always share the first token. - num_blocks_shared = num_tokens_shared // block_size - - tokens_1 = list(range(num_prefill_tokens)) - tokens_2 = tokens_1[:num_tokens_shared] + [ - t + 10 for t in tokens_1[num_tokens_shared:] - ] - - print(f"tokens_1: {tokens_1}") - print(f"tokens_2: {tokens_2}") - print(f"num_blocks_shared: {num_blocks_shared}") - print(f"num_blocks_required_seq_1: {num_blocks_required_seq_1}") - print(f"num_gpu_blocks: {num_gpu_blocks}") - - num_blocks_required_with_sharing = 2 * num_blocks_required_seq_1 - num_blocks_shared - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - enable_caching=True, # Prefix cache - ) - - # Allocate seq 1. - seq_group_1 = create_seq_group( - seq_output_lens=[0], - request_id="0", - seq_id_start=0, - prompt_token_ids=tokens_1, - block_size=block_size, - ) - assert block_manager.can_allocate(seq_group_1) == AllocStatus.OK - block_manager.allocate(seq_group_1) - - # Allocate seq 2. - seq_group_2 = create_seq_group( - seq_output_lens=[0], - request_id="1", - seq_id_start=1, - prompt_token_ids=tokens_2[:1], # Just one token for prefill. - block_size=block_size, - ) - assert block_manager.can_allocate(seq_group_2) == AllocStatus.OK - block_manager.allocate(seq_group_2) - - # Test if append is possible. - seq = seq_group_2.get_seqs()[0] - for token_id in tokens_2[1:]: - seq.append_token_id(token_id, {token_id: Logprob(0.0)}) - - seq.update_num_computed_tokens(len(tokens_2[1:])) - seq.status = SequenceStatus.RUNNING - - if num_blocks_required_with_sharing <= num_gpu_blocks: - assert block_manager.can_append_slots(seq_group_2, 0) - else: - assert not block_manager.can_append_slots(seq_group_2, 0) - - @pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) @@ -622,6 +556,7 @@ def num_blocks(num_tokens): for token_id in range(num_slots_to_append): seq.append_token_id(token_id, {token_id: Logprob(0.0)}) seq.data.update_num_computed_tokens(1) + block_manager._computed_blocks_tracker.update_seq(seq) block_manager.append_slots(seq, num_lookahead_slots=0) if prompt_len < sliding_window + 10: check_used(0, sliding_blocks + 1) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 366b030eaa399..9b3782e97fe72 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -18,6 +18,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("cached_position", [0, 1]) +@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) def test_mixed_requests( hf_runner, vllm_runner, @@ -27,6 +28,7 @@ def test_mixed_requests( dtype: str, max_tokens: int, cached_position: int, + enable_chunked_prefill: bool, monkeypatch, ) -> None: """ @@ -41,9 +43,10 @@ def test_mixed_requests( cached_prompt = example_prompts[cached_position] with vllm_runner( - model, - dtype=dtype, - enable_prefix_caching=True, + model, + dtype=dtype, + enable_prefix_caching=True, + enable_chunked_prefill=enable_chunked_prefill, ) as vllm_model: # Run the first prompt so the cache is populated vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 2beee5b9885a4..74962541dfb22 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -45,7 +45,6 @@ def __init__( block_allocator: DeviceAwareBlockAllocator, _blocks: Optional[List[Block]] = None, max_block_sliding_window: Optional[int] = None, - enable_prefix_caching: bool = False, ): self._block_size = block_size self._allocator = block_allocator @@ -56,51 +55,56 @@ def __init__( self._max_block_sliding_window = max_block_sliding_window self._num_full_slots = self._get_num_token_ids() - # Whether to enable prefix caching. - self._enable_prefix_caching = enable_prefix_caching - - @staticmethod - def get_num_required_blocks(token_ids: List[int], - block_size: int, - num_lookahead_slots: int = 0) -> int: - """Calculates the minimum number of blocks required to store a given - sequence of token IDs along with any look-ahead slots that may be - required (like in multi-step + chunked-prefill). - - This assumes worst-case scenario, where every block requires a new - allocation (e.g. ignoring prefix caching). - - Args: - token_ids (List[int]): The sequence of token IDs to be stored. - block_size (int): The maximum number of tokens that can be stored in - a single block. - num_lookahead_slots (int): look-ahead slots that the sequence may - require. - - Returns: - int: The minimum number of blocks required to store the given - sequence of token IDs along with any required look-ahead slots. - """ - return cdiv(len(token_ids) + num_lookahead_slots, block_size) - - def allocate(self, seq: Sequence, device: Device = Device.GPU) -> None: - """Allocates memory blocks for storing the given sequence of token IDs. + # @staticmethod + # def get_num_required_blocks(token_ids: List[int], + # block_size: int, + # num_lookahead_slots: int = 0) -> int: + # """Calculates the minimum number of blocks required to store a given + # sequence of token IDs along with any look-ahead slots that may be + # required (like in multi-step + chunked-prefill). + + # This assumes worst-case scenario, where every block requires a new + # allocation (e.g. ignoring prefix caching). + + # Args: + # token_ids (List[int]): The sequence of token IDs to be stored. + # block_size (int): The maximum number of tokens that can be stored in + # a single block. + # num_lookahead_slots (int): look-ahead slots that the sequence may + # require. + + # Returns: + # int: The minimum number of blocks required to store the given + # sequence of token IDs along with any required look-ahead slots. + # """ + # return cdiv(len(token_ids) + num_lookahead_slots, block_size) + + def allocate( + self, + token_ids: List[int], + block_hashes: List[Optional[int]], + device: Device = Device.GPU, + ) -> None: + """Allocates memory blocks for storing the given sequence. This method allocates the required number of blocks to store the given sequence of token IDs. Args: - seq (Sequence): The sequence to allocate blocks for. + token_ids (List[int]): The sequence of token IDs to be stored. + block_hashes (List[int]): The list of block hashes for the sequence. device (Device, optional): The device on which the blocks should be allocated. Defaults to Device.GPU. """ assert not self._is_allocated - if not seq.get_token_ids(): + if not token_ids: return - blocks = self._allocate_blocks_for_token_ids(seq=seq, device=device) + blocks = self._allocate_blocks_for_token_ids( + token_ids, block_hashes, device + ) self.update(blocks) - self._num_full_slots = len(seq.get_token_ids()) + self._num_full_slots = len(token_ids) def update(self, blocks: List[Block]) -> None: """Resets the table to the newly provided blocks @@ -110,7 +114,9 @@ def update(self, blocks: List[Block]) -> None: def append_slots( self, - seq: Sequence, + token_ids: List[int], + block_hashes: List[Optional[int]], + num_computed_slots: Optional[int], num_lookahead_slots: int = 0, ) -> None: """Appends a sequence of token IDs to the existing blocks in the @@ -138,9 +144,6 @@ def append_slots( assert self._is_allocated, "no blocks have been allocated" assert len(self._blocks) > 0 - token_ids = self.get_unseen_token_ids(seq.get_token_ids()) - num_computed_slots = seq.data.get_num_computed_tokens() - # Drop blocks that are no longer needed due to sliding window if self._max_block_sliding_window is not None: null_block = self._allocator.allocate_or_get_null_block() @@ -162,13 +165,15 @@ def append_slots( first_block_idx = self._num_full_slots // self._block_size token_blocks = self._chunk_token_blocks_for_append(token_ids) + if len(token_blocks) != len(block_hashes): + breakpoint() + + assert len(token_blocks) == len( + block_hashes + ), "chunked token_ids and block_hashes must have the same length" + for i, token_block in enumerate(token_blocks): - if self._enable_prefix_caching: - block_hash: Optional[int] = seq.update_and_get_block_hash( - first_block_idx + i - ) - else: - block_hash = None + block_hash = block_hashes[i] self._blocks.append_token_ids(first_block_idx + i, token_block, block_hash) self._num_full_slots += len(token_ids) @@ -223,7 +228,6 @@ def fork(self) -> "BlockTable": block_allocator=self._allocator, _blocks=forked_blocks, max_block_sliding_window=self._max_block_sliding_window, - enable_prefix_caching=self._enable_prefix_caching, ) def free(self) -> None: @@ -274,33 +278,33 @@ def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: return sequence_token_ids[self.num_full_slots:] def _allocate_blocks_for_token_ids( - self, seq: Sequence, device: Device + self, + token_ids: List[int], + block_hashes: List[Optional[int]], + device: Device, ) -> List[Block]: blocks: List[Block] = [] - block_hashes: List[Optional[int]] = [] prev_block: Optional[Block] = None block_token_ids = [] tail_token_ids = [] - token_ids = seq.get_token_ids() - chunked_block_token_ids = chunk_list(token_ids, self._block_size) - for block_idx, cur_token_ids in enumerate(chunked_block_token_ids): + chunked_block_token_ids = list(chunk_list(token_ids, self._block_size)) + assert len(block_hashes) == len( + chunked_block_token_ids + ), "block_hashes and chunked token_ids must have the same length" + + for cur_token_ids in chunked_block_token_ids: if len(cur_token_ids) == self._block_size: block_token_ids.append(cur_token_ids) - if self._enable_prefix_caching: - block_hashes.append(seq.update_and_get_block_hash(block_idx)) - else: - block_hashes.append(None) else: tail_token_ids.append(cur_token_ids) - block_hashes.append(None) if block_token_ids: blocks.extend( self._allocator.allocate_immutable_blocks( prev_block, block_token_ids=block_token_ids, - block_hashes=block_hashes, + block_hashes=block_hashes[: len(block_token_ids)], device=device, ) ) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 9294e668a07bf..1674514f354e7 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -353,6 +353,15 @@ def get_and_reset_swaps(self) -> List[Tuple[int, int]]: self._swap_mapping.clear() return list(mapping.items()) + def find_cached_blocks_prefix( + self, block_hashes: List[int], allocated: bool + ) -> List[int]: + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].find_cached_blocks_prefix( + block_hashes, allocated + ) + class NullBlock(Block): """ diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 8cec110cca663..f55133f92ee0c 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -123,7 +123,7 @@ def allocate_immutable_blocks( self, prev_block: Optional[Block], block_token_ids: List[List[int]], - block_hashes: Optional[List[Optional[int]]] = None, + block_hashes: List[Optional[int]], ) -> List[Block]: pass @@ -204,7 +204,9 @@ def get_prefix_cache_hit_rate(self) -> float: pass @abstractmethod - def get_allocated_cached_blocks(self, block_hashes: List[int]) -> List[int]: + def find_cached_blocks_prefix( + self, block_hashes: List[int], allocated: bool + ) -> List[int]: pass class NoFreeBlocksError(ValueError): @@ -308,3 +310,9 @@ def get_allocated_cached_blocks( self, block_hashes: List[int], device: Device ) -> List[int]: pass + + @abstractmethod + def find_cached_blocks_prefix( + self, block_hashes: List[int], allocated: bool, device: Device + ) -> List[int]: + pass \ No newline at end of file diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index e396ae3a3f4ff..83725faef6d39 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -338,6 +338,11 @@ def get_prefix_cache_hit_rate(self) -> float: def get_allocated_cached_blocks(self, block_hashes: List[int]) -> List[int]: return [] + def find_cached_blocks_prefix( + self, block_hashes: List[int], allocated: bool = False + ) -> List[int]: + return [] + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index bdd12ebcc8be8..035a9a8dd92f9 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -9,6 +9,7 @@ NaiveBlockAllocator) from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor from vllm.sequence import Sequence +from vllm.utils import cdiv PrefixHash = int @@ -183,12 +184,12 @@ def allocate_immutable_blocks( self, prev_block: Optional[Block], block_token_ids: List[List[int]], - block_hashes: Optional[List[int]] = None, + block_hashes: List[Optional[int]], ) -> List[Block]: blocks = [] - assert ( - block_hashes is not None - ), "block_hashes must be provided for immutable prefix cache blocks" + assert len(block_token_ids) == len( + block_hashes + ), "block_token_ids and block_hashes must have the same length" for token_ids, block_hash in zip(block_token_ids, block_hashes): prev_block = self.allocate_immutable_block( @@ -658,48 +659,35 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id - def get_allocated_cached_blocks(self, block_hashes: List[PrefixHash]) -> List[PrefixHash]: + def find_cached_blocks_prefix( + self, block_hashes: List[PrefixHash], allocated: bool = False + ) -> List[PrefixHash]: """ - Get the list of blocks that are already computed and allocated so that they can - be shared by multiple sequences, and no needed to be allocated again. - - INVARIANCE: - For a sequence of blocks, it's also guaranteed that if a block is allocated (i.e. - block_is_active_computed(block_hash) == True), then the previous block must also be - allocated (i.e. block_is_active_computed(prev_block_hash) == True). - - This is because we allocate and free entire sequence of blocks atomically (no partial - sequence is allocated or freed). Therefore, because block hash includes the previous - block's hash, if a current block is allocated, this means the previous block must also - be allocated. - - NOTE: we exclude computed blocks in evictor because they are already freed even if they - are cached. They would still have to be allocated by a sequence. If not, consider a - scenario with a seqeuence of 3 token blocks, and a block pool of only 2 blocks: - [b0, b1, b2], where b0 and b1 are computed but evicted, b2 is not computed. - So b0, b1 are the 2 free blocks, in evictor. - When deciding how many more blocks need to be allocated for this sequence, it should be - all 3 blocks (b0, b1, b2) rather than just 1 block (b2). + Return the prefix of the block hashes that are already computed and + cached. + + When `allocated` is True, only return the blocks that are allocated. """ - # Search for the longest prefix in `block_hashes` that are present cached blocks. - def block_is_active_computed(block_hash: PrefixHash) -> bool: + + def block_is_cached(block_hash: PrefixHash) -> bool: if block_hash not in self._cached_blocks: return False cached_block_id = self._cached_blocks[block_hash] - if cached_block_id in self.evictor: + if allocated and cached_block_id in self.evictor: + # When we require the block to be allocated even if it's cached, + # it must not be in evictor. return False # We only consider the blocks that are marked as computed. - if not self._block_tracker[cached_block_id].computed: - return False - - return True + return self.block_is_computed(cached_block_id) from bisect import bisect_left + # Look for the first block that's not cached, and returns the prefix + # , i.e. blocks that are cached. idx = bisect_left( - block_hashes, True, key=lambda x: not block_is_active_computed(x) + block_hashes, True, key=lambda x: not block_is_cached(x) ) return block_hashes[:idx] @@ -736,7 +724,9 @@ def __init__( assert isinstance(allocator, PrefixCachingBlockAllocator), ( "Currently this class is only tested with " "PrefixCachingBlockAllocator. Got instead allocator = {}".format( - allocator)) + allocator + ) + ) assert_prefix_caching_block_or_none(prev_block) self._prev_block = prev_block @@ -879,8 +869,11 @@ def set_content_hash(self, content_hash: Optional[int]) -> None: self._cached_content_hash = content_hash @staticmethod - def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], - cur_block_token_ids: List[int]) -> int: + def hash_block_tokens( + is_first_block: bool, + prev_block_hash: Optional[int], + cur_block_token_ids: List[int], + ) -> int: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. @@ -901,44 +894,163 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], assert (prev_block_hash is None) == is_first_block return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) -# TODO(rickyx): This is not used anymore. Or it could be used to track -# cached blocks for a sequence, so the sequence would be decoupled from the computed -# block hash calculation. + class ComputedBlocksTracker: - """Handles caching of per-sequence computed block ids. - When a sequence appears for the first time, it traverses all of the - blocks and detects the prefix of blocks that is computed. On the - subsequent times, it only traverses the new blocks that were added - and updates the already recorded prefix of blocks with the newly - computed blocks. - - To avoid redundant traversals, the algorithm also detects when there - is a "gap" in the computed prefix. For example, if we have blocks = - [1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then - we won't try to add more computed blocks to [1,2,3] in this sequence - iteration, and will add more computed blocks only after the sequence is - freed and reused again. - - Note that currently, for a given sequence, we also skip the last - block id for caching purposes, to avoid caching of a full sequence + """ + Tracks the computed blocks for a sequence. """ - def __init__(self, allocator): + def __init__( + self, allocator: BlockAllocator, block_size: int, enable_caching: bool + ): self._allocator = allocator - self._cached_computed_seq_blocks: Dict[int, Tuple[List[int], - bool]] = {} + self._block_size = block_size + self._enable_caching = enable_caching + # A map from seq_id to the list of block hashes for the + # sequence. This is so that we don't have to recompute the block hashes + # for the sequence when we need to check if the sequence is cached. + self._full_blocks_hashes: Dict[int, List[int]] = {} - def add_seq(self, seq_id: int) -> None: + # A map from (seq_id, and allocated status) to the number of tokens + # that are cached for the sequence. + self._num_tokens_computed: Dict[Tuple[int, bool], int] = {} + + def _add_seq(self, seq: Sequence) -> None: """Start tracking seq_id """ - assert seq_id not in self._cached_computed_seq_blocks - self._cached_computed_seq_blocks[seq_id] = ([], False) + if not self._enable_caching: + return + assert seq.seq_id not in self._full_blocks_hashes + self._full_blocks_hashes[seq.seq_id] = [] - def remove_seq(self, seq_id: int) -> None: + def remove_seq(self, seq: Sequence) -> None: """Stop tracking seq_id """ - assert seq_id in self._cached_computed_seq_blocks - del self._cached_computed_seq_blocks[seq_id] + if not self._enable_caching: + return + assert seq.seq_id in self._full_blocks_hashes + del self._full_blocks_hashes[seq.seq_id] + if (seq.seq_id, True) in self._num_tokens_computed: + del self._num_tokens_computed[(seq.seq_id, True)] + if (seq.seq_id, False) in self._num_tokens_computed: + del self._num_tokens_computed[(seq.seq_id, False)] + + def update_seq(self, seq: Sequence) -> None: + if not self._enable_caching: + return + + if seq.seq_id not in self._full_blocks_hashes: + self._add_seq(seq) + + block_hashes = self._full_blocks_hashes[seq.seq_id] + cur_num_blocks_recorded = len(block_hashes) + token_ids = seq.get_token_ids() + assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, ( + f"The sequence has {len(token_ids)} tokens, but" + f" already recorded {cur_num_blocks_recorded} blocks. " + "This should not happen since we assume blocks are " + "only added. When the sequence is recomputed, we should have " + "removed the info of the old blocks." + ) + # Update the computed block hashes for the sequence + num_total_blocks = len(token_ids) // self._block_size + + # We need to know the hash of the previous block to compute the hash of + # the current block so that blocks could be uniquely identified across + # sequences of prefixes. + prev_block_hash = ( + None if cur_num_blocks_recorded == 0 else block_hashes[-1] + ) + # Only update the computed block hashes for the new blocks + for i in range(cur_num_blocks_recorded, num_total_blocks): + block_hash = seq.hash_of_block( + prev_block_hash=prev_block_hash, cur_block_idx=i + ) + block_hashes.append(block_hash) + prev_block_hash = block_hash + + def get_block_hashes( + self, seq: Sequence, start_block_idx: int = 0, end_block_idx: int = -1 + ) -> List[Optional[int]]: + """ + Returns the list of block hashes for the sequence. + If the last block is not yet full, its hash is None. + + Args: + start_block_idx: The index of the first block to return. + end_block_idx: The index of the last block to return. + + Returns: + The list of block hashes for the sequence. + """ + num_blocks = cdiv(len(seq.get_token_ids()), self._block_size) + if end_block_idx == -1: + end_block_idx = num_blocks + + if not self._enable_caching: + return [None] * (end_block_idx - start_block_idx) + + # Prefix caching is enabled. + if seq.seq_id not in self._full_blocks_hashes: + self.update_seq(seq) + + seq_block_hashes = self._full_blocks_hashes[seq.seq_id] + num_full_blocks = len(seq_block_hashes) + assert num_blocks - num_full_blocks <= 1, ( + "There should only be at most one block in the end of the " + f"sequence that's not yet computed and full. Got {num_blocks} " + f"blocks, {num_full_blocks} full blocks." + ) + # Add the None block if the last block is not yet full. + seq_block_hashes = seq_block_hashes + [None] * ( + num_blocks - num_full_blocks + ) + return seq_block_hashes[start_block_idx:end_block_idx] + + def get_num_tokens_computed(self, seq: Sequence, allocated: bool) -> int: + """ + Returns the number of tokens that are computed for the sequence. + + When `allocated` is True, only the blocks that are allocated are + returned. (Excluding those blocks that are currently in the evictor.) + + This routine is not O(1) because it needs to search through the + list of blocks for the sequence. Caller should cache the result + if possible. + + Args: + allocated: Whether the returned blocks are allocated. + + Returns: + The number of tokens that are computed for the sequence. + """ + if not self._enable_caching: + return 0 + + if seq.seq_id not in self._full_blocks_hashes: + self.update_seq(seq) + + num_computed_tokens_prev = self._num_tokens_computed.get( + (seq.seq_id, allocated), None + ) + if num_computed_tokens_prev is not None and seq.is_prefill(): + # For a sequence that is still in prefill, we don't have to + # recompute the number of cached tokens. + # This also handles correctly chunked prefill since currently + # we mark blocks as computed even if the sequence is still partially + # prefilled. So a continuously prefilled sequence should not + # see its cached token count change while running. + return num_computed_tokens_prev + + block_hashes = self._full_blocks_hashes[seq.seq_id] + + # This is currently O(logN), where N is the number of blocks. + num_cached_blocks = len( + self._allocator.find_cached_blocks_prefix(block_hashes, allocated) + ) + num_cached_tokens = num_cached_blocks * self._block_size + self._num_tokens_computed[(seq.seq_id, allocated)] = num_cached_tokens + return num_cached_tokens class LastAccessBlocksTracker: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index df09a3a30743c..51a5a66928cbd 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -5,7 +5,6 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.core.block.interfaces import Block from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec @@ -101,7 +100,8 @@ def __init__( self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} self._computed_blocks_tracker = ComputedBlocksTracker( - self.block_allocator) + self.block_allocator, self.block_size, enable_caching + ) self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) @@ -118,7 +118,26 @@ def _get_num_blocks_to_allocate( Returns: int: The number of new blocks to allocate. """ - num_cached_tokens = seq.get_num_cached_tokens() + + # It's important to exclude the blocks that are currently in the + # evictor for cached blocks when calculating the number of new blocks + # to allocate. + # + # This is because the blocks that are in the evictor are not yet + # allocated, and so we should not count them towards the number of + # blocks we need to allocate. + # + # consider a scenario with a seqeuence of 3 token blocks, and a block + # pool of only 2 blocks: [b0, b1, b2], where b0 and b1 are computed + # but evicted, b2 is not computed. So b0, b1 are the 2 free blocks, + # in evictor. When deciding how many more blocks need to be allocated for + # this sequence, it should be all 3 blocks (b0, b1, b2) rather than + # just 1 block (b2). + num_cached_tokens = ( + self._computed_blocks_tracker.get_num_tokens_computed( + seq, allocated=True + ) + ) assert ( num_cached_tokens % self.block_size == 0 @@ -128,24 +147,23 @@ def _get_num_blocks_to_allocate( return num_required_blocks - num_cached_blocks - def get_num_computed_tokens(self, seq: Sequence) -> int: + def get_num_cached_tokens(self, seq: Sequence, allocated: bool) -> int: """ - Get the number of computed tokens for a sequence. - - NOTE: This only returns tokens in blocks that are BOTH cached and allocated (active). + Get the number of cached tokens for a sequence (which might be + unscheduled yet). Args: - seq (Sequence): The sequence to get the number of computed tokens for. + seq (Sequence): The sequence to get the number of cached tokens for. + allocated (bool): Whether the cached tokens should be in blocks that + are allocated. Returns: - int: The number of allocated and cached computed tokens. + int: The number of cached tokens. """ - seq_blocks = seq.get_block_hashes() - cached_seq_blocks = self.block_allocator.get_allocated_cached_blocks( - block_hashes=seq_blocks, - device=Device.GPU, + + return self._computed_blocks_tracker.get_num_tokens_computed( + seq, allocated ) - return len(cached_seq_blocks) * self.block_size def can_allocate(self, seq_group: SequenceGroup, @@ -156,6 +174,8 @@ def can_allocate(self, check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + self._computed_blocks_tracker.update_seq(seq) + num_blocks_to_allocate = self._get_num_blocks_to_allocate( seq, num_lookahead_slots ) @@ -163,6 +183,7 @@ def can_allocate(self, if seq_group.is_encoder_decoder(): encoder_seq = seq_group.get_encoder_seq() assert encoder_seq is not None + self._computed_blocks_tracker.update_seq(encoder_seq) num_blocks_to_allocate += self._get_num_blocks_to_allocate( encoder_seq, num_lookahead_slots=0 ) @@ -187,11 +208,18 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_size=self.block_size, block_allocator=self.block_allocator, max_block_sliding_window=self.max_block_sliding_window, - enable_prefix_caching=self.enable_caching, ) + # TODO(rickyx): With chunked prefill, we actually still allocate the + # entire sequence of blocks here. It's possible to allocate only the + # blocks that are actually being prefilled. if seq.get_token_ids(): # Add blocks to the block table only if the sequence is non empty. - block_table.allocate(seq) + block_table.allocate( + token_ids=seq.get_token_ids(), + block_hashes=self._computed_blocks_tracker.get_block_hashes( + seq + ), + ) return block_table @@ -209,7 +237,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: self.block_tables[seq.seq_id] = block_table # Track seq - self._computed_blocks_tracker.add_seq(seq.seq_id) self._last_access_blocks_tracker.add_seq(seq.seq_id) # Assign the block table for each sequence. @@ -217,7 +244,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: self.block_tables[seq.seq_id] = block_table.fork() # Track seq - self._computed_blocks_tracker.add_seq(seq.seq_id) self._last_access_blocks_tracker.add_seq(seq.seq_id) # Allocate cross-attention block table for encoder sequence @@ -257,6 +283,7 @@ def can_append_slots(self, seq_group: SequenceGroup, # The newly appended tokens might create one or more full blocks, which # could be reused from the cache. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + self._computed_blocks_tracker.update_seq(seq) block_table = self.block_tables[seq.seq_id] num_touched_blocks += block_table.get_num_blocks_touched_by_append_slots( token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), @@ -278,8 +305,19 @@ def append_slots( # Extend the block table for any new decoded tokens, as well as reserve # space for lookahead slots. + unseen_token_ids = block_table.get_unseen_token_ids(seq.get_token_ids()) + if len(unseen_token_ids) == 0: + block_hashes = [] + else: + block_hashes = self._computed_blocks_tracker.get_block_hashes( + seq, + start_block_idx=block_table.num_full_slots // self.block_size, + ) + block_table.append_slots( - seq=seq, + token_ids=unseen_token_ids, + block_hashes=block_hashes, + num_computed_slots=seq.data.get_num_computed_tokens(), num_lookahead_slots=num_lookahead_slots, ) # Return any new copy-on-writes. @@ -299,7 +337,7 @@ def free(self, seq: Sequence) -> None: # Untrack seq self._last_access_blocks_tracker.remove_seq(seq_id) - self._computed_blocks_tracker.remove_seq(seq_id) + self._computed_blocks_tracker.remove_seq(seq) # Free table/blocks self.block_tables[seq_id].free() @@ -356,10 +394,14 @@ def get_common_computed_block_ids( computed_seq_block_ids = [] for seq in seqs: all_blocks = self.block_tables[seq.seq_id].physical_block_ids - num_cached_token = seq.get_num_cached_tokens() - assert num_cached_token % self.block_size == 0 - num_cached_block = num_cached_token // self.block_size - computed_block_ids = all_blocks[:num_cached_block] + num_cached_tokens = ( + self._computed_blocks_tracker.get_num_tokens_computed( + seq, allocated=True + ) + ) + assert num_cached_tokens % self.block_size == 0 + num_cached_blocks = num_cached_tokens // self.block_size + computed_block_ids = all_blocks[:num_cached_blocks] computed_seq_block_ids.append(computed_block_ids) # NOTE(sang): This assumes seq_block_ids doesn't contain any None. @@ -374,7 +416,6 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_tables[child_seq.seq_id] = src_block_table.fork() # Track child seq - self._computed_blocks_tracker.add_seq(child_seq.seq_id) self._last_access_blocks_tracker.add_seq(child_seq.seq_id) def can_swap_in(self, seq_group: SequenceGroup, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1434eb9b3115d..7b0a81778c07f 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -560,8 +560,17 @@ def _schedule_running( assert len(self._async_stopped) == 0 while running_queue: seq_group = running_queue[0] - num_running_tokens = self._get_num_new_tokens( - seq_group, SequenceStatus.RUNNING, enable_chunking, budget) + num_running_tokens, num_running_tokens_cached = ( + self._get_num_new_tokens_to_schedule( + seq_group, SequenceStatus.RUNNING, enable_chunking, budget + ) + ) + + assert num_running_tokens_cached == 0, ( + "No tokens should have been cached for running seq groups " + "(be it in continuous prefill with chunked prefill or in " + "decode)" + ) if num_running_tokens == 0: # No budget => Stop @@ -735,13 +744,16 @@ def _schedule_swapped( # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.SWAPPED, - enable_chunking, budget) + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_tokens_to_schedule( + seq_group, SequenceStatus.SWAPPED, enable_chunking, budget + ) + ) - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): break if lora_int_id > 0 and curr_loras is not None: @@ -752,12 +764,20 @@ def _schedule_swapped( is_prefill = seq_group.is_prefill() if is_prefill: prefill_seq_groups.append( - ScheduledSequenceGroup(seq_group, - token_chunk_size=num_new_tokens)) + ScheduledSequenceGroup( + seq_group, + token_chunk_size=num_new_tokens_uncached + + num_new_tokens_cached, + ) + ) else: decode_seq_groups.append( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) - budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs @@ -824,29 +844,44 @@ def _schedule_priority_preemption( if waiting_queue: seq_group = waiting_queue.popleft() num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - False, budget) + num_new_tokens_uncached, _ = self._get_num_new_tokens_to_schedule( + seq_group, SequenceStatus.WAITING, False, budget + ) # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) - if (num_new_tokens and can_allocate == AllocStatus.OK - and budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + if ( + num_new_tokens_uncached + and can_allocate == AllocStatus.OK + and budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ) + ): break # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() - num_running_tokens = self._get_num_new_tokens( - vseq_group, SequenceStatus.RUNNING, False, budget) - budget.subtract_num_batched_tokens(vseq_group.request_id, - num_running_tokens) + num_running_tokens_uncached, num_running_tokens_cached = ( + self._get_num_new_tokens_to_schedule( + vseq_group, SequenceStatus.RUNNING, False, budget + ) + ) + assert num_running_tokens_cached == 0, ( + "No tokens should have been cached for running seq " + "groups (be it in continuous prefill with chunked prefill " + "or in decode)" + ) + budget.subtract_num_batched_tokens( + vseq_group.request_id, num_running_tokens_uncached + ) num_running_seqs = vseq_group.get_max_num_running_seqs() - budget.subtract_num_seqs(vseq_group.request_id, - num_running_seqs) + budget.subtract_num_seqs( + vseq_group.request_id, num_running_seqs + ) # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out, @@ -905,19 +940,17 @@ def _schedule_prefills( "Waiting sequence group should have only one prompt " "sequence.") seq = waiting_seqs[0] - self._update_prefix_cached_tokens(seq) - num_new_tokens = self._get_num_new_tokens( - seq_group, - SequenceStatus.WAITING, - enable_chunking, - budget, - ) - - num_new_tokens_exclude_cached = self._get_num_new_tokens_exclude_cached( - num_new_tokens, seq + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_tokens_to_schedule( + seq_group, + SequenceStatus.WAITING, + enable_chunking, + budget, + ) ) - + num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached if not enable_chunking: + # Sanity check that the number of new tokens is correct. num_prompt_tokens = seq.get_len() assert num_new_tokens == num_prompt_tokens @@ -968,14 +1001,12 @@ def _schedule_prefills( num_new_seqs = seq_group.get_max_num_running_seqs() if num_new_tokens == 0: - # TODO(rickyx): this could be made earlier? - # No more new tokens to schedule. + # No more new tokens for prefill. break - assert num_new_tokens > 0 # We have new tokens but they might be cached. if not budget.can_schedule( - num_new_tokens=num_new_tokens_exclude_cached, + num_new_tokens=num_new_tokens_uncached, num_new_seqs=num_new_seqs, ): # No more budget for new tokens. @@ -987,20 +1018,6 @@ def _schedule_prefills( waiting_queue.popleft() self._allocate_and_set_running(seq_group) - # NOTE(rickyx): We are updating this again since some of the previously - # cached blocks that were in evictor might now become active again. - # Therefore, the actual number of tokens cached might have changed. - self._update_prefix_cached_tokens(seq) - num_new_tokens = self._get_num_new_tokens( - seq_group, - SequenceStatus.RUNNING, - enable_chunking, - budget, - ) - num_new_tokens_uncached = self._get_num_new_tokens_exclude_cached( - num_new_tokens, seq - ) - if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] # init_multi_step_from_lookahead_slots happens in append_slots @@ -1013,10 +1030,10 @@ def _schedule_prefills( else: seq_group.init_multi_step_from_lookahead_slots( num_lookahead_slots, - num_scheduler_steps=self.scheduler_config. - num_scheduler_steps, + num_scheduler_steps=self.scheduler_config.num_scheduler_steps, is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking) + enable_chunking=enable_chunking, + ) seq_groups.append( ScheduledSequenceGroup( @@ -1026,7 +1043,7 @@ def _schedule_prefills( budget.add_num_batched_tokens( seq_group.request_id, num_batched_tokens=num_new_tokens_uncached, - num_cached_tokens=num_new_tokens - num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) @@ -1039,11 +1056,13 @@ def _schedule_prefills( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking)) + is_prefill=True, enable_chunking=enable_chunking + ), + ) def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can @@ -1057,13 +1076,20 @@ def _schedule_default(self) -> SchedulerOutputs: # Make sure we include num running seqs before scheduling prefill, # so that we don't schedule beyond max_num_seqs for prefill. for seq_group in self.running: - budget.add_num_seqs(seq_group.request_id, - seq_group.get_max_num_running_seqs()) + budget.add_num_seqs( + seq_group.request_id, seq_group.get_max_num_running_seqs() + ) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - curr_loras = set( - seq_group.lora_int_id for seq_group in self.running - if seq_group.lora_int_id > 0) if self.lora_enabled else None + curr_loras = ( + set( + seq_group.lora_int_id + for seq_group in self.running + if seq_group.lora_int_id > 0 + ) + if self.lora_enabled + else None + ) prefills = SchedulerPrefillOutputs.create_empty() running_scheduled = SchedulerRunningOutputs.create_empty() @@ -1071,30 +1097,37 @@ def _schedule_default(self) -> SchedulerOutputs: # If any requests are swapped, prioritized swapped requests. if not self.swapped: - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=False) + prefills = self._schedule_prefills( + budget, curr_loras, enable_chunking=False + ) - if len(prefills.seq_groups - ) == 0 and self.scheduler_config.policy == "priority": + if ( + len(prefills.seq_groups) == 0 + and self.scheduler_config.policy == "priority" + ): self._schedule_priority_preemption(budget) # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. if len(prefills.seq_groups) == 0: - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=False) + running_scheduled = self._schedule_running( + budget, curr_loras, enable_chunking=False + ) # If any sequence group is preempted, do not swap in any sequence # group. because it means there's no slot for new running requests. - if len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) == 0: + if ( + len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) + == 0 + ): swapped_in = self._schedule_swapped(budget, curr_loras) - assert (budget.num_batched_tokens <= - self.scheduler_config.max_num_batched_tokens) + assert ( + budget.num_batched_tokens + <= self.scheduler_config.max_num_batched_tokens + ) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -1107,12 +1140,14 @@ def _schedule_default(self) -> SchedulerOutputs: if len(swapped_in.decode_seq_groups) > 0: self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) + [s.seq_group for s in swapped_in.decode_seq_groups] + ) # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) - preempted = (len(running_scheduled.preempted) + - len(running_scheduled.swapped_out)) + preempted = len(running_scheduled.preempted) + len( + running_scheduled.swapped_out + ) # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. @@ -1150,7 +1185,7 @@ def _schedule_default(self) -> SchedulerOutputs: def _schedule_chunked_prefill(self) -> SchedulerOutputs: """Schedule queued requests. - + Chunked prefill allows to chunk prefill requests, batch them together with decode requests. This policy 1. schedule as many decoding requests as possible. 2. schedule chunked prefill requests that are not @@ -1172,23 +1207,28 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: swapped_in = SchedulerSwappedInOutputs.create_empty() # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=True) + running_scheduled = self._schedule_running( + budget, curr_loras, enable_chunking=True + ) # Schedule swapped out requests. # If preemption happens, it means we don't have space for swap-in. - if len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) == 0: + if ( + len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) + == 0 + ): swapped_in = self._schedule_swapped(budget, curr_loras) # Schedule new prefills. - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=True) + prefills = self._schedule_prefills( + budget, curr_loras, enable_chunking=True + ) - assert (budget.num_batched_tokens <= - self.scheduler_config.max_num_batched_tokens), f"{budget.num_batched_tokens=}, {self.scheduler_config.max_num_batched_tokens=}" + assert ( + budget.num_batched_tokens + <= self.scheduler_config.max_num_batched_tokens + ), f"{budget.num_batched_tokens=}, {self.scheduler_config.max_num_batched_tokens=}" assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -1198,14 +1238,16 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # By default, vLLM scheduler prioritizes prefills. # Once chunked prefill is enabled, # the policy is changed to prioritize decode requests. + self.running.extend([s.seq_group for s in swapped_in.decode_seq_groups]) self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) - self.running.extend( - [s.seq_group for s in swapped_in.prefill_seq_groups]) + [s.seq_group for s in swapped_in.prefill_seq_groups] + ) self.running.extend( - [s.seq_group for s in running_scheduled.decode_seq_groups]) + [s.seq_group for s in running_scheduled.decode_seq_groups] + ) self.running.extend( - [s.seq_group for s in running_scheduled.prefill_seq_groups]) + [s.seq_group for s in running_scheduled.prefill_seq_groups] + ) self.running.extend([s.seq_group for s in prefills.seq_groups]) # Update swapped requests. @@ -1227,13 +1269,15 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: num_batched_tokens_from_budget=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy + swapped_in.blocks_to_copy, + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), preempted=( - len(running_scheduled.preempted) + len(running_scheduled.swapped_out) + len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) ), ) @@ -1244,21 +1288,25 @@ def _schedule(self) -> SchedulerOutputs: else: return self._schedule_default() - def _can_append_slots(self, seq_group: SequenceGroup, - enable_chunking: bool) -> bool: + def _can_append_slots( + self, seq_group: SequenceGroup, enable_chunking: bool + ) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ # It is True only for testing case to trigger artificial preemption. - if (self.enable_artificial_preemption - and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB - and self.artificial_preempt_cnt > 0): + if ( + self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0 + ): self.artificial_preempt_cnt -= 1 return False is_prefill = seq_group.is_prefill() num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill, enable_chunking) + is_prefill, enable_chunking + ) if is_prefill and num_lookahead_slots > 0: # Appending prefill slots only happens multi-step and @@ -1266,17 +1314,19 @@ def _can_append_slots(self, seq_group: SequenceGroup, assert self.scheduler_config.is_multi_step and enable_chunking return self.block_manager.can_append_slots( - seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots + ) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: # async_output_proc is allowed only when we have a single sequence # in the sequence group no_single_seq = seq_group.sampling_params is None or ( - seq_group.sampling_params.n == 1) + seq_group.sampling_params.n == 1 + ) return no_single_seq def schedule( - self + self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: # Schedule sequence groups. # This function call changes the internal states of the scheduler @@ -1294,13 +1344,15 @@ def schedule( # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] for i, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): + scheduler_outputs.scheduled_seq_groups + ): seq_group = scheduled_seq_group.seq_group token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) seq_group_metadata = self._seq_group_metadata_cache[ - self.cache_id].get_object() + self.cache_id + ].get_object() seq_group_metadata.seq_data.clear() seq_group_metadata.block_tables.clear() @@ -1317,7 +1369,8 @@ def schedule( # Block table for cross-attention # Also managed at SequenceGroup level cross_block_table = self.block_manager.get_cross_block_table( - seq_group) + seq_group + ) else: encoder_seq_data = None cross_block_table = None @@ -1331,7 +1384,9 @@ def schedule( if self.cache_config.enable_prefix_caching: common_computed_block_nums = ( self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING))) + seq_group.get_seqs(status=SequenceStatus.RUNNING) + ) + ) do_sample = True is_prompt = seq_group.is_prefill() @@ -1349,8 +1404,10 @@ def schedule( # NOTE: We use get_len instead of get_prompt_len because when # a sequence is preempted, prefill includes previous generated # output tokens. - if (token_chunk_size + num_computed_tokens < - seqs[0].data.get_len()): + if ( + token_chunk_size + num_computed_tokens + < seqs[0].data.get_len() + ): do_sample = False # It assumes the scheduled_seq_groups is ordered by @@ -1375,7 +1432,8 @@ def schedule( # the subsequent comms can still use delta, but # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups > 0 else None, + if scheduler_outputs.num_prefill_groups > 0 + else None, mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) @@ -1398,7 +1456,8 @@ def schedule( if allow_async_output_proc: allow_async_output_proc = self._allow_async_output_proc( - seq_group) + seq_group + ) # Now that the batch has been created, we can assume all blocks in the # batch will have been computed before the next scheduling invocation. @@ -1407,7 +1466,8 @@ def schedule( for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: self.block_manager.mark_blocks_as_computed( scheduled_seq_group.seq_group, - scheduled_seq_group.token_chunk_size) + scheduled_seq_group.token_chunk_size, + ) self._seq_group_metadata_cache[self.next_cache_id].reset() @@ -1426,8 +1486,11 @@ def schedule( self.cache_id = self.next_cache_id # Return results - return (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc) + return ( + seq_group_metadata_list, + scheduler_outputs, + allow_async_output_proc, + ) def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_manager.fork(parent_seq, child_seq) @@ -1481,10 +1544,12 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING - def _append_slots(self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], - enable_chunking: bool = False) -> None: + def _append_slots( + self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False, + ) -> None: """Appends new slots to the sequences in the given sequence group. Args: @@ -1499,13 +1564,15 @@ def _append_slots(self, """ is_prefill: bool = seq_group.is_prefill() num_lookahead_slots: int = self._get_num_lookahead_slots( - is_prefill, enable_chunking) + is_prefill, enable_chunking + ) seq_group.init_multi_step_from_lookahead_slots( num_lookahead_slots, num_scheduler_steps=self.scheduler_config.num_scheduler_steps, is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking) + enable_chunking=enable_chunking, + ) seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING if self.scheduler_config.is_multi_step and enable_chunking: @@ -1552,8 +1619,11 @@ def _preempt( "not enough KV cache space. This can affect the end-to-end " "performance. Increase gpu_memory_utilization or " "tensor_parallel_size to provide more KV cache memory. " - "total_num_cumulative_preemption=%d", seq_group.request_id, - preemption_mode, self.num_cumulative_preemption + 1) + "total_num_cumulative_preemption=%d", + seq_group.request_id, + preemption_mode, + self.num_cumulative_preemption + 1, + ) self.num_cumulative_preemption += 1 if preemption_mode == PreemptionMode.RECOMPUTE: @@ -1602,7 +1672,8 @@ def _swap_out( # entire engine. raise RuntimeError( "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error.") + "the swap space to avoid this error." + ) mapping = self.block_manager.swap_out(seq_group) blocks_to_swap_out.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): @@ -1615,17 +1686,18 @@ def _passed_delay(self, now: float) -> bool: # Delay scheduling prompts to let waiting queue fill up if self.scheduler_config.delay_factor > 0 and self.waiting: earliest_arrival_time = min( - [e.metrics.arrival_time for e in self.waiting]) - passed_delay = ( - (now - earliest_arrival_time) > - (self.scheduler_config.delay_factor * self.last_prompt_latency) - or not self.running) + [e.metrics.arrival_time for e in self.waiting] + ) + passed_delay = (now - earliest_arrival_time) > ( + self.scheduler_config.delay_factor * self.last_prompt_latency + ) or not self.running else: passed_delay = True return passed_delay - def _get_num_lookahead_slots(self, is_prefill: bool, - enable_chunking: bool) -> int: + def _get_num_lookahead_slots( + self, is_prefill: bool, enable_chunking: bool + ) -> int: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. @@ -1653,120 +1725,266 @@ def _get_num_lookahead_slots(self, is_prefill: bool, return self.scheduler_config.num_lookahead_slots - def _get_num_new_tokens( + def _get_num_new_tokens_to_schedule( self, seq_group: SequenceGroup, status: SequenceStatus, enable_chunking: bool, budget: SchedulingBudget, - ) -> int: - """Get the next new tokens to compute for a given sequence group - that's in a given `status`. + ) -> Tuple[int, int]: + """ + Returns the number of new uncomputed tokens to schedule and the number of + cached tokens for a given sequence group that's in a given `status`. The API could chunk the number of tokens to compute based on `budget` if `enable_chunking` is True. If a sequence group has multiple sequences (e.g., running beam search), it means it is in decoding phase, so chunking doesn't happen. - Returns 0 if the new token cannot be computed due to token budget. + Returns (0, 0) if the new token cannot be computed due to token budget. + + Args: + seq_group: The sequence group to get the number of new tokens to + schedule. + status: The status of the sequences to get the number of new tokens + to schedule. + enable_chunking: Whether to chunk the number of tokens to compute. + budget: The budget to chunk the number of tokens to compute. + + + Returns: + A tuple of two ints. The first int is the number of new uncached + tokens to schedule. The second int is the number of cached tokens to schedule. + + If no more new tokens can be scheduled, returns (0, 0). """ - num_new_tokens = 0 + num_cached_new_tokens = 0 + num_uncached_new_tokens = 0 + seqs = seq_group.get_seqs(status=status) for seq in seqs: - num_new_tokens += seq.get_num_new_tokens() - assert num_new_tokens > 0 - # Chunk if a running request cannot fit in the given budget. - # If number of seq > 1, it means it is doing beam search - # in a decode phase. Do not chunk. - if enable_chunking and len(seqs) == 1: - remaining_token_budget = budget.remaining_token_budget() - seq = seqs[0] - if self.scheduler_config.is_multi_step: - # The current multi-step + chunked prefill capability does - # not actually support chunking prompts. - # - # Therefore, `num_new_tokens` is computed in the same fashion - # for both multi-step+chunked-prefill & - # multi-step+chunked-prefill+APC - # - # Prompts with more tokens than the current remaining budget - # are postponed to future scheduler steps - if num_new_tokens > self._get_prompt_limit(seq_group): - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > remaining_token_budget \ - else num_new_tokens - elif self.cache_config.enable_prefix_caching: - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block - # size to avoid partial block matching. - block_size = self.cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - num_new_tokens_cached = seq.get_num_cached_tokens() - seq.get_num_computed_tokens() - num_new_tokens_cached = max(0, num_new_tokens_cached) - # Round down to block - remaining_token_budget = remaining_token_budget // block_size * block_size - - # Calculate the number of new tokens that are not cached with chunk cap. - num_new_tokens_uncached = min(num_new_tokens - num_new_tokens_cached, remaining_token_budget) - if num_new_tokens_uncached == 0: - # No more budget for new tokens, don't include any cached tokens too. - return 0 - num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached - else: - num_new_tokens = min(num_new_tokens, remaining_token_budget) - return num_new_tokens + if not seq.is_prefill(): + # Decode sequences should always just have 1 uncached token + # TODO(rickyx): Actually is this still correct for multi-step? + num_uncached_new_tokens += 1 + continue - def _update_prefix_cached_tokens(self, seq: Sequence): - """ - Update the number of prefix cached tokens for a sequence. + # We will need to consider cached tokens for prefill sequences. + num_computed_tokens_seq = seq.get_num_computed_tokens() + # We don't care about whether the blocks are allocated or not + # because we're only interested in the number of cached tokens. + # The blocks should be eventually allocated if the seq is scheduled. + # So we set allocated=False. + if self.cache_config.enable_prefix_caching: + num_cached_tokens_seq = ( + self.block_manager.get_num_cached_tokens( + seq, allocated=False + ) + ) - This function takes O(log(n)) time, where n is the number of blocks - in the sequence. - """ - num_prefix_cached_tokens = self.block_manager.get_num_computed_tokens(seq) - seq.set_num_prefix_cached_tokens(num_prefix_cached_tokens) + # Any computed token should have been cached too. + if not self.scheduler_config.chunked_prefill_enabled: + assert num_cached_tokens_seq >= num_computed_tokens_seq, ( + f"Number of cached tokens ({num_cached_tokens_seq}) " + "should be no less than the number of computed " + f"tokens ({num_computed_tokens_seq}) when chunked prefill " + "is disabled" + ) + num_new_tokens_cached_seq = max( + 0, num_cached_tokens_seq - num_computed_tokens_seq + ) + else: + num_cached_tokens_seq = 0 + num_new_tokens_cached_seq = 0 - def _get_num_new_tokens_exclude_cached( - self, num_new_tokens: int, seq: Sequence - ) -> int: - """ - Get the number of new tokens to compute for a sequence, excluding - cached tokens. + all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq + num_new_tokens_uncached_seq = ( + all_num_new_tokens_seq - num_new_tokens_cached_seq + ) - Args: - num_new_tokens: The number of new tokens to compute. - seq: The sequence to compute the new tokens for. + num_uncached_new_tokens += num_new_tokens_uncached_seq + num_cached_new_tokens += num_new_tokens_cached_seq - Returns: - Given `num_new_tokens`, returns the number of uncached tokens. - """ + if enable_chunking and len(seqs) == 1: + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. + num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( + budget, + self._get_prompt_limit(seq_group), + num_uncached_new_tokens, + ) + + return num_uncached_new_tokens, num_cached_new_tokens - # If a decode sequence, new tokens are always not computed/cached. - if not seq.is_prefill(): - return num_new_tokens + def _chunk_new_tokens_to_schedule( + self, + budget: SchedulingBudget, + prompt_limit: int, + num_new_tokens: int, + ) -> int: + remaining_token_budget = budget.remaining_token_budget() + if self.scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > prompt_limit: + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = ( + 0 + if num_new_tokens > remaining_token_budget + else num_new_tokens + ) + elif self.cache_config.enable_prefix_caching: + # When prefix caching is enabled, we always allocate + # the number of new tokens that is dividable by the block + # size to avoid partial block matching. + block_size = self.cache_config.block_size + remainder = budget.token_budget % block_size + if remainder != 0: + raise ValueError( + "When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {remainder}" + ) + # Round down to block + remaining_token_budget = ( + remaining_token_budget // block_size * block_size + ) + # Calculate the number of new tokens that are not cached with chunk cap. + num_new_tokens = min(num_new_tokens, remaining_token_budget) + else: + num_new_tokens = min(num_new_tokens, remaining_token_budget) - # If a prefill sequence, we need to exclude the number of cached tokens. - num_computed_tokens = seq.get_num_computed_tokens() - num_cached_tokens = seq.get_num_cached_tokens() + return num_new_tokens - # We subtract the number of cached tokens from the number of new tokens - num_computed_tokens_new = num_new_tokens + num_computed_tokens - num_new_tokens_exclude_cached = max( - 0, num_computed_tokens_new - num_cached_tokens - ) - assert ( - num_new_tokens_exclude_cached <= num_new_tokens - ), "Number of new tokens exclude cached should be less than or equal to the number of new tokens" - return num_new_tokens_exclude_cached + # def _get_num_new_tokens( + # self, + # seq_group: SequenceGroup, + # status: SequenceStatus, + # enable_chunking: bool, + # budget: SchedulingBudget, + # ) -> int: + # """Get the next new tokens to compute for a given sequence group + # that's in a given `status`. + + # The API could chunk the number of tokens to compute based on `budget` + # if `enable_chunking` is True. If a sequence group has multiple + # sequences (e.g., running beam search), it means it is in decoding + # phase, so chunking doesn't happen. + + # Returns 0 if the new token cannot be computed due to token budget. + # """ + # num_new_tokens = 0 + # seqs = seq_group.get_seqs(status=status) + # for seq in seqs: + # num_new_tokens += seq.get_num_new_tokens() + # assert num_new_tokens > 0 + # # Chunk if a running request cannot fit in the given budget. + # # If number of seq > 1, it means it is doing beam search + # # in a decode phase. Do not chunk. + # if enable_chunking and len(seqs) == 1: + # remaining_token_budget = budget.remaining_token_budget() + # seq = seqs[0] + # if self.scheduler_config.is_multi_step: + # # The current multi-step + chunked prefill capability does + # # not actually support chunking prompts. + # # + # # Therefore, `num_new_tokens` is computed in the same fashion + # # for both multi-step+chunked-prefill & + # # multi-step+chunked-prefill+APC + # # + # # Prompts with more tokens than the current remaining budget + # # are postponed to future scheduler steps + # if num_new_tokens > self._get_prompt_limit(seq_group): + # # If the seq_group is in prompt-stage, pass the + # # num_new_tokens as-is so the caller can ignore + # # the sequence. + # pass + # else: + # num_new_tokens = 0 \ + # if num_new_tokens > remaining_token_budget \ + # else num_new_tokens + # elif self.cache_config.enable_prefix_caching: + # # When prefix caching is enabled, we always allocate + # # the number of new tokens that is dividable by the block + # # size to avoid partial block matching. + # block_size = self.cache_config.block_size + # remainder = budget.token_budget % block_size + # if remainder != 0: + # raise ValueError("When enabling chunked prefill and " + # "prefix caching, max_num_batched_tokens " + # "(chunk size) must be dividable by " + # "block size, but got chunk_size " + # f"({budget.token_budget}) % block_size " + # f"({block_size}) = {remainder}") + # num_new_tokens_cached = seq.get_num_cached_tokens() - seq.get_num_computed_tokens() + # num_new_tokens_cached = max(0, num_new_tokens_cached) + # # Round down to block + # remaining_token_budget = remaining_token_budget // block_size * block_size + + # # Calculate the number of new tokens that are not cached with chunk cap. + # num_new_tokens_uncached = min(num_new_tokens - num_new_tokens_cached, remaining_token_budget) + # if num_new_tokens_uncached == 0: + # # No more budget for new tokens, don't include any cached tokens too. + # return 0 + # num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached + # else: + # num_new_tokens = min(num_new_tokens, remaining_token_budget) + # return num_new_tokens + # + # def _update_prefix_cached_tokens(self, seq: Sequence): + # """ + # Update the number of prefix cached tokens for a sequence. + + # This function takes O(log(n)) time, where n is the number of blocks + # in the sequence. + # """ + # num_prefix_cached_tokens = self.block_manager.get_num_computed_tokens(seq) + # seq.set_num_prefix_cached_tokens(num_prefix_cached_tokens) + + # def _get_num_new_tokens_exclude_cached( + # self, num_new_tokens: int, seq: Sequence + # ) -> int: + # """ + # Get the number of new tokens to compute for a sequence, excluding + # cached tokens. + + # Args: + # num_new_tokens: The number of new tokens to compute. + # seq: The sequence to compute the new tokens for. + + # Returns: + # Given `num_new_tokens`, returns the number of uncached tokens. + # """ + + # # If a decode sequence, new tokens are always not computed/cached. + # if not seq.is_prefill(): + # return num_new_tokens + + # # If a prefill sequence, we need to exclude the number of cached tokens. + # num_computed_tokens = seq.get_num_computed_tokens() + # num_cached_tokens = seq.get_num_cached_tokens() + + # # We subtract the number of cached tokens from the number of new tokens + # num_computed_tokens_new = num_new_tokens + num_computed_tokens + # num_new_tokens_exclude_cached = max( + # 0, num_computed_tokens_new - num_cached_tokens + # ) + # assert ( + # num_new_tokens_exclude_cached <= num_new_tokens + # ), "Number of new tokens exclude cached should be less than or equal to the number of new tokens" + # return num_new_tokens_exclude_cached diff --git a/vllm/sequence.py b/vllm/sequence.py index a3d6c0b1492ad..5d68371ce7d6c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -625,23 +625,40 @@ def _reset_block_hashes(self): def set_num_prefix_cached_tokens(self, num_prefix_cached_tokens: int): self.data.set_num_prefix_cached_tokens(num_prefix_cached_tokens) - def hash_of_block(self, logical_idx: int) -> int: - # TODO This can produce incorrect hash when block size > prompt size + def hash_of_block( + self, prev_block_hash: Optional[int], cur_block_idx: int + ) -> int: + """ + Get the hash of a given block of the sequence. - # Compute the number of tokens in the sequence - # TODO: The current hashing function is O(L^2). We should optimize - # this in the future. - num_tokens = self.num_hashed_tokens_of_block(logical_idx) - hashed_tokens = self.data.get_prefix_token_ids(num_tokens) - return hash((hashed_tokens, self.lora_int_id)) + Args: + prev_block_hash: The hash of the previous block. + block_idx: The index of the block. It should be a valid block index + of the sequence, i.e. it's a full block. - def num_hashed_tokens_of_block(self, logical_idx: int): - return logical_idx * self.block_size + self.block_size + Returns: + The hash of the block. + """ + token_ids = self.get_token_ids() + assert (cur_block_idx + 1) * self.block_size <= len(token_ids), ( + f"Invalid block index: {cur_block_idx}. The sequence only has " + f"{len(token_ids) // self.block_size} blocks." + ) + block_token_ids = token_ids[ + cur_block_idx * self.block_size : (cur_block_idx + 1) + * self.block_size + ] + return hash( + ( + prev_block_hash, + self.lora_int_id, + *block_token_ids, + ) + ) def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - self._reset_block_hashes() def append_token_id(self, token_id: int, logprobs: Dict[int, Logprob]) -> None: @@ -688,20 +705,20 @@ def fork(self, new_seq_id: int) -> "Sequence": def get_num_computed_tokens(self) -> int: return self.data.get_num_computed_tokens() - def get_num_new_tokens(self) -> int: - """Get the number of new tokens to be computed. + # def get_num_new_tokens(self) -> int: + # """Get the number of new tokens to be computed. - Returns: - The new number of tokens to be computed. I.e., 1 for decode, or - the remaining prompt size for prefill. - """ - if self.data.stage == SequenceStage.DECODE: - return 1 + # Returns: + # The new number of tokens to be computed. I.e., 1 for decode, or + # the remaining prompt size for prefill. + # """ + # if self.data.stage == SequenceStage.DECODE: + # return 1 - return self.data.get_num_uncomputed_tokens() + # return self.data.get_num_uncomputed_tokens() - def get_num_cached_tokens(self) -> int: - return self.data.get_num_prefix_cached_tokens() + # def get_num_cached_tokens(self) -> int: + # return self.data.get_num_prefix_cached_tokens() def is_prefill(self) -> bool: return self.data.stage == SequenceStage.PREFILL From 8d8853e53015360f9f21f75562a5f86644882c65 Mon Sep 17 00:00:00 2001 From: rickyx Date: Thu, 7 Nov 2024 17:53:46 +0000 Subject: [PATCH 12/12] lint --- benchmarks/benchmark_prefix_caching.py | 3 +- tests/core/block/test_block_manager.py | 11 +- tests/core/block/test_block_table.py | 27 +- tests/core/block/test_prefix_caching_block.py | 12 +- tests/core/utils.py | 2 +- tests/prefix_caching/test_prefix_caching.py | 8 +- vllm/core/block/block_table.py | 46 +- vllm/core/block/common.py | 5 +- vllm/core/block/cpu_gpu_block_allocator.py | 27 +- vllm/core/block/interfaces.py | 32 +- vllm/core/block/naive_block.py | 28 +- vllm/core/block/prefix_caching_block.py | 93 ++-- vllm/core/block_manager.py | 70 ++- vllm/core/scheduler.py | 430 +++++------------- vllm/engine/metrics.py | 5 +- vllm/sequence.py | 57 +-- 16 files changed, 304 insertions(+), 552 deletions(-) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 8eb6c8ad7606b..6ee03dc2258b8 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -117,7 +117,8 @@ def main(args): input_length_range = tuple(map(int, args.input_length_range.split(':'))) random.seed(args.seed) if args.dataset_path is not None: - print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}") + print(f"Start to sample {args.num_prompts} prompts " + f"from {args.dataset_path}") filtered_datasets = sample_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, diff --git a/tests/core/block/test_block_manager.py b/tests/core/block/test_block_manager.py index 16e2876c015b1..742ebf870ac5a 100644 --- a/tests/core/block/test_block_manager.py +++ b/tests/core/block/test_block_manager.py @@ -256,10 +256,9 @@ def test_can_allocate_with_prefix_cache( # Allocate the seq 1 block_manager.allocate(seq_group_1) - # Mark the seq 1 as computed (This shoudl be done by the scheduler in reality) - block_manager.mark_blocks_as_computed( - seq_group=seq_group_1, token_chunk_size=len(tokens_1) - ) + # Mark the seq 1 as computed (This should be done by the scheduler in reality) + block_manager.mark_blocks_as_computed(seq_group=seq_group_1, + token_chunk_size=len(tokens_1)) # Test if allocatable of seq 2. seq_group_2 = create_seq_group( @@ -399,7 +398,9 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, watermark=0, enable_caching=enable_caching) prompt, seq_group = create_dummy_prompt( - "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1, block_size=block_size + "1", + prompt_length=(num_gpu_blocks - 1) * block_size - 1, + block_size=block_size, ) prompt.status = SequenceStatus.WAITING block_manager.allocate(seq_group) diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index 06f1e297695a9..54c49657b737d 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -105,8 +105,7 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int): block_size=block_size, block_allocator=allocator, enable_prefix_caching=True, - ) - ) + )) seq = make_sequence(alloc_i, token_ids, block_size) block_tables[-1].allocate(seq=seq, device=Device.GPU) @@ -148,7 +147,8 @@ def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str, block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) for i in range(5): @@ -193,7 +193,8 @@ def test_append_token_ids_allocation(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) num_expected_blocks_before_append = len( @@ -250,7 +251,8 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) num_expected_blocks_before_append = len( @@ -308,7 +310,8 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) seq = make_sequence(0, token_ids, block_size) block_table.allocate(seq=seq, device=Device.GPU) @@ -353,7 +356,8 @@ def test_fork(seq_len: int, block_size: int, allocator_type: str): block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) seq = make_sequence(0, token_ids, block_size) @@ -414,7 +418,8 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, original_block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) num_expected_non_cow_blocks = cdiv(sequence_len, block_size) @@ -504,7 +509,8 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, original_block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) seq = make_sequence(0, token_ids, block_size) @@ -590,7 +596,8 @@ def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, block_table = BlockTable( block_size=block_size, block_allocator=allocator, - enable_prefix_caching=True if allocator_type == "prefix_caching" else False, + enable_prefix_caching=True + if allocator_type == "prefix_caching" else False, ) seq = make_sequence(0, token_ids, block_size) block_table.allocate(seq=seq, device=Device.GPU) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 8c8690bc23d1b..8195888f6bef8 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -795,9 +795,8 @@ def test_get_cached_blocks(): block_size = 16 num_blocks = 5 - allocator = PrefixCachingBlockAllocator( - block_size=block_size, num_blocks=num_blocks - ) + allocator = PrefixCachingBlockAllocator(block_size=block_size, + num_blocks=num_blocks) # 1. Allocate a list of blocks block_hashes = [random.randint(1, 1000000) for _ in range(num_blocks)] @@ -825,12 +824,11 @@ def test_get_cached_blocks(): result = allocator.get_cached_blocks(cached_hashes) assert ( result == expected_cached_blocks - ), f"Expected {expected_cached_blocks}, but got {result}, with test case {cached_hashes}. blcok hashes = {block_hashes}" + ), f"Expected {expected_cached_blocks}, but got {result}, with test case {cached_hashes}. block hashes = {block_hashes}" # Test with some non-existent hashes non_existent_hash = max(block_hashes) + 1 test_hashes = block_hashes[:3] + [non_existent_hash] + block_hashes[3:] result = allocator.get_cached_blocks(test_hashes) - assert ( - result == block_hashes[0:3] - ), f"Expected {block_hashes[0:3]}, but got {result}" + assert (result == block_hashes[0:3] + ), f"Expected {block_hashes[0:3]}, but got {result}" diff --git a/tests/core/utils.py b/tests/core/utils.py index 6a539db9fc3f7..5c2e9e9c9ca1e 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -116,7 +116,7 @@ def create_dummy_prompt_encoder_decoder( def create_seq_group( seq_prompt_len: int = 1024, - seq_output_lens: GenericSequence[int] = (128,), + seq_output_lens: GenericSequence[int] = (128, ), request_id: str = "0", seq_id_start: int = 0, sampling_params: Optional[SamplingParams] = None, diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 9b3782e97fe72..6b6ab7a85a5b9 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -43,10 +43,10 @@ def test_mixed_requests( cached_prompt = example_prompts[cached_position] with vllm_runner( - model, - dtype=dtype, - enable_prefix_caching=True, - enable_chunked_prefill=enable_chunked_prefill, + model, + dtype=dtype, + enable_prefix_caching=True, + enable_chunked_prefill=enable_chunked_prefill, ) as vllm_model: # Run the first prompt so the cache is populated vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 74962541dfb22..1554f306903e9 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -55,30 +55,6 @@ def __init__( self._max_block_sliding_window = max_block_sliding_window self._num_full_slots = self._get_num_token_ids() - # @staticmethod - # def get_num_required_blocks(token_ids: List[int], - # block_size: int, - # num_lookahead_slots: int = 0) -> int: - # """Calculates the minimum number of blocks required to store a given - # sequence of token IDs along with any look-ahead slots that may be - # required (like in multi-step + chunked-prefill). - - # This assumes worst-case scenario, where every block requires a new - # allocation (e.g. ignoring prefix caching). - - # Args: - # token_ids (List[int]): The sequence of token IDs to be stored. - # block_size (int): The maximum number of tokens that can be stored in - # a single block. - # num_lookahead_slots (int): look-ahead slots that the sequence may - # require. - - # Returns: - # int: The minimum number of blocks required to store the given - # sequence of token IDs along with any required look-ahead slots. - # """ - # return cdiv(len(token_ids) + num_lookahead_slots, block_size) - def allocate( self, token_ids: List[int], @@ -100,14 +76,13 @@ def allocate( if not token_ids: return - blocks = self._allocate_blocks_for_token_ids( - token_ids, block_hashes, device - ) + blocks = self._allocate_blocks_for_token_ids(token_ids, block_hashes, + device) self.update(blocks) self._num_full_slots = len(token_ids) def update(self, blocks: List[Block]) -> None: - """Resets the table to the newly provided blocks + """Resets the table to the newly provided blocks (with their corresponding block ids) """ self._blocks.update(blocks) @@ -164,17 +139,14 @@ def append_slots( # Update the blocks with the new tokens first_block_idx = self._num_full_slots // self._block_size token_blocks = self._chunk_token_blocks_for_append(token_ids) - - if len(token_blocks) != len(block_hashes): - breakpoint() - assert len(token_blocks) == len( block_hashes ), "chunked token_ids and block_hashes must have the same length" for i, token_block in enumerate(token_blocks): block_hash = block_hashes[i] - self._blocks.append_token_ids(first_block_idx + i, token_block, block_hash) + self._blocks.append_token_ids(first_block_idx + i, token_block, + block_hash) self._num_full_slots += len(token_ids) @@ -304,10 +276,9 @@ def _allocate_blocks_for_token_ids( self._allocator.allocate_immutable_blocks( prev_block, block_token_ids=block_token_ids, - block_hashes=block_hashes[: len(block_token_ids)], + block_hashes=block_hashes[:len(block_token_ids)], device=device, - ) - ) + )) prev_block = blocks[-1] if tail_token_ids: @@ -315,8 +286,7 @@ def _allocate_blocks_for_token_ids( assert block_hashes[-1] is None cur_token_ids = tail_token_ids[0] block = self._allocator.allocate_mutable_block( - prev_block=prev_block, device=device - ) + prev_block=prev_block, device=device) block.append_token_ids(cur_token_ids, block_hash=None) blocks.append(block) diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index c2117ccaaeb50..6815a3808876e 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -254,9 +254,8 @@ def update(self, blocks: List[Block]): for block in self._blocks: self._add_block_id(block.block_id) - def append_token_ids( - self, block_index: int, token_ids: List[int], block_hash: Optional[int] - ) -> None: + def append_token_ids(self, block_index: int, token_ids: List[int], + block_hash: Optional[int]) -> None: block = self._blocks[block_index] prev_block_id = block.block_id diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 1674514f354e7..0c8c00a461e99 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -139,7 +139,7 @@ def allocate_immutable_blocks( prev_block: Optional[Block], block_token_ids: List[List[int]], device: Device, - block_hashes: Optional[List[Optional[int]]] = None, + block_hashes: List[Optional[int]], ) -> List[Block]: """Allocates a new group of immutable blocks with the provided block token IDs on the specified device. @@ -156,15 +156,7 @@ def allocate_immutable_blocks( containing the provided block token IDs. """ return self._allocators[device].allocate_immutable_blocks( - prev_block, block_token_ids, block_hashes - ) - - def get_allocated_cached_blocks( - self, - block_hashes: List[int], - device: Device, - ) -> List[int]: - return self._allocators[device].get_allocated_cached_blocks(block_hashes) + prev_block, block_token_ids, block_hashes) def allocate_immutable_block(self, prev_block: Optional[Block], token_ids: List[int], @@ -353,14 +345,12 @@ def get_and_reset_swaps(self) -> List[Tuple[int, int]]: self._swap_mapping.clear() return list(mapping.items()) - def find_cached_blocks_prefix( - self, block_hashes: List[int], allocated: bool - ) -> List[int]: + def find_cached_blocks_prefix(self, block_hashes: List[int], + allocated: bool) -> List[int]: # Prefix caching only supported on GPU. device = Device.GPU return self._allocators[device].find_cached_blocks_prefix( - block_hashes, allocated - ) + block_hashes, allocated) class NullBlock(Block): @@ -376,7 +366,9 @@ def __init__(self, proxy: Block): super().__init__() self._proxy = proxy - def append_token_ids(self, token_ids: List[BlockId]): + def append_token_ids(self, + token_ids: List[BlockId], + block_hash: Optional[int] = None) -> None: raise ValueError("null block should not be modified") @property @@ -429,4 +421,5 @@ def content_hash(self): return self._proxy.content_hash def set_content_hash(self, content_hash: Optional[int]) -> None: - raise NotImplementedError("NullBlock does not support set_content_hash") + raise NotImplementedError( + "NullBlock does not support set_content_hash") diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index f55133f92ee0c..08149a0aabb17 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -10,9 +10,9 @@ class Block(ABC): @abstractmethod - def append_token_ids( - self, token_ids: List[int], block_hash: Optional[int] = None - ) -> None: + def append_token_ids(self, + token_ids: List[int], + block_hash: Optional[int] = None) -> None: pass @property @@ -204,9 +204,8 @@ def get_prefix_cache_hit_rate(self) -> float: pass @abstractmethod - def find_cached_blocks_prefix( - self, block_hashes: List[int], allocated: bool - ) -> List[int]: + def find_cached_blocks_prefix(self, block_hashes: List[int], + allocated: bool) -> List[int]: pass class NoFreeBlocksError(ValueError): @@ -227,9 +226,13 @@ def allocate_immutable_block(self, prev_block: Optional[Block], pass @abstractmethod - def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device) -> List[Block]: + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + block_hashes: List[Optional[int]], + ) -> List[Block]: pass @abstractmethod @@ -306,13 +309,6 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: pass @abstractmethod - def get_allocated_cached_blocks( - self, block_hashes: List[int], device: Device - ) -> List[int]: + def find_cached_blocks_prefix(self, block_hashes: List[int], + allocated: bool) -> List[int]: pass - - @abstractmethod - def find_cached_blocks_prefix( - self, block_hashes: List[int], allocated: bool, device: Device - ) -> List[int]: - pass \ No newline at end of file diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 83725faef6d39..9bc4a316555a3 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -85,10 +85,10 @@ def allocate_immutable_block( return block def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - block_hashes: Optional[List[Optional[int]]] = None, + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + block_hashes: List[Optional[int]], # Not used ) -> List[Block]: num_blocks = len(block_token_ids) @@ -335,12 +335,9 @@ def swap_in(self, blocks: List[Block]) -> None: def get_prefix_cache_hit_rate(self) -> float: return -1 - def get_allocated_cached_blocks(self, block_hashes: List[int]) -> List[int]: - return [] - - def find_cached_blocks_prefix( - self, block_hashes: List[int], allocated: bool = False - ) -> List[int]: + def find_cached_blocks_prefix(self, + block_hashes: List[int], + allocated: bool = False) -> List[int]: return [] @@ -382,9 +379,9 @@ def __init__(self, self._append_token_ids_no_cow(token_ids) - def append_token_ids( - self, token_ids: List[int], block_hash: Optional[int] = None - ) -> None: + def append_token_ids(self, + token_ids: List[int], + block_hash: Optional[int] = None) -> None: """Appends the given token IDs to the block and performs a copy-on-write if necessary. @@ -465,7 +462,6 @@ def prev_block(self) -> Optional["Block"]: def content_hash(self) -> Optional[int]: return None - def set_content_hash(self, content_hash: int) -> None: + def set_content_hash(self, content_hash: Optional[int]) -> None: raise NotImplementedError( - "Setting content hash is not supported for naive block" - ) + "Setting content hash is not supported for naive block") diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 035a9a8dd92f9..61e1fd56ce970 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -4,7 +4,13 @@ from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device +from vllm.core.block.interfaces import ( + Block, + BlockAllocator, + BlockId, + Device, + DeviceAwareBlockAllocator, +) from vllm.core.block.naive_block import (BlockPool, NaiveBlock, NaiveBlockAllocator) from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor @@ -153,7 +159,8 @@ def allocate_immutable_block( Returns: Block: The allocated immutable block. """ - assert len(token_ids) == self._block_size, "An immutable block should be full" + assert (len(token_ids) == self._block_size + ), "An immutable block should be full" assert ( block_hash is not None ), "An immutable block should have a content hash for prefix caching" @@ -405,7 +412,8 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int: assert device is None # The number of free blocks is the number of hashless free blocks # plus the number of blocks evictor could free from its list. - return self._hashless_allocator.get_num_free_blocks() + self.evictor.num_blocks + return (self._hashless_allocator.get_num_free_blocks() + + self.evictor.num_blocks) def get_num_total_blocks(self) -> int: return self._hashless_allocator.get_num_total_blocks() @@ -638,18 +646,16 @@ def swap_in(self, blocks: List[Block]) -> None: # and the block_id is assigned to "block" to allow reusing the # existing "block" object if block.is_full: - assert ( - block.content_hash is not None - ), "Block is full but has no content hash" + assert (block.content_hash + is not None), "Block is full but has no content hash" tmp_block = self.allocate_immutable_block( prev_block=block.prev_block, token_ids=block.token_ids, block_hash=block.content_hash, ) else: - assert ( - block.content_hash is None - ), "Block is not full but has content hash" + assert (block.content_hash is + None), "Block is not full but has content hash" tmp_block = self.allocate_mutable_block( prev_block=block.prev_block) tmp_block.append_token_ids(block.token_ids, block_hash=None) @@ -659,9 +665,9 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id - def find_cached_blocks_prefix( - self, block_hashes: List[PrefixHash], allocated: bool = False - ) -> List[PrefixHash]: + def find_cached_blocks_prefix(self, + block_hashes: List[PrefixHash], + allocated: bool = False) -> List[PrefixHash]: """ Return the prefix of the block hashes that are already computed and cached. @@ -686,9 +692,9 @@ def block_is_cached(block_hash: PrefixHash) -> bool: # Look for the first block that's not cached, and returns the prefix # , i.e. blocks that are cached. - idx = bisect_left( - block_hashes, True, key=lambda x: not block_is_cached(x) - ) + idx = bisect_left(block_hashes, + True, + key=lambda x: not block_is_cached(x)) return block_hashes[:idx] @@ -724,9 +730,7 @@ def __init__( assert isinstance(allocator, PrefixCachingBlockAllocator), ( "Currently this class is only tested with " "PrefixCachingBlockAllocator. Got instead allocator = {}".format( - allocator - ) - ) + allocator)) assert_prefix_caching_block_or_none(prev_block) self._prev_block = prev_block @@ -788,9 +792,9 @@ def last_accessed(self) -> float: def last_accessed(self, last_accessed_ts: float): self._last_accessed = last_accessed_ts - def append_token_ids( - self, token_ids: List[int], block_hash: Optional[int] = None - ) -> None: + def append_token_ids(self, + token_ids: List[int], + block_hash: Optional[int] = None) -> None: """Appends the given token IDs to the block and registers the block as immutable if the block becomes full. @@ -860,9 +864,8 @@ def set_content_hash(self, content_hash: Optional[int]) -> None: assert self.content_hash is None, "Content hash already set" if content_hash is None: # This could happen when forking a mutable block. - assert ( - not self.is_full - ), "Block should not be full when new content hash is None" + assert (not self.is_full + ), "Block should not be full when new content hash is None" # No op. return assert self.is_full, "Block is not full when setting content hash" @@ -901,7 +904,10 @@ class ComputedBlocksTracker: """ def __init__( - self, allocator: BlockAllocator, block_size: int, enable_caching: bool + self, + allocator: DeviceAwareBlockAllocator, + block_size: int, + enable_caching: bool, ): self._allocator = allocator self._block_size = block_size @@ -950,28 +956,26 @@ def update_seq(self, seq: Sequence) -> None: f" already recorded {cur_num_blocks_recorded} blocks. " "This should not happen since we assume blocks are " "only added. When the sequence is recomputed, we should have " - "removed the info of the old blocks." - ) + "removed the info of the old blocks.") # Update the computed block hashes for the sequence num_total_blocks = len(token_ids) // self._block_size # We need to know the hash of the previous block to compute the hash of # the current block so that blocks could be uniquely identified across # sequences of prefixes. - prev_block_hash = ( - None if cur_num_blocks_recorded == 0 else block_hashes[-1] - ) + prev_block_hash = (None if cur_num_blocks_recorded == 0 else + block_hashes[-1]) # Only update the computed block hashes for the new blocks for i in range(cur_num_blocks_recorded, num_total_blocks): - block_hash = seq.hash_of_block( - prev_block_hash=prev_block_hash, cur_block_idx=i - ) + block_hash = seq.hash_of_block(prev_block_hash=prev_block_hash, + cur_block_idx=i) block_hashes.append(block_hash) prev_block_hash = block_hash - def get_block_hashes( - self, seq: Sequence, start_block_idx: int = 0, end_block_idx: int = -1 - ) -> List[Optional[int]]: + def get_block_hashes(self, + seq: Sequence, + start_block_idx: int = 0, + end_block_idx: int = -1) -> List[Optional[int]]: """ Returns the list of block hashes for the sequence. If the last block is not yet full, its hash is None. @@ -994,17 +998,16 @@ def get_block_hashes( if seq.seq_id not in self._full_blocks_hashes: self.update_seq(seq) - seq_block_hashes = self._full_blocks_hashes[seq.seq_id] + seq_block_hashes: List[Optional[int]] = self._full_blocks_hashes[ + seq.seq_id] # type: ignore num_full_blocks = len(seq_block_hashes) assert num_blocks - num_full_blocks <= 1, ( "There should only be at most one block in the end of the " f"sequence that's not yet computed and full. Got {num_blocks} " - f"blocks, {num_full_blocks} full blocks." - ) + f"blocks, {num_full_blocks} full blocks.") # Add the None block if the last block is not yet full. - seq_block_hashes = seq_block_hashes + [None] * ( - num_blocks - num_full_blocks - ) + seq_block_hashes = seq_block_hashes + [None] * (num_blocks - + num_full_blocks) return seq_block_hashes[start_block_idx:end_block_idx] def get_num_tokens_computed(self, seq: Sequence, allocated: bool) -> int: @@ -1031,8 +1034,7 @@ def get_num_tokens_computed(self, seq: Sequence, allocated: bool) -> int: self.update_seq(seq) num_computed_tokens_prev = self._num_tokens_computed.get( - (seq.seq_id, allocated), None - ) + (seq.seq_id, allocated), None) if num_computed_tokens_prev is not None and seq.is_prefill(): # For a sequence that is still in prefill, we don't have to # recompute the number of cached tokens. @@ -1046,8 +1048,7 @@ def get_num_tokens_computed(self, seq: Sequence, allocated: bool) -> int: # This is currently O(logN), where N is the number of blocks. num_cached_blocks = len( - self._allocator.find_cached_blocks_prefix(block_hashes, allocated) - ) + self._allocator.find_cached_blocks_prefix(block_hashes, allocated)) num_cached_tokens = num_cached_blocks * self._block_size self._num_tokens_computed[(seq.seq_id, allocated)] = num_cached_tokens return num_cached_tokens diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 51a5a66928cbd..f80166ca28789 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -5,6 +5,7 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.interfaces import Block from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec @@ -100,14 +101,13 @@ def __init__( self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} self._computed_blocks_tracker = ComputedBlocksTracker( - self.block_allocator, self.block_size, enable_caching - ) + self.block_allocator, self.block_size, enable_caching) self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) - def _get_num_blocks_to_allocate( - self, seq: Sequence, num_lookahead_slots: int = 0 - ) -> int: + def _get_num_blocks_to_allocate(self, + seq: Sequence, + num_lookahead_slots: int = 0) -> int: """ Get the number of new blocks to allocate for a sequence. @@ -127,7 +127,7 @@ def _get_num_blocks_to_allocate( # allocated, and so we should not count them towards the number of # blocks we need to allocate. # - # consider a scenario with a seqeuence of 3 token blocks, and a block + # consider a scenario with a sequence of 3 token blocks, and a block # pool of only 2 blocks: [b0, b1, b2], where b0 and b1 are computed # but evicted, b2 is not computed. So b0, b1 are the 2 free blocks, # in evictor. When deciding how many more blocks need to be allocated for @@ -135,15 +135,13 @@ def _get_num_blocks_to_allocate( # just 1 block (b2). num_cached_tokens = ( self._computed_blocks_tracker.get_num_tokens_computed( - seq, allocated=True - ) - ) + seq, allocated=True)) - assert ( - num_cached_tokens % self.block_size == 0 - ), "Cached tokens must be a multiple of block size" + assert (num_cached_tokens % self.block_size == 0 + ), "Cached tokens must be a multiple of block size" num_cached_blocks = cdiv(num_cached_tokens, self.block_size) - num_required_blocks = cdiv(seq.get_len() + num_lookahead_slots, self.block_size) + num_required_blocks = cdiv(seq.get_len() + num_lookahead_slots, + self.block_size) return num_required_blocks - num_cached_blocks @@ -162,8 +160,7 @@ def get_num_cached_tokens(self, seq: Sequence, allocated: bool) -> int: """ return self._computed_blocks_tracker.get_num_tokens_computed( - seq, allocated - ) + seq, allocated) def can_allocate(self, seq_group: SequenceGroup, @@ -177,25 +174,21 @@ def can_allocate(self, self._computed_blocks_tracker.update_seq(seq) num_blocks_to_allocate = self._get_num_blocks_to_allocate( - seq, num_lookahead_slots - ) + seq, num_lookahead_slots) if seq_group.is_encoder_decoder(): encoder_seq = seq_group.get_encoder_seq() assert encoder_seq is not None self._computed_blocks_tracker.update_seq(encoder_seq) num_blocks_to_allocate += self._get_num_blocks_to_allocate( - encoder_seq, num_lookahead_slots=0 - ) + encoder_seq, num_lookahead_slots=0) if self.max_block_sliding_window is not None: - num_blocks_to_allocate = min( - num_blocks_to_allocate, self.max_block_sliding_window - ) + num_blocks_to_allocate = min(num_blocks_to_allocate, + self.max_block_sliding_window) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - device=Device.GPU - ) + device=Device.GPU) if self.num_total_gpu_blocks - num_blocks_to_allocate < self.watermark_blocks: return AllocStatus.NEVER if num_free_gpu_blocks - num_blocks_to_allocate >= self.watermark_blocks: @@ -217,8 +210,7 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_table.allocate( token_ids=seq.get_token_ids(), block_hashes=self._computed_blocks_tracker.get_block_hashes( - seq - ), + seq), ) return block_table @@ -285,14 +277,15 @@ def can_append_slots(self, seq_group: SequenceGroup, for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): self._computed_blocks_tracker.update_seq(seq) block_table = self.block_tables[seq.seq_id] - num_touched_blocks += block_table.get_num_blocks_touched_by_append_slots( - token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, - ) + num_touched_blocks += ( + block_table.get_num_blocks_touched_by_append_slots( + token_ids=block_table.get_unseen_token_ids( + seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + )) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - device=Device.GPU - ) + device=Device.GPU) return num_touched_blocks <= num_free_gpu_blocks def append_slots( @@ -305,7 +298,8 @@ def append_slots( # Extend the block table for any new decoded tokens, as well as reserve # space for lookahead slots. - unseen_token_ids = block_table.get_unseen_token_ids(seq.get_token_ids()) + unseen_token_ids = block_table.get_unseen_token_ids( + seq.get_token_ids()) if len(unseen_token_ids) == 0: block_hashes = [] else: @@ -396,9 +390,7 @@ def get_common_computed_block_ids( all_blocks = self.block_tables[seq.seq_id].physical_block_ids num_cached_tokens = ( self._computed_blocks_tracker.get_num_tokens_computed( - seq, allocated=True - ) - ) + seq, allocated=True)) assert num_cached_tokens % self.block_size == 0 num_cached_blocks = num_cached_tokens // self.block_size computed_block_ids = all_blocks[:num_cached_blocks] @@ -581,10 +573,8 @@ def _can_swap(self, if self.block_allocator.get_num_total_blocks( device) < num_blocks_touched: return AllocStatus.NEVER - elif ( - self.block_allocator.get_num_free_blocks(device=device) - num_blocks_touched - >= watermark_blocks - ): + elif (self.block_allocator.get_num_free_blocks(device=device) - + num_blocks_touched >= watermark_blocks): return AllocStatus.OK else: return AllocStatus.LATER diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7b0a81778c07f..ca269bceeefa4 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -561,16 +561,14 @@ def _schedule_running( while running_queue: seq_group = running_queue[0] num_running_tokens, num_running_tokens_cached = ( - self._get_num_new_tokens_to_schedule( - seq_group, SequenceStatus.RUNNING, enable_chunking, budget - ) - ) + self._get_num_new_tokens_to_schedule(seq_group, + SequenceStatus.RUNNING, + enable_chunking, budget)) assert num_running_tokens_cached == 0, ( "No tokens should have been cached for running seq groups " "(be it in continuous prefill with chunked prefill or in " - "decode)" - ) + "decode)") if num_running_tokens == 0: # No budget => Stop @@ -745,14 +743,13 @@ def _schedule_swapped( # exceed the maximum number of sequences. num_new_seqs = seq_group.get_max_num_running_seqs() num_new_tokens_uncached, num_new_tokens_cached = ( - self._get_num_new_tokens_to_schedule( - seq_group, SequenceStatus.SWAPPED, enable_chunking, budget - ) - ) + self._get_num_new_tokens_to_schedule(seq_group, + SequenceStatus.SWAPPED, + enable_chunking, budget)) if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, ): break @@ -766,10 +763,9 @@ def _schedule_swapped( prefill_seq_groups.append( ScheduledSequenceGroup( seq_group, - token_chunk_size=num_new_tokens_uncached - + num_new_tokens_cached, - ) - ) + token_chunk_size=num_new_tokens_uncached + + num_new_tokens_cached, + )) else: decode_seq_groups.append( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) @@ -845,43 +841,34 @@ def _schedule_priority_preemption( seq_group = waiting_queue.popleft() num_new_seqs = seq_group.get_max_num_running_seqs() num_new_tokens_uncached, _ = self._get_num_new_tokens_to_schedule( - seq_group, SequenceStatus.WAITING, False, budget - ) + seq_group, SequenceStatus.WAITING, False, budget) # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) - if ( - num_new_tokens_uncached - and can_allocate == AllocStatus.OK - and budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - ) - ): + if (num_new_tokens_uncached and can_allocate == AllocStatus.OK + and budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + )): break # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens_uncached, num_running_tokens_cached = ( self._get_num_new_tokens_to_schedule( - vseq_group, SequenceStatus.RUNNING, False, budget - ) - ) + vseq_group, SequenceStatus.RUNNING, False, budget)) assert num_running_tokens_cached == 0, ( "No tokens should have been cached for running seq " "groups (be it in continuous prefill with chunked prefill " - "or in decode)" - ) + "or in decode)") budget.subtract_num_batched_tokens( - vseq_group.request_id, num_running_tokens_uncached - ) + vseq_group.request_id, num_running_tokens_uncached) num_running_seqs = vseq_group.get_max_num_running_seqs() - budget.subtract_num_seqs( - vseq_group.request_id, num_running_seqs - ) + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out, @@ -946,8 +933,7 @@ def _schedule_prefills( SequenceStatus.WAITING, enable_chunking, budget, - ) - ) + )) num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached if not enable_chunking: # Sanity check that the number of new tokens is correct. @@ -1006,8 +992,8 @@ def _schedule_prefills( # We have new tokens but they might be cached. if not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, ): # No more budget for new tokens. break @@ -1030,16 +1016,15 @@ def _schedule_prefills( else: seq_group.init_multi_step_from_lookahead_slots( num_lookahead_slots, - num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, is_multi_step=self.scheduler_config.is_multi_step, enable_chunking=enable_chunking, ) seq_groups.append( - ScheduledSequenceGroup( - seq_group=seq_group, token_chunk_size=num_new_tokens - ) - ) + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_new_tokens)) budget.add_num_batched_tokens( seq_group.request_id, num_batched_tokens=num_new_tokens_uncached, @@ -1056,8 +1041,7 @@ def _schedule_prefills( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking - ), + is_prefill=True, enable_chunking=enable_chunking), ) def _schedule_default(self) -> SchedulerOutputs: @@ -1076,20 +1060,13 @@ def _schedule_default(self) -> SchedulerOutputs: # Make sure we include num running seqs before scheduling prefill, # so that we don't schedule beyond max_num_seqs for prefill. for seq_group in self.running: - budget.add_num_seqs( - seq_group.request_id, seq_group.get_max_num_running_seqs() - ) + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - curr_loras = ( - set( - seq_group.lora_int_id - for seq_group in self.running - if seq_group.lora_int_id > 0 - ) - if self.lora_enabled - else None - ) + curr_loras = (set( + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None) prefills = SchedulerPrefillOutputs.create_empty() running_scheduled = SchedulerRunningOutputs.create_empty() @@ -1097,37 +1074,30 @@ def _schedule_default(self) -> SchedulerOutputs: # If any requests are swapped, prioritized swapped requests. if not self.swapped: - prefills = self._schedule_prefills( - budget, curr_loras, enable_chunking=False - ) + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=False) - if ( - len(prefills.seq_groups) == 0 - and self.scheduler_config.policy == "priority" - ): + if (len(prefills.seq_groups) == 0 + and self.scheduler_config.policy == "priority"): self._schedule_priority_preemption(budget) # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. if len(prefills.seq_groups) == 0: - running_scheduled = self._schedule_running( - budget, curr_loras, enable_chunking=False - ) + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=False) # If any sequence group is preempted, do not swap in any sequence # group. because it means there's no slot for new running requests. - if ( - len(running_scheduled.preempted) - + len(running_scheduled.swapped_out) - == 0 - ): + if (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) == 0): swapped_in = self._schedule_swapped(budget, curr_loras) - assert ( - budget.num_batched_tokens - <= self.scheduler_config.max_num_batched_tokens - ) + assert (budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -1140,14 +1110,12 @@ def _schedule_default(self) -> SchedulerOutputs: if len(swapped_in.decode_seq_groups) > 0: self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups] - ) + [s.seq_group for s in swapped_in.decode_seq_groups]) # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) preempted = len(running_scheduled.preempted) + len( - running_scheduled.swapped_out - ) + running_scheduled.swapped_out) # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. @@ -1207,27 +1175,24 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: swapped_in = SchedulerSwappedInOutputs.create_empty() # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running( - budget, curr_loras, enable_chunking=True - ) + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=True) # Schedule swapped out requests. # If preemption happens, it means we don't have space for swap-in. - if ( - len(running_scheduled.preempted) - + len(running_scheduled.swapped_out) - == 0 - ): + if (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) == 0): swapped_in = self._schedule_swapped(budget, curr_loras) # Schedule new prefills. - prefills = self._schedule_prefills( - budget, curr_loras, enable_chunking=True - ) + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=True) assert ( - budget.num_batched_tokens - <= self.scheduler_config.max_num_batched_tokens + budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens ), f"{budget.num_batched_tokens=}, {self.scheduler_config.max_num_batched_tokens=}" assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs @@ -1238,47 +1203,39 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # By default, vLLM scheduler prioritizes prefills. # Once chunked prefill is enabled, # the policy is changed to prioritize decode requests. - self.running.extend([s.seq_group for s in swapped_in.decode_seq_groups]) self.running.extend( - [s.seq_group for s in swapped_in.prefill_seq_groups] - ) + [s.seq_group for s in swapped_in.decode_seq_groups]) self.running.extend( - [s.seq_group for s in running_scheduled.decode_seq_groups] - ) + [s.seq_group for s in swapped_in.prefill_seq_groups]) self.running.extend( - [s.seq_group for s in running_scheduled.prefill_seq_groups] - ) + [s.seq_group for s in running_scheduled.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.prefill_seq_groups]) self.running.extend([s.seq_group for s in prefills.seq_groups]) # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) return SchedulerOutputs( - scheduled_seq_groups=( - prefills.seq_groups - + running_scheduled.prefill_seq_groups - + swapped_in.prefill_seq_groups - + running_scheduled.decode_seq_groups - + swapped_in.decode_seq_groups - ), - num_prefill_groups=( - len(prefills.seq_groups) - + len(swapped_in.prefill_seq_groups) - + len(running_scheduled.prefill_seq_groups) - ), + scheduled_seq_groups=(prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), + num_prefill_groups=(len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)), num_batched_tokens=budget.num_batched_and_cached_tokens, num_batched_tokens_from_budget=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy - + swapped_in.blocks_to_copy, - ignored_seq_groups=prefills.ignored_seq_groups - + swapped_in.infeasible_seq_groups, + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), - preempted=( - len(running_scheduled.preempted) - + len(running_scheduled.swapped_out) - ), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), ) def _schedule(self) -> SchedulerOutputs: @@ -1288,25 +1245,21 @@ def _schedule(self) -> SchedulerOutputs: else: return self._schedule_default() - def _can_append_slots( - self, seq_group: SequenceGroup, enable_chunking: bool - ) -> bool: + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ # It is True only for testing case to trigger artificial preemption. - if ( - self.enable_artificial_preemption - and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB - and self.artificial_preempt_cnt > 0 - ): + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): self.artificial_preempt_cnt -= 1 return False is_prefill = seq_group.is_prefill() num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill, enable_chunking - ) + is_prefill, enable_chunking) if is_prefill and num_lookahead_slots > 0: # Appending prefill slots only happens multi-step and @@ -1314,20 +1267,17 @@ def _can_append_slots( assert self.scheduler_config.is_multi_step and enable_chunking return self.block_manager.can_append_slots( - seq_group=seq_group, num_lookahead_slots=num_lookahead_slots - ) + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: # async_output_proc is allowed only when we have a single sequence # in the sequence group no_single_seq = seq_group.sampling_params is None or ( - seq_group.sampling_params.n == 1 - ) + seq_group.sampling_params.n == 1) return no_single_seq def schedule( - self, - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: + self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. @@ -1344,15 +1294,13 @@ def schedule( # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] for i, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups - ): + scheduler_outputs.scheduled_seq_groups): seq_group = scheduled_seq_group.seq_group token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) seq_group_metadata = self._seq_group_metadata_cache[ - self.cache_id - ].get_object() + self.cache_id].get_object() seq_group_metadata.seq_data.clear() seq_group_metadata.block_tables.clear() @@ -1369,8 +1317,7 @@ def schedule( # Block table for cross-attention # Also managed at SequenceGroup level cross_block_table = self.block_manager.get_cross_block_table( - seq_group - ) + seq_group) else: encoder_seq_data = None cross_block_table = None @@ -1384,9 +1331,7 @@ def schedule( if self.cache_config.enable_prefix_caching: common_computed_block_nums = ( self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING) - ) - ) + seq_group.get_seqs(status=SequenceStatus.RUNNING))) do_sample = True is_prompt = seq_group.is_prefill() @@ -1404,10 +1349,8 @@ def schedule( # NOTE: We use get_len instead of get_prompt_len because when # a sequence is preempted, prefill includes previous generated # output tokens. - if ( - token_chunk_size + num_computed_tokens - < seqs[0].data.get_len() - ): + if (token_chunk_size + num_computed_tokens < + seqs[0].data.get_len()): do_sample = False # It assumes the scheduled_seq_groups is ordered by @@ -1432,8 +1375,7 @@ def schedule( # the subsequent comms can still use delta, but # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups > 0 - else None, + if scheduler_outputs.num_prefill_groups > 0 else None, mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) @@ -1456,8 +1398,7 @@ def schedule( if allow_async_output_proc: allow_async_output_proc = self._allow_async_output_proc( - seq_group - ) + seq_group) # Now that the batch has been created, we can assume all blocks in the # batch will have been computed before the next scheduling invocation. @@ -1564,8 +1505,7 @@ def _append_slots( """ is_prefill: bool = seq_group.is_prefill() num_lookahead_slots: int = self._get_num_lookahead_slots( - is_prefill, enable_chunking - ) + is_prefill, enable_chunking) seq_group.init_multi_step_from_lookahead_slots( num_lookahead_slots, @@ -1672,8 +1612,7 @@ def _swap_out( # entire engine. raise RuntimeError( "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error." - ) + "the swap space to avoid this error.") mapping = self.block_manager.swap_out(seq_group) blocks_to_swap_out.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): @@ -1686,18 +1625,16 @@ def _passed_delay(self, now: float) -> bool: # Delay scheduling prompts to let waiting queue fill up if self.scheduler_config.delay_factor > 0 and self.waiting: earliest_arrival_time = min( - [e.metrics.arrival_time for e in self.waiting] - ) + [e.metrics.arrival_time for e in self.waiting]) passed_delay = (now - earliest_arrival_time) > ( - self.scheduler_config.delay_factor * self.last_prompt_latency - ) or not self.running + self.scheduler_config.delay_factor * + self.last_prompt_latency) or not self.running else: passed_delay = True return passed_delay - def _get_num_lookahead_slots( - self, is_prefill: bool, enable_chunking: bool - ) -> int: + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. @@ -1777,10 +1714,8 @@ def _get_num_new_tokens_to_schedule( # So we set allocated=False. if self.cache_config.enable_prefix_caching: num_cached_tokens_seq = ( - self.block_manager.get_num_cached_tokens( - seq, allocated=False - ) - ) + self.block_manager.get_num_cached_tokens(seq, + allocated=False)) # Any computed token should have been cached too. if not self.scheduler_config.chunked_prefill_enabled: @@ -1788,19 +1723,16 @@ def _get_num_new_tokens_to_schedule( f"Number of cached tokens ({num_cached_tokens_seq}) " "should be no less than the number of computed " f"tokens ({num_computed_tokens_seq}) when chunked prefill " - "is disabled" - ) + "is disabled") num_new_tokens_cached_seq = max( - 0, num_cached_tokens_seq - num_computed_tokens_seq - ) + 0, num_cached_tokens_seq - num_computed_tokens_seq) else: num_cached_tokens_seq = 0 num_new_tokens_cached_seq = 0 all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq - num_new_tokens_uncached_seq = ( - all_num_new_tokens_seq - num_new_tokens_cached_seq - ) + num_new_tokens_uncached_seq = (all_num_new_tokens_seq - + num_new_tokens_cached_seq) num_uncached_new_tokens += num_new_tokens_uncached_seq num_cached_new_tokens += num_new_tokens_cached_seq @@ -1840,11 +1772,8 @@ def _chunk_new_tokens_to_schedule( # the sequence. pass else: - num_new_tokens = ( - 0 - if num_new_tokens > remaining_token_budget - else num_new_tokens - ) + num_new_tokens = (0 if num_new_tokens > remaining_token_budget + else num_new_tokens) elif self.cache_config.enable_prefix_caching: # When prefix caching is enabled, we always allocate # the number of new tokens that is dividable by the block @@ -1852,139 +1781,18 @@ def _chunk_new_tokens_to_schedule( block_size = self.cache_config.block_size remainder = budget.token_budget % block_size if remainder != 0: - raise ValueError( - "When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}" - ) + raise ValueError("When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {remainder}") # Round down to block - remaining_token_budget = ( - remaining_token_budget // block_size * block_size - ) + remaining_token_budget = (remaining_token_budget // block_size * + block_size) # Calculate the number of new tokens that are not cached with chunk cap. num_new_tokens = min(num_new_tokens, remaining_token_budget) else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens - - # def _get_num_new_tokens( - # self, - # seq_group: SequenceGroup, - # status: SequenceStatus, - # enable_chunking: bool, - # budget: SchedulingBudget, - # ) -> int: - # """Get the next new tokens to compute for a given sequence group - # that's in a given `status`. - - # The API could chunk the number of tokens to compute based on `budget` - # if `enable_chunking` is True. If a sequence group has multiple - # sequences (e.g., running beam search), it means it is in decoding - # phase, so chunking doesn't happen. - - # Returns 0 if the new token cannot be computed due to token budget. - # """ - # num_new_tokens = 0 - # seqs = seq_group.get_seqs(status=status) - # for seq in seqs: - # num_new_tokens += seq.get_num_new_tokens() - # assert num_new_tokens > 0 - # # Chunk if a running request cannot fit in the given budget. - # # If number of seq > 1, it means it is doing beam search - # # in a decode phase. Do not chunk. - # if enable_chunking and len(seqs) == 1: - # remaining_token_budget = budget.remaining_token_budget() - # seq = seqs[0] - # if self.scheduler_config.is_multi_step: - # # The current multi-step + chunked prefill capability does - # # not actually support chunking prompts. - # # - # # Therefore, `num_new_tokens` is computed in the same fashion - # # for both multi-step+chunked-prefill & - # # multi-step+chunked-prefill+APC - # # - # # Prompts with more tokens than the current remaining budget - # # are postponed to future scheduler steps - # if num_new_tokens > self._get_prompt_limit(seq_group): - # # If the seq_group is in prompt-stage, pass the - # # num_new_tokens as-is so the caller can ignore - # # the sequence. - # pass - # else: - # num_new_tokens = 0 \ - # if num_new_tokens > remaining_token_budget \ - # else num_new_tokens - # elif self.cache_config.enable_prefix_caching: - # # When prefix caching is enabled, we always allocate - # # the number of new tokens that is dividable by the block - # # size to avoid partial block matching. - # block_size = self.cache_config.block_size - # remainder = budget.token_budget % block_size - # if remainder != 0: - # raise ValueError("When enabling chunked prefill and " - # "prefix caching, max_num_batched_tokens " - # "(chunk size) must be dividable by " - # "block size, but got chunk_size " - # f"({budget.token_budget}) % block_size " - # f"({block_size}) = {remainder}") - # num_new_tokens_cached = seq.get_num_cached_tokens() - seq.get_num_computed_tokens() - # num_new_tokens_cached = max(0, num_new_tokens_cached) - # # Round down to block - # remaining_token_budget = remaining_token_budget // block_size * block_size - - # # Calculate the number of new tokens that are not cached with chunk cap. - # num_new_tokens_uncached = min(num_new_tokens - num_new_tokens_cached, remaining_token_budget) - # if num_new_tokens_uncached == 0: - # # No more budget for new tokens, don't include any cached tokens too. - # return 0 - # num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached - # else: - # num_new_tokens = min(num_new_tokens, remaining_token_budget) - # return num_new_tokens - # - # def _update_prefix_cached_tokens(self, seq: Sequence): - # """ - # Update the number of prefix cached tokens for a sequence. - - # This function takes O(log(n)) time, where n is the number of blocks - # in the sequence. - # """ - # num_prefix_cached_tokens = self.block_manager.get_num_computed_tokens(seq) - # seq.set_num_prefix_cached_tokens(num_prefix_cached_tokens) - - # def _get_num_new_tokens_exclude_cached( - # self, num_new_tokens: int, seq: Sequence - # ) -> int: - # """ - # Get the number of new tokens to compute for a sequence, excluding - # cached tokens. - - # Args: - # num_new_tokens: The number of new tokens to compute. - # seq: The sequence to compute the new tokens for. - - # Returns: - # Given `num_new_tokens`, returns the number of uncached tokens. - # """ - - # # If a decode sequence, new tokens are always not computed/cached. - # if not seq.is_prefill(): - # return num_new_tokens - - # # If a prefill sequence, we need to exclude the number of cached tokens. - # num_computed_tokens = seq.get_num_computed_tokens() - # num_cached_tokens = seq.get_num_cached_tokens() - - # # We subtract the number of cached tokens from the number of new tokens - # num_computed_tokens_new = num_new_tokens + num_computed_tokens - # num_new_tokens_exclude_cached = max( - # 0, num_computed_tokens_new - num_cached_tokens - # ) - # assert ( - # num_new_tokens_exclude_cached <= num_new_tokens - # ), "Number of new tokens exclude cached should be less than or equal to the number of new tokens" - # return num_new_tokens_exclude_cached diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index f9ad930e0a7c3..a46625eff1e4a 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -475,9 +475,8 @@ def _log_prometheus(self, stats: Stats) -> None: stats.num_preemption_iter) self._log_counter(self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter) - self._log_counter( - self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter - ) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) self._log_histogram(self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter) self._log_histogram(self.metrics.histogram_time_per_output_token, diff --git a/vllm/sequence.py b/vllm/sequence.py index 5d68371ce7d6c..c73aeb36974a7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -591,22 +591,21 @@ def _update_block_hashes(self): token_ids = self.get_token_ids() # All token ids in the sequence num_full_blocks = len(token_ids) // self.block_size cur_num_full_blocks = len(self._computed_block_hashes) - prev_block_hash = ( - None if cur_num_full_blocks == 0 else self._computed_block_hashes[-1] - ) + prev_block_hash = (None if cur_num_full_blocks == 0 else + self._computed_block_hashes[-1]) for i in range(cur_num_full_blocks, num_full_blocks): - block_token_ids = token_ids[i * self.block_size : (i + 1) * self.block_size] + block_token_ids = token_ids[i * self.block_size:(i + 1) * + self.block_size] assert len(block_token_ids) == self.block_size - block_hash = hash( - ( - prev_block_hash, # Previous block hash - self.from_decoder_prompt, # Whether the sequence is decoder-only - # LoRA int id since the attention output will depend on - # LoRA with same token ids. - self.lora_int_id, - *block_token_ids, # The block token ids - ) - ) + block_hash = hash(( + prev_block_hash, # Previous block hash + self. + from_decoder_prompt, # Whether the sequence is decoder-only + # LoRA int id since the attention output will depend on + # LoRA with same token ids. + self.lora_int_id, + *block_token_ids, # The block token ids + )) self._computed_block_hashes.append(block_hash) prev_block_hash = block_hash @@ -619,15 +618,13 @@ def _reset_block_hashes(self): """ num_full_prompt_blocks = self.get_prompt_len() // self.block_size self._computed_block_hashes = self._computed_block_hashes[ - num_full_prompt_blocks: - ] + num_full_prompt_blocks:] def set_num_prefix_cached_tokens(self, num_prefix_cached_tokens: int): self.data.set_num_prefix_cached_tokens(num_prefix_cached_tokens) - def hash_of_block( - self, prev_block_hash: Optional[int], cur_block_idx: int - ) -> int: + def hash_of_block(self, prev_block_hash: Optional[int], + cur_block_idx: int) -> int: """ Get the hash of a given block of the sequence. @@ -642,19 +639,15 @@ def hash_of_block( token_ids = self.get_token_ids() assert (cur_block_idx + 1) * self.block_size <= len(token_ids), ( f"Invalid block index: {cur_block_idx}. The sequence only has " - f"{len(token_ids) // self.block_size} blocks." - ) - block_token_ids = token_ids[ - cur_block_idx * self.block_size : (cur_block_idx + 1) - * self.block_size - ] - return hash( - ( - prev_block_hash, - self.lora_int_id, - *block_token_ids, - ) - ) + f"{len(token_ids) // self.block_size} blocks.") + block_token_ids = token_ids[cur_block_idx * + self.block_size:(cur_block_idx + 1) * + self.block_size] + return hash(( + prev_block_hash, + self.lora_int_id, + *block_token_ids, + )) def reset_state_for_recompute(self): """Reset the sequence states for recomputation."""