-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix][SpecDecode] kv corruption with bonus tokens in spec decode #9730
Changes from 9 commits
9f3e448
6ea3190
aad1c31
2ad2bde
9a024f4
da47139
7e845be
543fc8c
9dc2871
6dd352e
96c7026
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -303,6 +303,7 @@ def test_multi_step_with_batch_expansion_correct_output(): | |
seed, | ||
model_runner_cls=TP1DraftModelRunner, | ||
) | ||
multi_step_worker.set_include_gpu_probs_tensor() | ||
worker = create_worker( | ||
Worker, | ||
model_name, | ||
|
@@ -397,6 +398,7 @@ def test_multi_step_with_batch_expansion_incorrect_output(): | |
seed, | ||
model_runner_cls=TP1DraftModelRunner, | ||
) | ||
multi_step_worker.set_include_gpu_probs_tensor() | ||
worker = create_worker( | ||
Worker, | ||
model_name, | ||
|
@@ -477,6 +479,104 @@ def test_multi_step_with_batch_expansion_incorrect_output(): | |
assert (num_mismatch > 0) | ||
|
||
|
||
@torch.inference_mode() | ||
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4]) | ||
def test_multi_step_correct_kvcache(num_steps): | ||
"""Verify that the KV cache of the draft model | ||
is correctly updated for sequences with bonus | ||
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
seed = 100 | ||
model_name = "JackFram/llama-68m" | ||
|
||
block_size = 16 | ||
num_gpu_blocks = 2048 // block_size | ||
batch_size = 1 | ||
multi_step_worker = create_worker( | ||
MultiStepWorker, | ||
model_name, | ||
block_size, | ||
num_gpu_blocks, | ||
seed, | ||
model_runner_cls=TP1DraftModelRunner, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don't set model_runner_cls will this test the non TP1DraftModelRunner case? If so we could parameterize the test to run with both model_runner_cls set and unset and test both changes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah actually, I found that for the CI machine, since the backend is XFORMERS instead of FlASH_ATTN, it will go to the else branch here by default because |
||
) | ||
multi_step_worker.set_include_gpu_probs_tensor() | ||
worker = create_worker( | ||
Worker, | ||
model_name, | ||
block_size, | ||
num_gpu_blocks, | ||
seed, | ||
) | ||
|
||
prompts = [[0] for _ in range(batch_size)] | ||
# Already generate two tokens for the sequence | ||
# so that we can simulate the bonus token case | ||
multi_step_continuations = [[ | ||
random.randint(0, 1000), | ||
random.randint(0, 1000) | ||
] for _ in prompts] | ||
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts] | ||
|
||
seq_ids_with_bonus_token_in_last_step = set(range(batch_size)) | ||
seq_group_metadata_list = create_seq_group_metadata_from_prompts( | ||
prompts, | ||
num_gpu_blocks, | ||
block_size, | ||
continuations=multi_step_continuations, | ||
final_prompt_lens=final_prompt_lens) | ||
|
||
# Run multi-step. | ||
zero_kv_cache(multi_step_worker.cache_engine) | ||
multi_step_output, _ = multi_step_worker.sampler_output( | ||
execute_model_req=ExecuteModelRequest( | ||
seq_group_metadata_list=seq_group_metadata_list), | ||
sample_len=num_steps, | ||
seq_ids_with_bonus_token_in_last_step= | ||
seq_ids_with_bonus_token_in_last_step) | ||
|
||
# Run single-step repeatedly. | ||
zero_kv_cache(worker.cache_engine) | ||
# Generate the kv cache for the bonus token first | ||
single_step_continuations = [c[:1] for c in multi_step_continuations] | ||
seq_group_metadata_list = create_seq_group_metadata_from_prompts( | ||
prompts, | ||
num_gpu_blocks, | ||
block_size, | ||
continuations=single_step_continuations, | ||
final_prompt_lens=final_prompt_lens) | ||
single_step_output = worker.execute_model( | ||
execute_model_req=ExecuteModelRequest( | ||
seq_group_metadata_list=seq_group_metadata_list)) | ||
for _ in range(num_steps): | ||
seq_group_metadata_list = create_seq_group_metadata_from_prompts( | ||
prompts, | ||
num_gpu_blocks, | ||
block_size, | ||
continuations=multi_step_continuations, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit - Would this be the same as continuations=single_step_continuations? If so could be do that here and below for readability? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'multi_step_continuations' has one more token for each request. For example |
||
final_prompt_lens=final_prompt_lens) | ||
|
||
single_step_output = worker.execute_model( | ||
execute_model_req=ExecuteModelRequest( | ||
seq_group_metadata_list=seq_group_metadata_list)) | ||
|
||
for i, seq_group_output in enumerate(single_step_output[-1]): | ||
multi_step_continuations[i].append( | ||
seq_group_output.samples[0].output_token) | ||
|
||
# Verify that the KV cache of the single-step and | ||
# multi-step workers are the same. | ||
single_step_gpu_cache = worker.cache_engine[0].gpu_cache | ||
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache | ||
num_layers = len(single_step_gpu_cache) | ||
allclose = lambda a, b: torch.allclose( | ||
a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2) | ||
for i in range(num_layers): | ||
assert allclose(single_step_gpu_cache[i][0], | ||
multi_step_gpu_cache[i][0]) | ||
assert allclose(single_step_gpu_cache[i][1], | ||
multi_step_gpu_cache[i][1]) | ||
|
||
|
||
@torch.inference_mode() | ||
def test_draft_proposals_full_speculation_len(): | ||
"""Verify Top1Proposer correctly handles case where all sequences | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,8 @@ def __init__(self, *args, **kwargs): | |
|
||
super().__init__(*args, **kwargs) | ||
|
||
self.indices_of_seq_with_bonus_tokens = None | ||
|
||
def _update_sampling_metadata(self, sampling_metadata, num_seqs, | ||
num_queries): | ||
|
||
|
@@ -159,6 +161,10 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): | |
# TODO: Add soft-tuning prompt adapter support | ||
return not self.prompt_adapter_config | ||
|
||
def set_indices_of_seq_with_bonus_tokens(self, | ||
indices_of_seq_with_bonus_tokens): | ||
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens | ||
|
||
@torch.inference_mode() | ||
def execute_model( | ||
self, | ||
|
@@ -284,11 +290,28 @@ def execute_model( | |
model_input.sampling_metadata) | ||
|
||
# Sample the next token. | ||
outputs.append( | ||
self.model.sample( | ||
logits=logits, | ||
sampling_metadata=model_input.sampling_metadata, | ||
)) | ||
output = self.model.sample( | ||
logits=logits, | ||
sampling_metadata=model_input.sampling_metadata, | ||
) | ||
outputs.append(output) | ||
|
||
if model_input.attn_metadata.num_prefills == 0 \ | ||
and self.indices_of_seq_with_bonus_tokens is not None: | ||
assert output.sampled_token_ids is not None | ||
# output.sampled_token_ids should be of shape (num_seqs, 1) | ||
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape | ||
assert num_tokens_per_seq == 1 | ||
count = 0 | ||
for i in range(nums_seqs): | ||
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[ | ||
count] | ||
if i != bonus_seq_idx: | ||
# The following might cause a cpu->gpu sync | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case aren't both tensors are on GPU ? Will this still cause a sync? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I benchmarked this with 160m + 7B on H100, this has neglectable performance difference. Before/After this PR, the proposal time are almost the same. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: May be update the comment to say that benchmark shows comparable performance? |
||
output.sampled_token_ids[ | ||
i, :] = model_input.input_tokens[bonus_seq_idx] | ||
else: | ||
count += 1 | ||
|
||
# Prepare inputs for the next step | ||
if step != num_steps - 1: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,8 @@ def sampler_output( | |
# Here we run the draft_model_runner with multi-step prepare | ||
# on the GPU directly | ||
expanded_request.num_steps = sample_len | ||
self.model_runner.set_indices_of_seq_with_bonus_tokens( | ||
indices_of_seq_with_bonus_tokens) | ||
model_outputs = self.execute_model( | ||
execute_model_req=expanded_request) | ||
else: | ||
|
@@ -97,7 +99,8 @@ def sampler_output( | |
model_output = model_output[0] | ||
|
||
self._append_new_tokens( | ||
model_output, expanded_request.seq_group_metadata_list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently as you mentioned in the pr description this does not take care of the TP1DraftModelRunner case. I was wondering if we should do the first run with the expanded request and after that the subsequent ones with the original request. That way both cases will be taken care of? However in TP1DraftModelRunner we will no longer have the multi-step run for all K steps now (it will be 1 step first with expanded request and K-1 multistep). cc: @LiuXiaoxuanPKU |
||
model_output, expanded_request.seq_group_metadata_list, | ||
indices_of_seq_with_bonus_tokens) | ||
model_outputs.append(model_output) | ||
|
||
filtered_model_outputs = self._filter_model_output( | ||
|
@@ -221,13 +224,15 @@ def get_spec_proposals( | |
@staticmethod | ||
def _append_new_tokens( | ||
model_output: List[SamplerOutput], | ||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: | ||
seq_group_metadata_list: List[SequenceGroupMetadata], | ||
indices_of_seq_with_bonus_tokens: List[int]) -> None: | ||
"""Given model output from a single run, append the tokens to the | ||
sequences. This is normally done outside of the worker, but it is | ||
required if the worker is to perform multiple forward passes. | ||
""" | ||
for seq_group_metadata, sequence_group_outputs in zip( | ||
seq_group_metadata_list, model_output): | ||
count = 0 | ||
for index, (seq_group_metadata, sequence_group_outputs) in enumerate( | ||
zip(seq_group_metadata_list, model_output)): | ||
seq_group_metadata.is_prompt = False | ||
|
||
for seq_output in sequence_group_outputs.samples: | ||
|
@@ -237,6 +242,16 @@ def _append_new_tokens( | |
|
||
token_id = seq_output.output_token | ||
token_logprob = seq_output.logprobs[token_id] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing token_logprob here is challenging because the log probability of the bonus token is required. I'm unsure if this change is actually necessary at this point, becuase I think filtered_model_output will be return at the end. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will we be using this logprob at all since we will be filtering out these sequences in the final output right? In that case we can add a note saying that this is a fake/unused logprob? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I think the logprobs of the sequence without the bonus token will likely not be used. I will check this precisely and add a note about it. |
||
# Determine the actual token ID to be generated, | ||
# considering bonus tokens | ||
if index != indices_of_seq_with_bonus_tokens[count]: | ||
bonus_seq_metadata = seq_group_metadata_list[ | ||
indices_of_seq_with_bonus_tokens[count]] | ||
_, bonus_token_seq_data = next( | ||
iter(bonus_seq_metadata.seq_data.items())) | ||
token_id = bonus_token_seq_data.output_token_ids[-1] | ||
else: | ||
count += 1 | ||
|
||
seq.append_token_id(token_id, token_logprob.logprob) | ||
seq.update_num_computed_tokens(1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we changed the local of draft model runner, it requires the sampler output is not None here.