From 3b684a8a54096087e40f1e1558e27040e66fbb81 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:33:16 -0800 Subject: [PATCH] spec decode: streamline batch expansion tensor manipulation (#918) --- aphrodite/spec_decode/batch_expansion.py | 138 +++++++++++--------- aphrodite/spec_decode/spec_decode_worker.py | 25 ++-- aphrodite/spec_decode/util.py | 42 +++--- tests/spec_decode/test_utils.py | 31 ++--- 4 files changed, 114 insertions(+), 122 deletions(-) diff --git a/aphrodite/spec_decode/batch_expansion.py b/aphrodite/spec_decode/batch_expansion.py index ca6cb6114..80f417627 100644 --- a/aphrodite/spec_decode/batch_expansion.py +++ b/aphrodite/spec_decode/batch_expansion.py @@ -12,8 +12,7 @@ from aphrodite.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) -from aphrodite.spec_decode.util import (nvtx_range, sampler_output_to_torch, - split_batch_by_proposal_len) +from aphrodite.spec_decode.util import nvtx_range, split_batch_by_proposal_len from aphrodite.task_handler.worker_base import WorkerBase SeqId = int @@ -90,16 +89,24 @@ def score_proposals( assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - (all_tokens, all_probs, spec_logprobs, - all_hidden_states) = self._contract_batch( - contracted_bs=len(execute_model_req.seq_group_metadata_list), - target_sampler_output=target_sampler_output, - proposals=proposals, - num_scoring_tokens=num_scoring_tokens, - non_spec_indices=non_spec_indices, - spec_indices=spec_indices, - k=execute_model_req.num_lookahead_slots, - ) + if not non_spec_indices: + # All sequence groups in batch have spec decoding enabled + contracted = self._contract_batch_all_spec( + target_sampler_output=target_sampler_output, + proposals=proposals, + ) + else: + # Batch has a mix of spec decode enabled and disabled seq groups + contracted = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) + all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted return SpeculativeScores( probs=all_probs, @@ -123,14 +130,9 @@ def _expand_batch( # batch proposal len. This adds some complexity (splitting the batch # into spec and non spec sequences) and should be removed in the # future. It can be done by supporting per-sequence proposal lens. - spec_seqs, spec_indices = split_batch_by_proposal_len( - seq_group_metadata_list, - proposal_lens_list, - select_proposal_len_zero=False) - non_spec_seqs, non_spec_indices = split_batch_by_proposal_len( - seq_group_metadata_list, - proposal_lens_list, - select_proposal_len_zero=True) + (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \ + split_batch_by_proposal_len( + seq_group_metadata_list, proposal_lens_list) target_seq_group_metadata_list = self._create_scoring_model_input( seq_group_metadata_list=spec_seqs, @@ -173,7 +175,7 @@ def _contract_batch( # The number of tokens in the expanded batch used for speculation is # equal to the total expanded batch size minus the number of samples for # non-speculative sequences. - non_spec_expanded_bs, _ = non_spec_target_token_ids.shape + non_spec_expanded_bs = len(non_spec_target_token_ids) spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) @@ -183,7 +185,7 @@ def _contract_batch( if target_hidden_states is not None: target_hidden_states = target_hidden_states.reshape( - spec_expanded_bs, k + 1, target_hidden_states.shape[-1]) + *target_token_ids.shape, target_hidden_states.shape[-1]) all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), fill_value=-1) all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) @@ -195,14 +197,19 @@ def _contract_batch( size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) else: all_hidden_states = None - if non_spec_indices: - all_tokens[non_spec_indices, :1] = non_spec_target_token_ids - all_probs[non_spec_indices, :1, :] = non_spec_target_probs - all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs + if non_spec_indices: + all_tokens[non_spec_indices, :1] = \ + non_spec_target_token_ids.unsqueeze(1) + all_probs[non_spec_indices, :1, :] = \ + non_spec_target_probs.unsqueeze(1) + all_logprobs[non_spec_indices, :1, :] = \ + non_spec_target_logprobs.unsqueeze(1) if all_hidden_states is not None: - all_hidden_states[ - non_spec_indices, :1, :] = non_spec_target_hidden_states + assert non_spec_target_hidden_states is not None + all_hidden_states[non_spec_indices, :1, :] = \ + non_spec_target_hidden_states.unsqueeze(1) + if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs @@ -212,6 +219,34 @@ def _contract_batch( all_hidden_states[spec_indices] = target_hidden_states return all_tokens, all_probs, all_logprobs, all_hidden_states + def _contract_batch_all_spec( + self, + target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: + """Contract the expanded batch back into its original size. + This maps the scores of speculative tokens back to their original + sequences. + It assumes all sequences in the batch were previously expanded. + """ + # Map distinct sequences used to score each token + # of shape [batch_size * k + 1] back to [batch_size, k + 1]. + contracted_bs, k = proposals.proposal_token_ids.shape + # Reshape tensors to original batch size + target_token_ids = target_sampler_output.sampled_token_ids.reshape( + contracted_bs, k + 1) + target_probs = target_sampler_output.sampled_token_probs.reshape( + *target_token_ids.shape, self._vocab_size) + target_logprobs = target_sampler_output.logprobs.reshape( + target_probs.shape) + target_hidden_states = target_sampler_output.hidden_states + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + *target_token_ids.shape, target_hidden_states.shape[-1]) + return (target_token_ids, target_probs, target_logprobs, + target_hidden_states) + def _create_scoring_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -343,8 +378,9 @@ def _create_single_target_seq_group_metadata( token_chunk_size=1, ) + @staticmethod def _split_scoring_output( - self, sampler_output: SamplerOutput, num_scoring_tokens: int + sampler_output: SamplerOutput, num_scoring_tokens: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: @@ -358,10 +394,9 @@ def _split_scoring_output( # future. It can be done by supporting per-sequence proposal lens. # First samples are from speculative scoring, latter samples are non- # speculative samples. - split_sizes = [ - num_scoring_tokens, - sampler_output.sampled_token_ids.numel() - num_scoring_tokens - ] + split_sizes = (num_scoring_tokens, + sampler_output.sampled_token_ids.numel() - + num_scoring_tokens) (spec_probs, non_spec_probs ) = sampler_output.sampled_token_probs.split(split_sizes) (spec_sampled_tokens, non_spec_sampled_tokens @@ -378,34 +413,15 @@ def _split_scoring_output( ) = sampler_output.hidden_states.split(split_sizes) else: spec_hidden_states, non_spec_hidden_states = None, None - # Convert scores to tensors. - sampler_output.sampled_token_probs = spec_probs - sampler_output.sampled_token_ids = spec_sampled_tokens - sampler_output.logprobs = spec_logprobs - sampler_output.hidden_states = spec_hidden_states - (target_token_ids, target_probs, target_logprobs, - target_hidden_states) = sampler_output_to_torch([sampler_output], - True) - - # Convert non-speculative output tokens to tensors. - sampler_output.sampled_token_probs = non_spec_probs - sampler_output.sampled_token_ids = non_spec_sampled_tokens - sampler_output.logprobs = non_spec_logprobs - sampler_output.hidden_states = non_spec_hidden_states - (non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs, - non_spec_target_hidden_states) = sampler_output_to_torch( - [sampler_output], True) - - return (target_token_ids, target_probs, target_logprobs, - target_hidden_states, non_spec_target_token_ids, - non_spec_target_probs, non_spec_target_logprobs, - non_spec_target_hidden_states) + return (spec_sampled_tokens, spec_probs, spec_logprobs, + spec_hidden_states, non_spec_sampled_tokens, non_spec_probs, + non_spec_logprobs, non_spec_hidden_states) + @staticmethod def _create_target_seq_id_iterator( - self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: + seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: """Create an iterator for creating target sequence ids. Target sequence ids are distinct from sequence ids because we create a distinct target sequence id for each proposal token to be scored. @@ -415,8 +431,8 @@ def _create_target_seq_id_iterator( """ return count(start=max(seq_ids) + 1) + @staticmethod def _get_token_ids_to_score( - self, full_spec_token_ids: List[TokenId] # shape: [k] ) -> List[List[TokenId]]: """Given an int tensor of proposal token ids, return a list of @@ -437,8 +453,6 @@ def _get_token_ids_to_score( empty_token_ids: List[TokenId] = [] token_ids_to_score = [empty_token_ids] - token_ids_to_score.extend([ - full_spec_token_ids[:i + 1] - for i in range(len(full_spec_token_ids)) - ]) + token_ids_to_score.extend(full_spec_token_ids[:i + 1] + for i in range(len(full_spec_token_ids))) return token_ids_to_score diff --git a/aphrodite/spec_decode/spec_decode_worker.py b/aphrodite/spec_decode/spec_decode_worker.py index 25940c9aa..601d814fa 100644 --- a/aphrodite/spec_decode/spec_decode_worker.py +++ b/aphrodite/spec_decode/spec_decode_worker.py @@ -363,12 +363,13 @@ def execute_model( # used during the prefill phase. # 2. Auto-disable enabled: The running queue size exceeds # the specified threshold. - # 3. No request: There are no requests in the batch. + # 3. No request: There are no requests in the batch, or + # none of the requests in the batch have spec decoding enabled. # In any of these cases, the proposer and scorer workers # are called normally. - no_spec = num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list - ) == 0 or disable_all_speculation + no_spec = num_lookahead_slots == 0 or disable_all_speculation or all( + sgm.num_speculative_tokens == 0 + for sgm in execute_model_req.seq_group_metadata_list) # Broadcast how many lookahead slots are scheduled for this step, and # whether all speculation is disabled, to all non-driver workers. @@ -412,10 +413,8 @@ def _should_disable_all_speculation( self, execute_model_req: ExecuteModelRequest) -> bool: # When the batch size is too large, disable speculative decoding # to stop trading off throughput for latency. - disable_all_speculation = (execute_model_req.running_queue_size >= - self.disable_by_batch_size) - - return disable_all_speculation + return (execute_model_req.running_queue_size >= + self.disable_by_batch_size) def _maybe_disable_speculative_tokens( self, disable_all_speculation: bool, @@ -614,14 +613,8 @@ def _verify_tokens( # batch proposal len. This adds some complexity (splitting the batch # into spec and non spec sequences) and should be removed in the # future. It can be done by supporting per-sequence proposal lens. - _, spec_indices = split_batch_by_proposal_len( - seq_group_metadata_list, - proposal_lens_list, - select_proposal_len_zero=False) - _, non_spec_indices = split_batch_by_proposal_len( - seq_group_metadata_list, - proposal_lens_list, - select_proposal_len_zero=True) + (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len( + seq_group_metadata_list, proposal_lens_list) original_indices = spec_indices + non_spec_indices # Get probabilities of target model, excluding bonus token. diff --git a/aphrodite/spec_decode/util.py b/aphrodite/spec_decode/util.py index 397e9d49f..5fee22c7a 100644 --- a/aphrodite/spec_decode/util.py +++ b/aphrodite/spec_decode/util.py @@ -1,6 +1,6 @@ import time from contextlib import contextmanager -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Sequence, Tuple import torch @@ -99,33 +99,26 @@ def create_sequence_group_output( def split_batch_by_proposal_len( seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_lens: List[int], select_proposal_len_zero: bool -) -> Tuple[List[SequenceGroupMetadata], List[int]]: + proposal_lens: List[int], +) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[ + List[SequenceGroupMetadata], List[int]]]: """Utility function that splits a batch based on whether the proposal len is zero or not. We should remove this once Aphrodite supports per-sequence proposal lens in a batch. """ - if select_proposal_len_zero: - predicate = lambda proposal_len: proposal_len == 0 - else: - predicate = lambda proposal_len: proposal_len != 0 - - indices = [ - i for i, (_, proposal_len - ) in enumerate(zip(seq_group_metadata_list, proposal_lens)) - if predicate(proposal_len) - ] - seq_groups = [ - seq_group for seq_group, proposal_len in zip( - seq_group_metadata_list, proposal_lens) if predicate(proposal_len) - ] - - return seq_groups, indices + nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) + zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) + for i, (seq_group, proposal_len) in enumerate( + zip(seq_group_metadata_list, proposal_lens)): + seq_groups, indices = nonzero_lists if proposal_len else zero_lists + seq_groups.append(seq_group) + indices.append(i) + return nonzero_lists, zero_lists def sampler_output_to_torch( - sampler_output_list: List[SamplerOutput], sampler_transposed: bool + sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Utility function which converts a list of SamplerOutput to tensors. @@ -149,18 +142,12 @@ def sampler_output_to_torch( dim=0, ) - if sampler_transposed: - sampled_token_probs = sampled_token_probs.transpose(0, 1) - # shape: [batch_size, num_sampler_output, vocab_size] sampled_token_logprobs = torch.stack( [sampler_output.logprobs for sampler_output in sampler_output_list], dim=0, ) - if sampler_transposed: - sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) - # shape: [batch_size, num_sampler_output] sampled_token_ids = torch.stack( [ @@ -169,7 +156,10 @@ def sampler_output_to_torch( ], dim=0, ) + if sampler_transposed: + sampled_token_probs = sampled_token_probs.transpose(0, 1) + sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) sampled_token_ids = sampled_token_ids.transpose(0, 1) if sampler_output_list[0].hidden_states is not None: diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 63e150b7b..5a06f3ee9 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -55,10 +55,9 @@ def fake_sequence_group_metadata(): def test_filter_zero_length_proposals(fake_sequence_group_metadata): proposal_lens = [0, 1, 0] - filtered_groups, indices = split_batch_by_proposal_len( - fake_sequence_group_metadata, - proposal_lens, - select_proposal_len_zero=True) + _, (filtered_groups, + indices) = split_batch_by_proposal_len(fake_sequence_group_metadata, + proposal_lens) expected_groups = [ fake_sequence_group_metadata[0], fake_sequence_group_metadata[2] @@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata): def test_filter_non_zero_length_proposals(fake_sequence_group_metadata): proposal_lens = [0, 1, 2] - filtered_groups, indices = split_batch_by_proposal_len( - fake_sequence_group_metadata, - proposal_lens, - select_proposal_len_zero=False) + (filtered_groups, + indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata, + proposal_lens) expected_groups = [ fake_sequence_group_metadata[1], fake_sequence_group_metadata[2] @@ -86,8 +84,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata): def test_empty_inputs(): - filtered_groups, indices = split_batch_by_proposal_len( - [], [], select_proposal_len_zero=True) + _, (filtered_groups, indices) = split_batch_by_proposal_len([], []) assert filtered_groups == [] assert indices == [] @@ -95,10 +92,9 @@ def test_empty_inputs(): def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata): proposal_lens = [0, 0, 0] - filtered_groups, indices = split_batch_by_proposal_len( - fake_sequence_group_metadata, - proposal_lens, - select_proposal_len_zero=False) + (filtered_groups, + indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata, + proposal_lens) assert filtered_groups == [] assert indices == [] @@ -106,10 +102,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata): def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): proposal_lens = [1, 1, 1] - filtered_groups, indices = split_batch_by_proposal_len( - fake_sequence_group_metadata, - proposal_lens, - select_proposal_len_zero=True) + _, (filtered_groups, + indices) = split_batch_by_proposal_len(fake_sequence_group_metadata, + proposal_lens) assert filtered_groups == [] assert indices == []