Skip to content

Commit

Permalink
[Speculative Decoding] Fixing hidden states handling in batch expansi…
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigoyal1997 authored and omrishiv committed Aug 26, 2024
1 parent 71b4af3 commit d241e8e
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 41 deletions.
25 changes: 16 additions & 9 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
ensure_all_accepted=ensure_all_accepted)


def run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False):
def run_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero (or when temperature is > 0 and seeded).
Expand Down Expand Up @@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids

print(f'{acceptance_rate=}')

if ensure_all_accepted:
assert acceptance_rate == 1.0

if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2
42 changes: 42 additions & 0 deletions tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify acceptance rate with different batch size and large output
length."""
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=0.0,
seeded=True,
force_output_len=True,
expected_acceptance_rate=0.48)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
86 changes: 60 additions & 26 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from array import array
from itertools import chain, count
from typing import Iterator, List, Tuple
from typing import Iterator, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -88,21 +88,22 @@ def score_proposals(
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]

all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
(all_tokens, all_probs, spec_logprobs,
all_hidden_states) = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)

return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=spec_logprobs,
hidden_states=target_sampler_output.hidden_states,
hidden_states=all_hidden_states,
)

def _expand_batch(
Expand Down Expand Up @@ -145,20 +146,22 @@ def _expand_batch(
num_scoring_tokens)

def _contract_batch(
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int], k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(target_token_ids, target_probs, target_logprobs,
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = self._split_scoring_output(
non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)

# Map distinct sequences used to score each token
Expand All @@ -176,23 +179,40 @@ def _contract_batch(
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)

if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
spec_expanded_bs, k + 1, target_hidden_states.shape[-1])

all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))

if target_sampler_output.hidden_states is not None:
all_hidden_states = target_hidden_states.new_zeros(
size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
else:
all_hidden_states = None

if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs

if all_hidden_states is not None:
all_hidden_states[
non_spec_indices, :1, :] = non_spec_target_hidden_states

if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs

return all_tokens, all_probs, all_logprobs
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states

return all_tokens, all_probs, all_logprobs, all_hidden_states

def _create_scoring_model_input(
self,
Expand Down Expand Up @@ -327,8 +347,9 @@ def _create_single_target_seq_group_metadata(

def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
"""Split the target model output into speculative and non-speculative
output.
"""
Expand All @@ -353,24 +374,37 @@ def _split_scoring_output(
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)

if sampler_output.hidden_states is not None:
(
spec_hidden_states,
non_spec_hidden_states,
) = sampler_output.hidden_states.split(split_sizes)
else:
spec_hidden_states, non_spec_hidden_states = None, None

# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
sampler_output.logprobs = spec_logprobs
(target_token_ids, target_probs,
target_logprobs) = sampler_output_to_torch([sampler_output], True)
sampler_output.hidden_states = spec_hidden_states
(target_token_ids, target_probs, target_logprobs,
target_hidden_states) = sampler_output_to_torch([sampler_output],
True)

# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
sampler_output.logprobs = non_spec_logprobs
sampler_output.hidden_states = non_spec_hidden_states
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
True)
non_spec_target_logprobs,
non_spec_target_hidden_states) = sampler_output_to_torch(
[sampler_output], True)

return (target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs)
target_hidden_states, non_spec_target_token_ids,
non_spec_target_probs, non_spec_target_logprobs,
non_spec_target_hidden_states)

def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
Expand Down
5 changes: 2 additions & 3 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,8 @@ def _verify_tokens(
hidden_states = proposal_scores.hidden_states
if hidden_states is not None:
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[1]
hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
hs_size)
hs_size = hidden_states.shape[-1]

accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
Expand Down
2 changes: 1 addition & 1 deletion vllm/spec_decode/top1_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _merge_outputs(
return proposal_tokens, proposal_probs, proposal_lens_tensor

sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
sampler_output, sampler_transposed)

# Now, reformat the output GPU tensors such that each sequence has
Expand Down
20 changes: 18 additions & 2 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def split_batch_by_proposal_len(

def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
Expand Down Expand Up @@ -169,7 +169,23 @@ def sampler_output_to_torch(
if sampler_transposed:
sampled_token_ids = sampled_token_ids.transpose(0, 1)

return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
if sampler_output_list[0].hidden_states is not None:
# shape: [batch_size, num_sampler_output, hidden_dim]
sampled_hidden_states = torch.stack(
[
sampler_output.hidden_states
for sampler_output in sampler_output_list
],
dim=0,
)

if sampler_transposed:
sampled_hidden_states = sampled_hidden_states.transpose(0, 1)
else:
sampled_hidden_states = None

return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
sampled_hidden_states)


def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
Expand Down

0 comments on commit d241e8e

Please sign in to comment.