diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index d0f91a63b2d6a..a701f482b4ffb 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, ensure_all_accepted=ensure_all_accepted) -def run_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len, - force_output_len: bool, - temperature: float, - seeded: bool, - print_tokens: bool = False, - ensure_all_accepted: bool = False): +def run_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + temperature: float, + seeded: bool, + print_tokens: bool = False, + ensure_all_accepted: bool = False, + expected_acceptance_rate: Optional[float] = None): """Helper method that compares the outputs of both the baseline LLM and the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero (or when temperature is > 0 and seeded). @@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator, print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids + print(f'{acceptance_rate=}') + if ensure_all_accepted: assert acceptance_rate == 1.0 + + if expected_acceptance_rate is not None: + assert acceptance_rate >= expected_acceptance_rate - 1e-2 diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 25067e7a4262c..c72e4595fd335 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + }, +]) +@pytest.mark.parametrize("output_len", [2048]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify acceptance rate with different batch size and large output + length.""" + run_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + temperature=0.0, + seeded=True, + force_output_len=True, + expected_acceptance_rate=0.48) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index aec4847b96c35..ad6f3f313841d 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,6 +1,6 @@ from array import array from itertools import chain, count -from typing import Iterator, List, Tuple +from typing import Iterator, List, Optional, Tuple import torch @@ -88,21 +88,22 @@ def score_proposals( assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - all_tokens, all_probs, spec_logprobs = self._contract_batch( - contracted_bs=len(execute_model_req.seq_group_metadata_list), - target_sampler_output=target_sampler_output, - proposals=proposals, - num_scoring_tokens=num_scoring_tokens, - non_spec_indices=non_spec_indices, - spec_indices=spec_indices, - k=execute_model_req.num_lookahead_slots, - ) + (all_tokens, all_probs, spec_logprobs, + all_hidden_states) = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) return SpeculativeScores( probs=all_probs, token_ids=all_tokens, logprobs=spec_logprobs, - hidden_states=target_sampler_output.hidden_states, + hidden_states=all_hidden_states, ) def _expand_batch( @@ -145,10 +146,11 @@ def _expand_batch( num_scoring_tokens) def _contract_batch( - self, contracted_bs: int, target_sampler_output: SamplerOutput, - proposals: SpeculativeProposals, num_scoring_tokens: int, - non_spec_indices: List[int], spec_indices: List[int], - k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, contracted_bs: int, target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], k: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -156,9 +158,10 @@ def _contract_batch( contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to. """ - (target_token_ids, target_probs, target_logprobs, + (target_token_ids, target_probs, target_logprobs, target_hidden_states, non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = self._split_scoring_output( + non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token @@ -176,23 +179,40 @@ def _contract_batch( self._vocab_size) target_logprobs = target_logprobs.reshape(target_probs.shape) + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + spec_expanded_bs, k + 1, target_hidden_states.shape[-1]) + all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), fill_value=-1) all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) all_logprobs = target_logprobs.new_full(size=all_probs.shape, fill_value=-float("inf")) + if target_sampler_output.hidden_states is not None: + all_hidden_states = target_hidden_states.new_zeros( + size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) + else: + all_hidden_states = None + if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs + if all_hidden_states is not None: + all_hidden_states[ + non_spec_indices, :1, :] = non_spec_target_hidden_states + if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs all_logprobs[spec_indices] = target_logprobs - return all_tokens, all_probs, all_logprobs + if all_hidden_states is not None: + all_hidden_states[spec_indices] = target_hidden_states + + return all_tokens, all_probs, all_logprobs, all_hidden_states def _create_scoring_model_input( self, @@ -327,8 +347,9 @@ def _create_single_target_seq_group_metadata( def _split_scoring_output( self, sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: """Split the target model output into speculative and non-speculative output. """ @@ -353,24 +374,37 @@ def _split_scoring_output( non_spec_logprobs, ) = sampler_output.logprobs.split(split_sizes) + if sampler_output.hidden_states is not None: + ( + spec_hidden_states, + non_spec_hidden_states, + ) = sampler_output.hidden_states.split(split_sizes) + else: + spec_hidden_states, non_spec_hidden_states = None, None + # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens sampler_output.logprobs = spec_logprobs - (target_token_ids, target_probs, - target_logprobs) = sampler_output_to_torch([sampler_output], True) + sampler_output.hidden_states = spec_hidden_states + (target_token_ids, target_probs, target_logprobs, + target_hidden_states) = sampler_output_to_torch([sampler_output], + True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens sampler_output.logprobs = non_spec_logprobs + sampler_output.hidden_states = non_spec_hidden_states (non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = sampler_output_to_torch([sampler_output], - True) + non_spec_target_logprobs, + non_spec_target_hidden_states) = sampler_output_to_torch( + [sampler_output], True) return (target_token_ids, target_probs, target_logprobs, - non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) + target_hidden_states, non_spec_target_token_ids, + non_spec_target_probs, non_spec_target_logprobs, + non_spec_target_hidden_states) def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 63a00139cc09d..acf77a7349eef 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -646,9 +646,8 @@ def _verify_tokens( hidden_states = proposal_scores.hidden_states if hidden_states is not None: # Contract hidden states based on accepted tokens - hs_size = hidden_states.shape[1] - hidden_states = hidden_states.reshape(-1, max_proposal_len + 1, - hs_size) + hs_size = hidden_states.shape[-1] + accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 1a56497030280..28f7f7eb069ab 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -242,7 +242,7 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs, _ = sampler_output_to_torch( + proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( sampler_output, sampler_transposed) # Now, reformat the output GPU tensors such that each sequence has diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index c6223a97dba10..b85f2a6f70ac0 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -123,7 +123,7 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( sampler_output_list: List[SamplerOutput], sampler_transposed: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Utility function which converts a list of SamplerOutput to tensors. sampler_transposed here is used as the indicator for whether @@ -169,7 +169,23 @@ def sampler_output_to_torch( if sampler_transposed: sampled_token_ids = sampled_token_ids.transpose(0, 1) - return sampled_token_ids, sampled_token_probs, sampled_token_logprobs + if sampler_output_list[0].hidden_states is not None: + # shape: [batch_size, num_sampler_output, hidden_dim] + sampled_hidden_states = torch.stack( + [ + sampler_output.hidden_states + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_hidden_states = sampled_hidden_states.transpose(0, 1) + else: + sampled_hidden_states = None + + return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, + sampled_hidden_states) def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,