Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding] Fixing hidden states handling in batch expansion #7508

Merged
merged 6 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
ensure_all_accepted=ensure_all_accepted)


def run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False):
def run_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero (or when temperature is > 0 and seeded).
Expand Down Expand Up @@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids

print(f'{acceptance_rate=}')

if ensure_all_accepted:
assert acceptance_rate == 1.0

if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2
42 changes: 42 additions & 0 deletions tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify acceptance rate with different batch size and large output
length."""
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=0.0,
seeded=True,
force_output_len=True,
expected_acceptance_rate=0.48)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill do you know what value we should expect here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(like, is this ballpark correct)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cadedaniel the value looks reasonable but I’m actually not sure what’s expected, will try to find out when I get a chance.



@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
86 changes: 60 additions & 26 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import chain, count
from typing import Iterator, List, Tuple
from typing import Iterator, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -86,21 +86,22 @@ def score_proposals(
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]

all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
(all_tokens, all_probs, spec_logprobs,
all_hidden_states) = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)

return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=spec_logprobs,
hidden_states=target_sampler_output.hidden_states,
hidden_states=all_hidden_states,
)

def _expand_batch(
Expand Down Expand Up @@ -143,20 +144,22 @@ def _expand_batch(
num_scoring_tokens)

def _contract_batch(
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int], k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.

contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(target_token_ids, target_probs, target_logprobs,
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = self._split_scoring_output(
non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)

# Map distinct sequences used to score each token
Expand All @@ -174,23 +177,40 @@ def _contract_batch(
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)

if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
spec_expanded_bs, k + 1, target_hidden_states.shape[-1])

all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))

if target_sampler_output.hidden_states is not None:
all_hidden_states = target_hidden_states.new_zeros(
size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
else:
all_hidden_states = None

if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs

if all_hidden_states is not None:
all_hidden_states[
non_spec_indices, :1, :] = non_spec_target_hidden_states

if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs

return all_tokens, all_probs, all_logprobs
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states

return all_tokens, all_probs, all_logprobs, all_hidden_states

def _create_scoring_model_input(
self,
Expand Down Expand Up @@ -324,8 +344,9 @@ def _create_single_target_seq_group_metadata(

def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
"""Split the target model output into speculative and non-speculative
output.
"""
Expand All @@ -350,24 +371,37 @@ def _split_scoring_output(
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)

if sampler_output.hidden_states is not None:
(
spec_hidden_states,
non_spec_hidden_states,
) = sampler_output.hidden_states.split(split_sizes)
else:
spec_hidden_states, non_spec_hidden_states = None, None

# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
sampler_output.logprobs = spec_logprobs
(target_token_ids, target_probs,
target_logprobs) = sampler_output_to_torch([sampler_output], True)
sampler_output.hidden_states = spec_hidden_states
(target_token_ids, target_probs, target_logprobs,
target_hidden_states) = sampler_output_to_torch([sampler_output],
True)

# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
sampler_output.logprobs = non_spec_logprobs
sampler_output.hidden_states = non_spec_hidden_states
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
True)
non_spec_target_logprobs,
non_spec_target_hidden_states) = sampler_output_to_torch(
[sampler_output], True)

return (target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs)
target_hidden_states, non_spec_target_token_ids,
non_spec_target_probs, non_spec_target_logprobs,
non_spec_target_hidden_states)

def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
Expand Down
5 changes: 2 additions & 3 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,8 @@ def _verify_tokens(
hidden_states = proposal_scores.hidden_states
if hidden_states is not None:
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[1]
hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
hs_size)
hs_size = hidden_states.shape[-1]

accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
Expand Down
2 changes: 1 addition & 1 deletion vllm/spec_decode/top1_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _merge_outputs(
return proposal_tokens, proposal_probs, proposal_lens_tensor

sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
sampler_output, sampler_transposed)

# Now, reformat the output GPU tensors such that each sequence has
Expand Down
20 changes: 18 additions & 2 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def split_batch_by_proposal_len(

def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
Expand Down Expand Up @@ -169,7 +169,23 @@ def sampler_output_to_torch(
if sampler_transposed:
sampled_token_ids = sampled_token_ids.transpose(0, 1)

return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
if sampler_output_list[0].hidden_states is not None:
# shape: [batch_size, num_sampler_output, hidden_dim]
sampled_hidden_states = torch.stack(
[
sampler_output.hidden_states
for sampler_output in sampler_output_list
],
dim=0,
)

if sampler_transposed:
sampled_hidden_states = sampled_hidden_states.transpose(0, 1)
else:
sampled_hidden_states = None

return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
sampled_hidden_states)


def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
Expand Down
Loading