diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index db9b7e1da2e46..5414036a37263 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -132,18 +132,18 @@ def main(args): filtered_datasets = [(PROMPT, prompt_len, args.output_len) ] * args.num_prompts - # llm = LLM( - # model=args.model, - # tokenizer_mode="auto", - # trust_remote_code=True, - # enforce_eager=True, - # use_v2_block_manager=args.use_v2_block_manager, - # tensor_parallel_size=args.tensor_parallel_size, - # enable_prefix_caching=args.enable_prefix_caching, - # disable_log_stats=False, - # max_num_batched_tokens=4096 * 2, - # enable_chunked_prefill=True, - # ) + llm = LLM( + model=args.model, + tokenizer_mode="auto", + trust_remote_code=True, + enforce_eager=True, + use_v2_block_manager=args.use_v2_block_manager, + tensor_parallel_size=args.tensor_parallel_size, + enable_prefix_caching=args.enable_prefix_caching, + disable_log_stats=False, + max_num_batched_tokens=4096 * 2, + enable_chunked_prefill=True, + ) engine_args = EngineArgs.from_cli_args(args) llm = LLM(**dataclasses.asdict(engine_args)) diff --git a/tests/core/block/test_block_manager.py b/tests/core/block/test_block_manager.py index e940d2a331d9b..1190fa2885298 100644 --- a/tests/core/block/test_block_manager.py +++ b/tests/core/block/test_block_manager.py @@ -235,7 +235,7 @@ def test_can_allocate_with_prefix_cache( # Num blocks needed for 2 seqs, minus the number of blocks shared. num_blocks_required_with_sharing = 2 * num_blocks_required_seq - num_blocks_shared - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0, @@ -299,7 +299,7 @@ def test_can_append_with_prefix_cache( print(f"num_gpu_blocks: {num_gpu_blocks}") num_blocks_required_with_sharing = 2 * num_blocks_required_seq_1 - num_blocks_shared - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0, diff --git a/vllm/config.py b/vllm/config.py index 98e0ea96f0d2c..99a82c8f1b40b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1074,10 +1074,6 @@ def __init__(self, self.policy = policy self._verify_args() - print( - f"max_num_batched_tokens: {self.max_num_batched_tokens}, max_num_seqs: {self.max_num_seqs}" - ) - def _verify_args(self) -> None: if (self.max_num_batched_tokens < self.max_model_len and not self.chunked_prefill_enabled): diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 9a3a66523bd34..2beee5b9885a4 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -164,7 +164,9 @@ def append_slots( for i, token_block in enumerate(token_blocks): if self._enable_prefix_caching: - block_hash: Optional[int] = seq.get_block_hash(first_block_idx + i) + block_hash: Optional[int] = seq.update_and_get_block_hash( + first_block_idx + i + ) else: block_hash = None self._blocks.append_token_ids(first_block_idx + i, token_block, block_hash) @@ -286,7 +288,7 @@ def _allocate_blocks_for_token_ids( if len(cur_token_ids) == self._block_size: block_token_ids.append(cur_token_ids) if self._enable_prefix_caching: - block_hashes.append(seq.get_block_hash(block_idx)) + block_hashes.append(seq.update_and_get_block_hash(block_idx)) else: block_hashes.append(None) else: @@ -308,12 +310,9 @@ def _allocate_blocks_for_token_ids( assert len(tail_token_ids) == 1 assert block_hashes[-1] is None cur_token_ids = tail_token_ids[0] - try: - block = self._allocator.allocate_mutable_block( - prev_block=prev_block, device=device - ) - except Exception as e: - breakpoint() + block = self._allocator.allocate_mutable_block( + prev_block=prev_block, device=device + ) block.append_token_ids(cur_token_ids, block_hash=None) blocks.append(block) diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 191eb6c8fdd15..c2117ccaaeb50 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -196,9 +196,6 @@ def increase_pool(self): allocator=self._allocator, block_id=None)) - # TODO(rickyx): This should take in kwargs for flexible initialization of different types of blocks - # Right now, we update explicitly blocks with other args after initialization, e.g. block_hash - # computed for the prefix caching block. def init_block( self, prev_block: Optional[Block], @@ -206,15 +203,6 @@ def init_block( block_size: int, physical_block_id: Optional[int], ) -> Block: - """Initializes a block with the given parameters. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. - token_ids (List[int]): The token IDs to be stored in the block. - block_size (int): The size of the block. - physical_block_id (Optional[int]): The physical block ID. - block_hash (Optional[int]): The hash of the block's content. - """ if len(self._free_ids) == 0: self.increase_pool() assert len(self._free_ids) > 0 diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 378d273a271ab..8cec110cca663 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -99,7 +99,7 @@ def content_hash(self) -> Optional[int]: return None @abstractmethod - def set_content_hash(self, content_hash: int) -> None: + def set_content_hash(self, content_hash: Optional[int]) -> None: pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 2ffdb94e904a0..e396ae3a3f4ff 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -64,7 +64,6 @@ def allocate_immutable_block( self, prev_block: Optional[Block], token_ids: List[int], - device: Optional[Device] = None, block_hash: Optional[int] = None, ) -> Block: """Allocates a new immutable block with the given token IDs, linked to @@ -79,7 +78,6 @@ def allocate_immutable_block( Returns: Block: The newly allocated immutable block. """ - assert device is None assert block_hash is None block = self.allocate_mutable_block(prev_block=prev_block) @@ -91,9 +89,7 @@ def allocate_immutable_blocks( prev_block: Optional[Block], block_token_ids: List[List[int]], block_hashes: Optional[List[Optional[int]]] = None, - device: Optional[Device] = None, ) -> List[Block]: - assert device is None num_blocks = len(block_token_ids) block_ids = [] @@ -114,7 +110,6 @@ def allocate_immutable_blocks( def allocate_mutable_block( self, prev_block: Optional[Block], - device: Optional[Device] = None, block_hash: Optional[int] = None, ) -> Block: """Allocates a new mutable block, linked to the previous block. @@ -127,7 +122,6 @@ def allocate_mutable_block( Returns: Block: The newly allocated mutable block. """ - assert device is None assert block_hash is None block_id = self._allocate_block_id() diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index feecf3d45ec3c..bdd12ebcc8be8 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -140,7 +140,6 @@ def allocate_immutable_block( prev_block: Optional[Block], token_ids: List[int], block_hash: Optional[int] = None, - device: Optional[Device] = None, ) -> Block: """Allocates an immutable block with the given token IDs, reusing cached blocks if possible. @@ -153,7 +152,6 @@ def allocate_immutable_block( Returns: Block: The allocated immutable block. """ - assert device is None assert len(token_ids) == self._block_size, "An immutable block should be full" assert ( block_hash is not None @@ -163,9 +161,6 @@ def allocate_immutable_block( cached_block_id = self._cached_blocks.get(block_hash, None) if cached_block_id is not None: # Initialize a block that points to cached data - # print( - # f"reuse block_hash={block_hash} from cached_block_id: {cached_block_id}" - # ) block: Block = self._block_pool.init_block( prev_block=prev_block, token_ids=token_ids, @@ -177,9 +172,6 @@ def allocate_immutable_block( self._incr_refcount_cached_block(block) return block - # print( - # f"alloc from new block(block_hash: {block_hash}), get_num_free_blocks: {self.get_num_free_blocks()}" - # ) self.metric_data.query(hit=False) # No cached block => Allocate a new block @@ -192,7 +184,6 @@ def allocate_immutable_blocks( prev_block: Optional[Block], block_token_ids: List[List[int]], block_hashes: Optional[List[int]] = None, - device: Optional[Device] = None, ) -> List[Block]: blocks = [] assert ( @@ -204,7 +195,6 @@ def allocate_immutable_blocks( prev_block=prev_block, token_ids=token_ids, block_hash=block_hash, - device=device, ) blocks.append(prev_block) return blocks @@ -224,9 +214,6 @@ def allocate_mutable_block(self, """ assert device is None assert_prefix_caching_block_or_none(prev_block) - # print( - # f"Allocating mutable block: get_num_free_blocks: {self.get_num_free_blocks()}" - # ) block_id = self._allocate_block_id() block = self._block_pool.init_block(prev_block=prev_block, token_ids=[], @@ -297,7 +284,6 @@ def _allocate_block_id(self) -> BlockId: """First tries to allocate a block id from the hashless allocator, and if there are no blocks, then tries to evict an unused cached block. """ - # print(f"allocating block_id: get_num_free_blocks: {self.get_num_free_blocks()}") hashless_block_id = self._maybe_allocate_hashless_block_id() if hashless_block_id is not None: return hashless_block_id @@ -418,9 +404,7 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int: assert device is None # The number of free blocks is the number of hashless free blocks # plus the number of blocks evictor could free from its list. - return self._hashless_allocator.get_num_free_blocks() + ( - self.evictor.num_blocks - ) + return self._hashless_allocator.get_num_free_blocks() + self.evictor.num_blocks def get_num_total_blocks(self) -> int: return self._hashless_allocator.get_num_total_blocks() @@ -511,9 +495,6 @@ def cow_block_if_not_appendable(self, block: Block) -> BlockId: return src_block_id self._free_block_id(block) - # print( - # f"Allocating block for COW: get_num_free_blocks: {self.get_num_free_blocks()}" - # ) trg_block_id = self._allocate_block_id() self._cow_tracker.record_cow(src_block_id, trg_block_id) @@ -878,38 +859,6 @@ def token_ids(self) -> List[int]: def prev_block(self) -> Optional[Block]: return self._prev_block - # @property - # def content_hash(self) -> Optional[int]: - # """Return the content-based hash of the current block, or None if it is - # not yet defined. - - # For the content-based hash to be defined, the current block must be - # full. - # """ - # # If the hash is already computed, return it. - # if self._cached_content_hash is not None: # return self._cached_content_hash - - # # We cannot compute a hash for the current block because it is not full. - # if not self.is_full: - # return None - - # is_first_block = self._prev_block is None - # prev_block_hash = ( - # None if is_first_block else - # self._prev_block.content_hash # type: ignore - # ) - - # # Previous block exists but does not yet have a hash. - # # Return no hash in this case. - # if prev_block_hash is None and not is_first_block: - # return None - - # self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( - # is_first_block, - # prev_block_hash, - # cur_block_token_ids=self.token_ids) - # return self._cached_content_hash - @property def content_hash(self) -> Optional[int]: return self._cached_content_hash @@ -952,7 +901,9 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], assert (prev_block_hash is None) == is_first_block return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) - +# TODO(rickyx): This is not used anymore. Or it could be used to track +# cached blocks for a sequence, so the sequence would be decoupled from the computed +# block hash calculation. class ComputedBlocksTracker: """Handles caching of per-sequence computed block ids. When a sequence appears for the first time, it traverses all of the @@ -989,54 +940,6 @@ def remove_seq(self, seq_id: int) -> None: assert seq_id in self._cached_computed_seq_blocks del self._cached_computed_seq_blocks[seq_id] - def get_cached_computed_blocks_and_update( - self, seq_id: int, block_ids: List[int]) -> List[int]: - """ Look at the class documentation for details - """ - # Ensure seq_id is already tracked - assert seq_id in self._cached_computed_seq_blocks - - # Get cached data (may be empty on the first time) - prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[ - seq_id] - - if has_gap: - # When gap is detected, we do not add more computed blocks at this - # sequence iteration - return prev_computed_block_ids - - # We do not consider the last block id for caching purposes. - num_cur_blocks = len(block_ids) - 1 - assert num_cur_blocks >= 0 - - if len(prev_computed_block_ids) >= num_cur_blocks: - # Cache HIT - assert len(prev_computed_block_ids) == num_cur_blocks - return prev_computed_block_ids - - # If here, then we may possibly add more computed blocks. As a result, - # traverse the additional blocks after prev_computed_block_ids to - # detect more computed blocks and add them. - - # Incremental init for seq_id => Look only at the new blocks - computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501 - prev_computed_block_ids, - block_ids, - skip_last_block_id= - True, # We skip last block id to avoid caching of full seq - ) - - # QQ(rickyx): why is it possible to actually have a gap? - - # Detect if there is a "gap" - has_gap = len(computed_block_ids) < num_cur_blocks - - # Record - self._cached_computed_seq_blocks[seq_id] = (computed_block_ids, - has_gap) - - return computed_block_ids - class LastAccessBlocksTracker: """Manages the last access time of the tracked sequences, in order to allow diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 998b706296fd2..df09a3a30743c 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -108,6 +108,16 @@ def __init__( def _get_num_blocks_to_allocate( self, seq: Sequence, num_lookahead_slots: int = 0 ) -> int: + """ + Get the number of new blocks to allocate for a sequence. + + Args: + seq (Sequence): The sequence to allocate blocks for. + num_lookahead_slots (int): The number of lookahead slots to allocate. + + Returns: + int: The number of new blocks to allocate. + """ num_cached_tokens = seq.get_num_cached_tokens() assert ( @@ -119,6 +129,17 @@ def _get_num_blocks_to_allocate( return num_required_blocks - num_cached_blocks def get_num_computed_tokens(self, seq: Sequence) -> int: + """ + Get the number of computed tokens for a sequence. + + NOTE: This only returns tokens in blocks that are BOTH cached and allocated (active). + + Args: + seq (Sequence): The sequence to get the number of computed tokens for. + + Returns: + int: The number of allocated and cached computed tokens. + """ seq_blocks = seq.get_block_hashes() cached_seq_blocks = self.block_allocator.get_allocated_cached_blocks( block_hashes=seq_blocks, @@ -126,52 +147,6 @@ def get_num_computed_tokens(self, seq: Sequence) -> int: ) return len(cached_seq_blocks) * self.block_size - # def get_num_computed_blocks(self, seq_group: SequenceGroup) -> Dict[SeqId, int]: - # num_computed_blocks = {} - # for seq in seq_group.get_seqs(): - # num_computed_blocks[seq.seq_id] = self._get_num_computed_tokens(seq) - # return num_computed_blocks - - def can_allocate_old( - self, seq_group: SequenceGroup, num_lookahead_slots: int = 0 - ) -> AllocStatus: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = BlockTable.get_num_required_blocks( - seq.get_token_ids(), - block_size=self.block_size, - num_lookahead_slots=num_lookahead_slots, - ) - - if seq_group.is_encoder_decoder(): - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - num_required_blocks += BlockTable.get_num_required_blocks( - encoder_seq.get_token_ids(), - block_size=self.block_size, - ) - - if self.max_block_sliding_window is not None: - num_required_blocks = min( - num_required_blocks, self.max_block_sliding_window - ) - - num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - device=Device.GPU - ) - - # Use watermark to avoid frequent cache eviction. - if self.num_total_gpu_blocks - num_required_blocks < self.watermark_blocks: - return AllocStatus.NEVER - if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - def can_allocate(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> AllocStatus: @@ -200,14 +175,6 @@ def can_allocate(self, num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( device=Device.GPU ) - # print(f"num_blocks_to_allocate: {num_blocks_to_allocate}") - # print(f"num_free_gpu_blocks: {num_free_gpu_blocks}") - # print(f"watermark_blocks: {self.watermark_blocks}") - - # Use watermark to avoid frequent cache eviction. - # old_can_allocate = self.can_allocate_old(seq_group, num_lookahead_slots) - - can_allocate = None if self.num_total_gpu_blocks - num_blocks_to_allocate < self.watermark_blocks: return AllocStatus.NEVER if num_free_gpu_blocks - num_blocks_to_allocate >= self.watermark_blocks: @@ -215,11 +182,6 @@ def can_allocate(self, else: return AllocStatus.LATER - # if old_can_allocate != can_allocate: - # print(f"old_can_allocate: {old_can_allocate}, can_allocate: {can_allocate}") - - return can_allocate - def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_table = BlockTable( block_size=self.block_size, @@ -400,17 +362,6 @@ def get_common_computed_block_ids( computed_block_ids = all_blocks[:num_cached_block] computed_seq_block_ids.append(computed_block_ids) - # old_computed_block_ids = ( - # self._computed_blocks_tracker.get_cached_computed_blocks_and_update( - # seq.seq_id, all_blocks - # ) - # ) - # if old_computed_block_ids != computed_block_ids: - # print( - # f"old_computed_block_ids: \n{old_computed_block_ids}\n, computed_block_ids: \n{computed_block_ids}\n" - # ) - # print(f"seq: {seq}") - # NOTE(sang): This assumes seq_block_ids doesn't contain any None. return self.block_allocator.get_common_computed_block_ids( computed_seq_block_ids) # type: ignore diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 51a88f338dc3c..1434eb9b3115d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -56,12 +56,14 @@ class SchedulingBudget: max_num_seqs: int _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) + # Number of batched tokens that are strictly not cached. _num_batched_tokens: int = 0 - _num_batched_and_cached_tokens: int = 0 + # Number of batched tokens that are cached. + _num_cached_tokens: int = 0 _num_curr_seqs: int = 0 def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): - # assert num_new_tokens != 0 + assert num_new_tokens >= 0 assert num_new_seqs != 0 return (self.num_batched_tokens + num_new_tokens <= self.token_budget and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) @@ -73,32 +75,25 @@ def add_num_batched_tokens( self, req_id: str, num_batched_tokens: int, - num_batched_and_cached_tokens: Optional[int] = None, + num_cached_tokens: int = 0, ): assert num_batched_tokens >= 0 + assert num_cached_tokens >= 0 if req_id in self._request_ids_num_batched_tokens: return self._request_ids_num_batched_tokens.add(req_id) self._num_batched_tokens += num_batched_tokens - if num_batched_and_cached_tokens is None: - num_batched_and_cached_tokens = num_batched_tokens - self._num_batched_and_cached_tokens += num_batched_and_cached_tokens - - assert self._num_batched_tokens <= self.token_budget, f"{self._num_batched_tokens} > {self.token_budget}" + self._num_cached_tokens += num_cached_tokens def subtract_num_batched_tokens( self, req_id: str, num_batched_tokens: int, - num_batched_and_cached_tokens: Optional[int] = None, ): if req_id in self._request_ids_num_batched_tokens: self._request_ids_num_batched_tokens.remove(req_id) self._num_batched_tokens -= num_batched_tokens - if num_batched_and_cached_tokens is None: - num_batched_and_cached_tokens = num_batched_tokens - self._num_batched_and_cached_tokens -= num_batched_and_cached_tokens def add_num_seqs(self, req_id: str, num_curr_seqs: int): if req_id in self._request_ids_num_curr_seqs: @@ -118,7 +113,7 @@ def num_batched_tokens(self): @property def num_batched_and_cached_tokens(self): - return self._num_batched_and_cached_tokens + return self._num_batched_tokens + self._num_cached_tokens @property def num_curr_seqs(self): @@ -638,9 +633,6 @@ def _schedule_running( self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() - if seq_group.is_prefill() and not enable_chunking: - breakpoint() - scheduled_seq_group: ScheduledSequenceGroup = \ self._scheduled_seq_group_cache[self.cache_id].get_object() scheduled_seq_group.seq_group = seq_group @@ -840,7 +832,6 @@ def _schedule_priority_preemption( while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): # Only preempt if waiting sequence cannot be allocated - assert False can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens and can_allocate == AllocStatus.OK and budget.can_schedule(num_new_tokens=num_new_tokens, @@ -926,7 +917,6 @@ def _schedule_prefills( num_new_tokens, seq ) - # print(f"[{seq_group.request_id=}] {num_new_tokens=} {num_new_tokens_exclude_cached}, budget: {budget.num_batched_tokens}") if not enable_chunking: num_prompt_tokens = seq.get_len() assert num_new_tokens == num_prompt_tokens @@ -950,15 +940,6 @@ def _schedule_prefills( can_allocate = self.block_manager.can_allocate( seq_group, num_lookahead_slots=num_lookahead_slots) - old_can_allocate = self.block_manager.can_allocate_old( - seq_group, num_lookahead_slots=num_lookahead_slots - ) - - if can_allocate != old_can_allocate: - print( - f"can_allocate: {can_allocate}, old_can_allocate: {old_can_allocate}" - ) - if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: @@ -1004,13 +985,10 @@ def _schedule_prefills( if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) waiting_queue.popleft() - try: - self._allocate_and_set_running(seq_group) - except Exception as e: - breakpoint() + self._allocate_and_set_running(seq_group) # NOTE(rickyx): We are updating this again since some of the previously - # cached blocks that were in evictor might now become active again. + # cached blocks that were in evictor might now become active again. # Therefore, the actual number of tokens cached might have changed. self._update_prefix_cached_tokens(seq) num_new_tokens = self._get_num_new_tokens( @@ -1019,8 +997,8 @@ def _schedule_prefills( enable_chunking, budget, ) - num_new_tokens_exclude_cached = self._get_num_new_tokens_exclude_cached( - num_new_tokens, seq + num_new_tokens_uncached = self._get_num_new_tokens_exclude_cached( + num_new_tokens, seq ) if enable_chunking and self.scheduler_config.is_multi_step: @@ -1041,13 +1019,14 @@ def _schedule_prefills( enable_chunking=enable_chunking) seq_groups.append( - ScheduledSequenceGroup(seq_group=seq_group, - token_chunk_size=num_new_tokens)) - # print(f"[{seq_group.request_id}] {num_new_tokens=} {num_new_tokens_exclude_cached=}, budget: {budget.num_batched_tokens}") + ScheduledSequenceGroup( + seq_group=seq_group, token_chunk_size=num_new_tokens + ) + ) budget.add_num_batched_tokens( seq_group.request_id, - num_batched_tokens=num_new_tokens_exclude_cached, - num_batched_and_cached_tokens=num_new_tokens, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens - num_new_tokens_uncached, ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) @@ -1137,11 +1116,6 @@ def _schedule_default(self) -> SchedulerOutputs: # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. - if len(running_scheduled.prefill_seq_groups) > 0: - print( - f"running_scheduled.prefill_seq_groups: {running_scheduled.prefill_seq_groups}" - ) - breakpoint() assert len(running_scheduled.prefill_seq_groups) == 0 assert len(swapped_in.prefill_seq_groups) == 0 @@ -1572,7 +1546,7 @@ def _preempt( else: preemption_mode = PreemptionMode.RECOMPUTE - if self.num_cumulative_preemption % 5 == 0: + if self.num_cumulative_preemption % 50 == 0: logger.warning( "Sequence group %s is preempted by %s mode because there is " "not enough KV cache space. This can affect the end-to-end " @@ -1699,7 +1673,6 @@ def _get_num_new_tokens( num_new_tokens = 0 seqs = seq_group.get_seqs(status=status) for seq in seqs: - # self._update_prefix_cached_tokens(seq) num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 # Chunk if a running request cannot fit in the given budget. @@ -1751,18 +1724,34 @@ def _get_num_new_tokens( # No more budget for new tokens, don't include any cached tokens too. return 0 num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached - # print(f"[{seq_group.request_id}] {num_new_tokens=} {num_new_tokens_uncached=} {num_new_tokens_cached=}, budget: {budget.num_batched_tokens}") else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens def _update_prefix_cached_tokens(self, seq: Sequence): + """ + Update the number of prefix cached tokens for a sequence. + + This function takes O(log(n)) time, where n is the number of blocks + in the sequence. + """ num_prefix_cached_tokens = self.block_manager.get_num_computed_tokens(seq) seq.set_num_prefix_cached_tokens(num_prefix_cached_tokens) def _get_num_new_tokens_exclude_cached( self, num_new_tokens: int, seq: Sequence ) -> int: + """ + Get the number of new tokens to compute for a sequence, excluding + cached tokens. + + Args: + num_new_tokens: The number of new tokens to compute. + seq: The sequence to compute the new tokens for. + + Returns: + Given `num_new_tokens`, returns the number of uncached tokens. + """ # If a decode sequence, new tokens are always not computed/cached. if not seq.is_prefill(): diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 17476dd78d1f6..34694a37405c0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1637,7 +1637,6 @@ def _get_stats(self, # Iteration stats num_prompt_tokens_iter = 0 num_generation_tokens_iter = 0 - num_extra_batched_tokens_iter = 0 time_to_first_tokens_iter: List[float] = [] time_per_output_tokens_iter: List[float] = [] num_preemption_iter = (0 if scheduler_outputs is None else @@ -1678,17 +1677,6 @@ def _get_stats(self, # not counted (to avoid double counting) actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore - num_extra_batched_tokens_iter = ( - actual_num_batched_tokens - - scheduler_outputs.num_batched_tokens_from_budget - ) - if num_extra_batched_tokens_iter > 0: - print( - f"num_extra_batched_tokens_iter: {num_extra_batched_tokens_iter}, " - f"actual_num_batched_tokens: {actual_num_batched_tokens}, " - f"num_batched_tokens_from_budget: {scheduler_outputs.num_batched_tokens_from_budget}" - ) - num_generation_tokens_from_prefill_groups = 0. # NOTE: if scheduler_outputs.num_prefill_groups > 0 and # the len of scheduler_outputs.scheduled_seq_groups is != @@ -1802,7 +1790,6 @@ def _get_stats(self, time_per_output_tokens_iter=time_per_output_tokens_iter, spec_decode_metrics=spec_decode_metrics, num_preemption_iter=num_preemption_iter, - num_extra_batched_tokens_iter=num_extra_batched_tokens_iter, # Request stats # Latency time_e2e_requests=time_e2e_requests, @@ -1813,7 +1800,8 @@ def _get_stats(self, finished_reason_requests=finished_reason_requests, max_lora=str(max_lora_stat), waiting_lora_adapters=list(waiting_lora_adapters.keys()), - running_lora_adapters=list(running_lora_adapters.keys())) + running_lora_adapters=list(running_lora_adapters.keys()), + ) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index d9bc9dfddfee1..e9a5bd3b586be 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -42,7 +42,6 @@ class Stats: time_to_first_tokens_iter: List[float] time_per_output_tokens_iter: List[float] num_preemption_iter: int - num_extra_batched_tokens_iter: int # Request stats (should have _requests suffix) # Latency diff --git a/vllm/sequence.py b/vllm/sequence.py index 0af3b7acdf3ab..a3d6c0b1492ad 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -566,8 +566,11 @@ def get_output_token_ids_to_return( return self.data._cached_all_token_ids[-num_new_tokens:] - def get_block_hash(self, block_idx: int) -> Optional[int]: - + def update_and_get_block_hash(self, block_idx: int) -> Optional[int]: + """ + Get the block hash for a given block index. + Optionally update the block hashes if not computed yet. + """ # Lazy update the block hashes on the first invocation. if block_idx >= len(self._computed_block_hashes): self._update_block_hashes() @@ -577,14 +580,12 @@ def get_block_hash(self, block_idx: int) -> Optional[int]: return None def get_block_hashes(self) -> List[int]: - # TODO(rickyx): maybe better to have an API to track if the computed hash is updated. self._update_block_hashes() return self._computed_block_hashes def _update_block_hashes(self): """ Update the block hashes for all the full blocks in the sequence. - It skips the blocks that have already been computed. """ token_ids = self.get_token_ids() # All token ids in the sequence @@ -697,8 +698,7 @@ def get_num_new_tokens(self) -> int: if self.data.stage == SequenceStage.DECODE: return 1 - num_computed_tokens = self.data.get_num_computed_tokens() - return self.data.get_len() - num_computed_tokens + return self.data.get_num_uncomputed_tokens() def get_num_cached_tokens(self) -> int: return self.data.get_num_prefix_cached_tokens()