Skip to content

Commit

Permalink
Disable LogProbs generation for spec decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
sroy745 committed Jul 16, 2024
1 parent 6212d5f commit 1eddc27
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 80 deletions.
2 changes: 2 additions & 0 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def process_outputs(self, sequence_group: SequenceGroup,
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
#for output in outputs:
# print('output ' + str(output))

# Since there's only one sequence per sequence group, we can take the
# first sample.
Expand Down
35 changes: 26 additions & 9 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def forward(
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
logprobs = None
# TODO(sroy) - Add flag for this.
#logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# Sample the next tokens.
sample_results, maybe_sampled_tokens_tensor = _sample(
Expand All @@ -109,13 +111,21 @@ def forward(
on_device_tensors = None

# Get the logprobs query results.
#print('sample_results ' + str(sample_results))
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors)
output = _build_sampler_output(sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors)
#print('output ' + str(output))
return output
#return _build_sampler_output(sample_results,
# sampling_metadata,
# prompt_logprobs,
# sample_logprobs,
# on_device_tensors=on_device_tensors)

@property
def _should_modify_greedy_probs_inplace(self) -> bool:
Expand Down Expand Up @@ -472,10 +482,10 @@ def _sample_with_torch(

# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
sampled_token_ids_tensor = torch.empty(probs.shape[0],
1,
dtype=torch.long,
device=logprobs.device)
device=probs.device)
else:
sampled_token_ids_tensor = None

Expand All @@ -492,7 +502,9 @@ def _sample_with_torch(
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[long_sample_indices],
#greedy_samples = torch.argmax(logprobs[long_sample_indices],
# dim=-1)
greedy_samples = torch.argmax(probs[long_sample_indices],
dim=-1)

if sampled_token_ids_tensor is not None:
Expand Down Expand Up @@ -720,6 +732,11 @@ def _get_logprobs(
Returns:
A tuple of prompt and sample logprobs per sequence group in a batch.
"""
if logprobs is None:
empty_sampled_logprob: List[SampleLogprobs] = [[None] * len(sample_results)] * (
len(sampling_metadata.seq_groups))
empty_prompt_logprob: List[PromptLogprobs] = [None] * len(sampling_metadata.seq_groups)
return empty_prompt_logprob, empty_sampled_logprob
# The index of query token to calculate logprobs. It includes both
# prompt and sample logprob indices.
query_indices: List[int] = []
Expand Down
7 changes: 5 additions & 2 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,12 @@ def append_token_id(
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
assert token_id in logprobs
if logprobs is not None:
assert token_id in logprobs
self.data.append_token_id(token_id, logprobs[token_id].logprob)
else:
self.data.append_token_id(token_id, 0.0)
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)

def get_len(self) -> int:
return self.data.get_len()
Expand Down
30 changes: 21 additions & 9 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,23 +169,29 @@ def _contract_batch(
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
target_probs = target_probs.reshape(*target_token_ids.shape,
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)

if target_logprobs is not None:
target_logprobs = target_logprobs.reshape(target_probs.shape)

all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))
all_logprobs = None
if target_logprobs is not None:
all_logprobs = target_logprobs.new_full(
size=all_probs.shape, fill_value=-float("inf"))

if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
if all_logprobs is not None:
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs

if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
if all_logprobs is not None:
all_logprobs[spec_indices] = target_logprobs

return all_tokens, all_probs, all_logprobs

Expand Down Expand Up @@ -327,10 +333,16 @@ def _split_scoring_output(
) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
(
spec_logprobs,
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)
if (sampler_output.logprobs is not None):
(
spec_logprobs,
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)
else:
(
spec_logprobs,
non_spec_logprobs,
) = (None, None)

# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
Expand Down
8 changes: 5 additions & 3 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ def update_model_input(
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]

token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]

seq.append_token_id(token_id, token_logprob.logprob)
if seq_output.logprobs is not None:
token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob.logprob)
else:
seq.append_token_id(token_id, 0.0)
seq.update_num_computed_tokens(1)

return self.prepare_model_input(self.cached_seq_group_metadata_list)
Expand Down
6 changes: 4 additions & 2 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ def _append_new_tokens(
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]

token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
token_logprob = 0.0
if seq_output.logprobs is not None:
token_logprob = seq_output.logprobs[token_id].logprob

seq.append_token_id(token_id, token_logprob.logprob)
seq.append_token_id(token_id, token_logprob)
seq.update_num_computed_tokens(1)

@staticmethod
Expand Down
87 changes: 50 additions & 37 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,25 +569,31 @@ def _create_output_sampler_list(
the same number of outputs.
"""
batch_size, num_steps = accepted_token_ids.shape

# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
if target_logprobs is not None:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)

# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)

# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)

# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(topk_logprobs_by_step,
topk_indices_by_step) = target_logprobs_by_step.topk(
k=self.scorer_worker.model_config.max_logprobs,
dim=-1,
)
# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(topk_logprobs_by_step,
topk_indices_by_step) = target_logprobs_by_step.topk(
k=self.scorer_worker.model_config.max_logprobs,
dim=-1,
)
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step.tolist())
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
topk_indices_by_step = topk_indices_by_step.tolist()

# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
Expand All @@ -598,12 +604,6 @@ def _create_output_sampler_list(

# Serialize all tensors to CPU Python lists.
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step.tolist())
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
topk_indices_by_step = topk_indices_by_step.tolist()

# Construct the output on a per-step, per-sequence basis.
sampler_output_list: List[SamplerOutput] = []
Expand All @@ -616,20 +616,33 @@ def _create_output_sampler_list(
for sequence_index in range(batch_size):
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append(
create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index]
[sequence_index],
token_id_logprob_rank=accepted_token_id_ranks_by_step[
step_index][sequence_index],
token_id_logprob=accepted_token_id_logprobs_by_step[
step_index][sequence_index],
seq_id=seq_ids[sequence_index],
topk_token_ids=topk_indices_by_step[step_index]
[sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
))
if target_logprobs is not None:
step_output_token_ids.append(
create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index]
[sequence_index],
token_id_logprob_rank=accepted_token_id_ranks_by_step[
step_index][sequence_index] ,
token_id_logprob=accepted_token_id_logprobs_by_step[
step_index][sequence_index],
seq_id=seq_ids[sequence_index],
topk_token_ids=topk_indices_by_step[step_index]
[sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
))
else:
step_output_token_ids.append(
create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index]
[sequence_index],
token_id_logprob_rank=-1 ,
token_id_logprob=-1,
seq_id=seq_ids[sequence_index],
topk_token_ids=[],
topk_logprobs=[],
))

sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))

Expand Down
41 changes: 23 additions & 18 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,21 @@ def create_sequence_group_output(
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[int, Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,
),
}
logprobs.update({
topk_token_ids[topk_logprob_index]: Logprob(
logprob=topk_logprobs[topk_logprob_index],
rank=topk_logprob_index + 1,
)
for topk_logprob_index, _ in enumerate(topk_token_ids)
})
logprobs = None
if token_id_logprob >= 0.0:
logprobs: Dict[int, Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,
),
}
logprobs.update({
topk_token_ids[topk_logprob_index]: Logprob(
logprob=topk_logprobs[topk_logprob_index],
rank=topk_logprob_index + 1,
)
for topk_logprob_index, _ in enumerate(topk_token_ids)
})

return CompletionSequenceGroupOutput(
samples=[
Expand Down Expand Up @@ -149,12 +151,15 @@ def sampler_output_to_torch(
sampled_token_probs = sampled_token_probs.transpose(0, 1)

# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs = torch.stack(
[sampler_output.logprobs for sampler_output in sampler_output_list],
dim=0,
)
if sampler_output_list[0].logprobs is not None:
sampled_token_logprobs = torch.stack(
[sampler_output.logprobs for sampler_output in sampler_output_list],
dim=0,
)
else:
sampled_token_logprobs = None

if sampler_transposed:
if sampler_transposed and sampled_token_logprobs is not None:
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)

# shape: [batch_size, num_sampler_output]
Expand Down

0 comments on commit 1eddc27

Please sign in to comment.