Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] Raise an error for no draft token case when draft_tp>1 #6369

Merged
merged 13 commits into from
Jul 19, 2024
24 changes: 20 additions & 4 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def create_worker(
typical_acceptance_sampler_posterior_alpha: float,
) -> "SpecDecodeWorker":

allow_zero_draft_token_step = True
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
Expand All @@ -133,6 +134,8 @@ def create_worker(
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
allow_zero_draft_token_step = False
proposer_worker = MultiStepWorker(**draft_worker_kwargs)

proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
Expand All @@ -155,10 +158,12 @@ def create_worker(
logger.info("Configuring SpecDecodeWorker with sampler=%s",
type(spec_decode_sampler))

return SpecDecodeWorker(proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler)
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)

def __init__(
self,
Expand All @@ -167,6 +172,7 @@ def __init__(
spec_decode_sampler: SpecDecodeBaseSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
):
"""
Create a SpecDecodeWorker.
Expand All @@ -187,11 +193,15 @@ def __init__(
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self.allow_no_draft_tokens = allow_zero_draft_token_step
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we mark this private, e.g. _allow_no_draft_tokens? we should have done this for all properties but we missed it

self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector
Expand Down Expand Up @@ -461,6 +471,12 @@ def _run_speculative_decoding_step(
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)

if not self.allow_no_draft_tokens and sum(
proposals.proposal_lens) == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worry about the overhead that sum brings, but I feel it's fine for now given that it won't be triggered with draft model TP=1. wdyt @cadedaniel?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, is it possible to store this field in proposals when we create proposals? that way we don't need an additional CPU-GPU-CPU sync

#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")

proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
Expand Down
Loading