Skip to content

Commit

Permalink
[Bugfix][SpecDecode] kv corruption with bonus tokens in spec decode (v…
Browse files Browse the repository at this point in the history
…llm-project#9730)

Co-authored-by: LiuXiaoxuanPKU <[email protected]>
  • Loading branch information
2 people authored and weilong.yu committed Dec 13, 2024
1 parent 83d899c commit e6137a7
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 10 deletions.
107 changes: 107 additions & 0 deletions tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pytest
import torch

from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
Expand Down Expand Up @@ -303,6 +305,7 @@ def test_multi_step_with_batch_expansion_correct_output():
seed,
model_runner_cls=TP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
Worker,
model_name,
Expand Down Expand Up @@ -397,6 +400,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
seed,
model_runner_cls=TP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
Worker,
model_name,
Expand Down Expand Up @@ -477,6 +481,109 @@ def test_multi_step_with_batch_expansion_incorrect_output():
assert (num_mismatch > 0)


@torch.inference_mode()
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
# The choice of backends forces the multi_step_worker to choose between
# the vanilla model_runner and TP1DraftModelRunner and that we can test
# both code paths.
@pytest.mark.parametrize('attn_backend',
[_Backend.XFORMERS, _Backend.FLASH_ATTN])
def test_multi_step_correct_kvcache(num_steps, attn_backend):
"""Verify that the KV cache of the draft model
is correctly updated for sequences with bonus token.
"""
seed = 100
model_name = "JackFram/llama-68m"

block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 1

with global_force_attn_backend_context_manager(attn_backend):
dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32'
multi_step_worker = create_worker(MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
dtype=dtype)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
dtype=dtype)

prompts = [[0] for _ in range(batch_size)]
# Already generate two tokens for the sequence
# so that we can simulate the bonus token case
multi_step_continuations = [[
random.randint(0, 1000),
random.randint(0, 1000)
] for _ in prompts]
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]

seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)

# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=
seq_ids_with_bonus_token_in_last_step)

# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
# Generate the kv cache for the bonus token first
single_step_continuations = [c[:1] for c in multi_step_continuations]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=single_step_continuations,
final_prompt_lens=final_prompt_lens)
single_step_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list))
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)

single_step_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list))

for i, seq_group_output in enumerate(single_step_output[-1]):
multi_step_continuations[i].append(
seq_group_output.samples[0].output_token)

# Verify that the KV cache of the single-step and
# multi-step workers are the same.
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
num_layers = len(single_step_gpu_cache)
allclose = lambda a, b: torch.allclose(
a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2)
for i in range(num_layers):
assert allclose(single_step_gpu_cache[i][0],
multi_step_gpu_cache[i][0])
assert allclose(single_step_gpu_cache[i][1],
multi_step_gpu_cache[i][1])


@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify Top1Proposer correctly handles case where all sequences
Expand Down
4 changes: 3 additions & 1 deletion tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ def create_worker(cls: Callable[..., T],
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True,
model_runner_cls: Optional[ModelRunner] = None) -> T:
model_runner_cls: Optional[ModelRunner] = None,
dtype: Optional[str] = "auto") -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
block_size=block_size,
enforce_eager=enforce_eager,
dtype=dtype,
)
engine_config = engine_args.create_engine_config()

Expand Down
35 changes: 30 additions & 5 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)

self.indices_of_seq_with_bonus_tokens = None

def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):

Expand Down Expand Up @@ -159,6 +161,10 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
# TODO: Add soft-tuning prompt adapter support
return not self.prompt_adapter_config

def set_indices_of_seq_with_bonus_tokens(self,
indices_of_seq_with_bonus_tokens):
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -284,11 +290,30 @@ def execute_model(
model_input.sampling_metadata)

# Sample the next token.
outputs.append(
self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
))
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
outputs.append(output)

if model_input.attn_metadata.num_prefills == 0 \
and self.indices_of_seq_with_bonus_tokens is not None:
assert output.sampled_token_ids is not None
# output.sampled_token_ids should be of shape (num_seqs, 1)
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
assert num_tokens_per_seq == 1
count = 0
for i in range(nums_seqs):
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
count]
if i != bonus_seq_idx:
# The following might cause a cpu->gpu sync
# However, the performance impact is negligible as we
# benchmarked on H100.
output.sampled_token_ids[
i, :] = model_input.input_tokens[bonus_seq_idx]
else:
count += 1

# Prepare inputs for the next step
if step != num_steps - 1:
Expand Down
23 changes: 19 additions & 4 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def sampler_output(
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
Expand All @@ -97,7 +99,8 @@ def sampler_output(
model_output = model_output[0]

self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list)
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)

filtered_model_outputs = self._filter_model_output(
Expand Down Expand Up @@ -221,13 +224,15 @@ def get_spec_proposals(
@staticmethod
def _append_new_tokens(
model_output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
seq_group_metadata_list: List[SequenceGroupMetadata],
indices_of_seq_with_bonus_tokens: List[int]) -> None:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
for seq_group_metadata, sequence_group_outputs in zip(
seq_group_metadata_list, model_output):
count = 0
for index, (seq_group_metadata, sequence_group_outputs) in enumerate(
zip(seq_group_metadata_list, model_output)):
seq_group_metadata.is_prompt = False

for seq_output in sequence_group_outputs.samples:
Expand All @@ -237,6 +242,16 @@ def _append_new_tokens(

token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
# Determine the actual token ID to be generated,
# considering bonus tokens
if index != indices_of_seq_with_bonus_tokens[count]:
bonus_seq_metadata = seq_group_metadata_list[
indices_of_seq_with_bonus_tokens[count]]
_, bonus_token_seq_data = next(
iter(bonus_seq_metadata.seq_data.items()))
token_id = bonus_token_seq_data.output_token_ids[-1]
else:
count += 1

seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1)
Expand Down

0 comments on commit e6137a7

Please sign in to comment.