Skip to content

Commit

Permalink
fixing hidden states handling in batch expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigoyal1997 committed Aug 14, 2024
1 parent e20233d commit 90bee1d
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 32 deletions.
86 changes: 60 additions & 26 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import chain, count
from typing import Iterator, List, Tuple
from typing import Iterator, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -86,21 +86,22 @@ def score_proposals(
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]

all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
(all_tokens, all_probs, spec_logprobs,
all_hidden_states) = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)

return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=spec_logprobs,
hidden_states=target_sampler_output.hidden_states,
hidden_states=all_hidden_states,
)

def _expand_batch(
Expand Down Expand Up @@ -143,20 +144,22 @@ def _expand_batch(
num_scoring_tokens)

def _contract_batch(
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int], k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(target_token_ids, target_probs, target_logprobs,
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = self._split_scoring_output(
non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)

# Map distinct sequences used to score each token
Expand All @@ -174,23 +177,40 @@ def _contract_batch(
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)

if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
spec_expanded_bs, k + 1, target_hidden_states.shape[-1])

all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))

if target_sampler_output.hidden_states is not None:
all_hidden_states = target_hidden_states.new_zeros(
size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
else:
all_hidden_states = None

if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs

if all_hidden_states is not None:
all_hidden_states[
non_spec_indices, :1, :] = non_spec_target_hidden_states

if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs

return all_tokens, all_probs, all_logprobs
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states

return all_tokens, all_probs, all_logprobs, all_hidden_states

def _create_scoring_model_input(
self,
Expand Down Expand Up @@ -324,8 +344,9 @@ def _create_single_target_seq_group_metadata(

def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
"""Split the target model output into speculative and non-speculative
output.
"""
Expand All @@ -350,24 +371,37 @@ def _split_scoring_output(
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)

if sampler_output.hidden_states is not None:
(
spec_hidden_states,
non_spec_hidden_states,
) = sampler_output.hidden_states.split(split_sizes)
else:
spec_hidden_states, non_spec_hidden_states = None, None

# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
sampler_output.logprobs = spec_logprobs
(target_token_ids, target_probs,
target_logprobs) = sampler_output_to_torch([sampler_output], True)
sampler_output.hidden_states = spec_hidden_states
(target_token_ids, target_probs, target_logprobs,
target_hidden_states) = sampler_output_to_torch([sampler_output],
True)

# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
sampler_output.logprobs = non_spec_logprobs
sampler_output.hidden_states = non_spec_hidden_states
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
True)
non_spec_target_logprobs,
non_spec_target_hidden_states) = sampler_output_to_torch(
[sampler_output], True)

return (target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs)
target_hidden_states, non_spec_target_token_ids,
non_spec_target_probs, non_spec_target_logprobs,
non_spec_target_hidden_states)

def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
Expand Down
5 changes: 2 additions & 3 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,8 @@ def _verify_tokens(
hidden_states = proposal_scores.hidden_states
if hidden_states is not None:
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[1]
hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
hs_size)
hs_size = hidden_states.shape[-1]

accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
Expand Down
2 changes: 1 addition & 1 deletion vllm/spec_decode/top1_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _merge_outputs(
return proposal_tokens, proposal_probs, proposal_lens_tensor

sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
sampler_output, sampler_transposed)

# Now, reformat the output GPU tensors such that each sequence has
Expand Down
20 changes: 18 additions & 2 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def split_batch_by_proposal_len(

def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
Expand Down Expand Up @@ -169,7 +169,23 @@ def sampler_output_to_torch(
if sampler_transposed:
sampled_token_ids = sampled_token_ids.transpose(0, 1)

return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
if sampler_output_list[0].hidden_states is not None:
# shape: [batch_size, num_sampler_output, hidden_dim]
sampled_hidden_states = torch.stack(
[
sampler_output.hidden_states
for sampler_output in sampler_output_list
],
dim=0,
)

if sampler_transposed:
sampled_hidden_states = sampled_hidden_states.transpose(0, 1)
else:
sampled_hidden_states = None

return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
sampled_hidden_states)


def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
Expand Down

0 comments on commit 90bee1d

Please sign in to comment.