From 3617a96487da008d64785bea7a73137ca012920d Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:45:45 +0900 Subject: [PATCH] [Bugfix][SpecDecode] kv corruption with bonus tokens in spec decode (#9730) Co-authored-by: LiuXiaoxuanPKU Signed-off-by: Loc Huynh --- tests/spec_decode/test_multi_step_worker.py | 107 ++++++++++++++++++++ tests/spec_decode/utils.py | 4 +- vllm/spec_decode/draft_model_runner.py | 35 ++++++- vllm/spec_decode/multi_step_worker.py | 23 ++++- 4 files changed, 159 insertions(+), 10 deletions(-) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index e6f7f480eebb2..0b5d82b6610ca 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -5,6 +5,8 @@ import pytest import torch +from vllm.attention.selector import (_Backend, + global_force_attn_backend_context_manager) from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob, @@ -303,6 +305,7 @@ def test_multi_step_with_batch_expansion_correct_output(): seed, model_runner_cls=TP1DraftModelRunner, ) + multi_step_worker.set_include_gpu_probs_tensor() worker = create_worker( Worker, model_name, @@ -397,6 +400,7 @@ def test_multi_step_with_batch_expansion_incorrect_output(): seed, model_runner_cls=TP1DraftModelRunner, ) + multi_step_worker.set_include_gpu_probs_tensor() worker = create_worker( Worker, model_name, @@ -477,6 +481,109 @@ def test_multi_step_with_batch_expansion_incorrect_output(): assert (num_mismatch > 0) +@torch.inference_mode() +@pytest.mark.parametrize('num_steps', [1, 2, 3, 4]) +# The choice of backends forces the multi_step_worker to choose between +# the vanilla model_runner and TP1DraftModelRunner and that we can test +# both code paths. +@pytest.mark.parametrize('attn_backend', + [_Backend.XFORMERS, _Backend.FLASH_ATTN]) +def test_multi_step_correct_kvcache(num_steps, attn_backend): + """Verify that the KV cache of the draft model + is correctly updated for sequences with bonus token. + """ + seed = 100 + model_name = "JackFram/llama-68m" + + block_size = 16 + num_gpu_blocks = 2048 // block_size + batch_size = 1 + + with global_force_attn_backend_context_manager(attn_backend): + dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32' + multi_step_worker = create_worker(MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + model_runner_cls=TP1DraftModelRunner, + dtype=dtype) + multi_step_worker.set_include_gpu_probs_tensor() + worker = create_worker(Worker, + model_name, + block_size, + num_gpu_blocks, + seed, + dtype=dtype) + + prompts = [[0] for _ in range(batch_size)] + # Already generate two tokens for the sequence + # so that we can simulate the bonus token case + multi_step_continuations = [[ + random.randint(0, 1000), + random.randint(0, 1000) + ] for _ in prompts] + final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts] + + seq_ids_with_bonus_token_in_last_step = set(range(batch_size)) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=multi_step_continuations, + final_prompt_lens=final_prompt_lens) + + # Run multi-step. + zero_kv_cache(multi_step_worker.cache_engine) + multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list), + sample_len=num_steps, + seq_ids_with_bonus_token_in_last_step= + seq_ids_with_bonus_token_in_last_step) + + # Run single-step repeatedly. + zero_kv_cache(worker.cache_engine) + # Generate the kv cache for the bonus token first + single_step_continuations = [c[:1] for c in multi_step_continuations] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=single_step_continuations, + final_prompt_lens=final_prompt_lens) + single_step_output = worker.execute_model( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list)) + for _ in range(num_steps): + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=multi_step_continuations, + final_prompt_lens=final_prompt_lens) + + single_step_output = worker.execute_model( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list)) + + for i, seq_group_output in enumerate(single_step_output[-1]): + multi_step_continuations[i].append( + seq_group_output.samples[0].output_token) + + # Verify that the KV cache of the single-step and + # multi-step workers are the same. + single_step_gpu_cache = worker.cache_engine[0].gpu_cache + multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache + num_layers = len(single_step_gpu_cache) + allclose = lambda a, b: torch.allclose( + a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2) + for i in range(num_layers): + assert allclose(single_step_gpu_cache[i][0], + multi_step_gpu_cache[i][0]) + assert allclose(single_step_gpu_cache[i][1], + multi_step_gpu_cache[i][1]) + + @torch.inference_mode() def test_draft_proposals_full_speculation_len(): """Verify Top1Proposer correctly handles case where all sequences diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 6cf0cfb09b8fa..e5cb0530f9961 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -68,12 +68,14 @@ def create_worker(cls: Callable[..., T], seed: int, is_driver_worker: bool = True, enforce_eager: bool = True, - model_runner_cls: Optional[ModelRunner] = None) -> T: + model_runner_cls: Optional[ModelRunner] = None, + dtype: Optional[str] = "auto") -> T: engine_args = EngineArgs( model=model_name, seed=seed, block_size=block_size, enforce_eager=enforce_eager, + dtype=dtype, ) engine_config = engine_args.create_engine_config() diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 17cc0ad1a4a3a..6330ac027db74 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -54,6 +54,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.indices_of_seq_with_bonus_tokens = None + def _update_sampling_metadata(self, sampling_metadata, num_seqs, num_queries): @@ -159,6 +161,10 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): # TODO: Add soft-tuning prompt adapter support return not self.prompt_adapter_config + def set_indices_of_seq_with_bonus_tokens(self, + indices_of_seq_with_bonus_tokens): + self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens + @torch.inference_mode() def execute_model( self, @@ -284,11 +290,30 @@ def execute_model( model_input.sampling_metadata) # Sample the next token. - outputs.append( - self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - )) + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + outputs.append(output) + + if model_input.attn_metadata.num_prefills == 0 \ + and self.indices_of_seq_with_bonus_tokens is not None: + assert output.sampled_token_ids is not None + # output.sampled_token_ids should be of shape (num_seqs, 1) + nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape + assert num_tokens_per_seq == 1 + count = 0 + for i in range(nums_seqs): + bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[ + count] + if i != bonus_seq_idx: + # The following might cause a cpu->gpu sync + # However, the performance impact is negligible as we + # benchmarked on H100. + output.sampled_token_ids[ + i, :] = model_input.input_tokens[bonus_seq_idx] + else: + count += 1 # Prepare inputs for the next step if step != num_steps - 1: diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 4b53fbe056c47..f49b98f5c9528 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -81,6 +81,8 @@ def sampler_output( # Here we run the draft_model_runner with multi-step prepare # on the GPU directly expanded_request.num_steps = sample_len + self.model_runner.set_indices_of_seq_with_bonus_tokens( + indices_of_seq_with_bonus_tokens) model_outputs = self.execute_model( execute_model_req=expanded_request) else: @@ -97,7 +99,8 @@ def sampler_output( model_output = model_output[0] self._append_new_tokens( - model_output, expanded_request.seq_group_metadata_list) + model_output, expanded_request.seq_group_metadata_list, + indices_of_seq_with_bonus_tokens) model_outputs.append(model_output) filtered_model_outputs = self._filter_model_output( @@ -221,13 +224,15 @@ def get_spec_proposals( @staticmethod def _append_new_tokens( model_output: List[SamplerOutput], - seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + seq_group_metadata_list: List[SequenceGroupMetadata], + indices_of_seq_with_bonus_tokens: List[int]) -> None: """Given model output from a single run, append the tokens to the sequences. This is normally done outside of the worker, but it is required if the worker is to perform multiple forward passes. """ - for seq_group_metadata, sequence_group_outputs in zip( - seq_group_metadata_list, model_output): + count = 0 + for index, (seq_group_metadata, sequence_group_outputs) in enumerate( + zip(seq_group_metadata_list, model_output)): seq_group_metadata.is_prompt = False for seq_output in sequence_group_outputs.samples: @@ -237,6 +242,16 @@ def _append_new_tokens( token_id = seq_output.output_token token_logprob = seq_output.logprobs[token_id] + # Determine the actual token ID to be generated, + # considering bonus tokens + if index != indices_of_seq_with_bonus_tokens[count]: + bonus_seq_metadata = seq_group_metadata_list[ + indices_of_seq_with_bonus_tokens[count]] + _, bonus_token_seq_data = next( + iter(bonus_seq_metadata.seq_data.items())) + token_id = bonus_token_seq_data.output_token_ids[-1] + else: + count += 1 seq.append_token_id(token_id, token_logprob.logprob) seq.update_num_computed_tokens(1)