-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUGFIX] Raise an error for no draft token case when draft_tp>1 #6369
Changes from 6 commits
dadfa82
ad8390c
a934b12
e93781d
10e6441
6edf8fc
5398d7a
a70ccc9
c4b6f72
02dc475
c2382b5
fa1f463
4b338d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,6 +109,7 @@ def create_worker( | |
typical_acceptance_sampler_posterior_alpha: float, | ||
) -> "SpecDecodeWorker": | ||
|
||
allow_zero_draft_token_step = True | ||
ngram_prompt_lookup_max = ( | ||
draft_worker_kwargs.pop("ngram_prompt_lookup_max")) | ||
ngram_prompt_lookup_min = ( | ||
|
@@ -133,6 +134,8 @@ def create_worker( | |
if draft_tp == 1: | ||
draft_worker_kwargs[ | ||
"model_runner_cls"] = TP1DraftModelRunner | ||
else: | ||
allow_zero_draft_token_step = False | ||
proposer_worker = MultiStepWorker(**draft_worker_kwargs) | ||
|
||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( | ||
|
@@ -155,10 +158,12 @@ def create_worker( | |
logger.info("Configuring SpecDecodeWorker with sampler=%s", | ||
type(spec_decode_sampler)) | ||
|
||
return SpecDecodeWorker(proposer_worker, | ||
scorer_worker, | ||
disable_by_batch_size=disable_by_batch_size, | ||
spec_decode_sampler=spec_decode_sampler) | ||
return SpecDecodeWorker( | ||
proposer_worker, | ||
scorer_worker, | ||
disable_by_batch_size=disable_by_batch_size, | ||
spec_decode_sampler=spec_decode_sampler, | ||
allow_zero_draft_token_step=allow_zero_draft_token_step) | ||
|
||
def __init__( | ||
self, | ||
|
@@ -167,6 +172,7 @@ def __init__( | |
spec_decode_sampler: SpecDecodeBaseSampler, | ||
metrics_collector: Optional[AsyncMetricsCollector] = None, | ||
disable_by_batch_size: Optional[int] = None, | ||
allow_zero_draft_token_step: Optional[bool] = True, | ||
): | ||
""" | ||
Create a SpecDecodeWorker. | ||
|
@@ -187,11 +193,15 @@ def __init__( | |
disable speculative decoding for new incoming requests. | ||
metrics_collector: Helper class for collecting metrics; can be set | ||
for testing purposes. | ||
allow_zero_draft_token_step: whether to allow a step where the draft | ||
model generates no draft token; should disallow when the tp of | ||
draft model is larger than 1 (TODO: #5814) | ||
""" | ||
self.proposer_worker = proposer_worker | ||
self.scorer_worker = scorer_worker | ||
self.disable_by_batch_size = disable_by_batch_size or float("inf") | ||
self.spec_decode_sampler = spec_decode_sampler | ||
self.allow_no_draft_tokens = allow_zero_draft_token_step | ||
self._metrics = AsyncMetricsCollector( | ||
self.spec_decode_sampler | ||
) if metrics_collector is None else metrics_collector | ||
|
@@ -461,6 +471,12 @@ def _run_speculative_decoding_step( | |
proposals = self.proposer_worker.get_spec_proposals( | ||
execute_model_req, self._seq_with_bonus_token_in_last_step) | ||
|
||
if not self.allow_no_draft_tokens and sum( | ||
proposals.proposal_lens) == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit worry about the overhead that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, is it possible to store this field in |
||
#TODO: Fix it #5814 | ||
raise RuntimeError("Cannot handle cases where distributed draft " | ||
"workers generate no tokens") | ||
|
||
proposal_scores = self.scorer.score_proposals( | ||
execute_model_req, | ||
proposals, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we mark this private, e.g.
_allow_no_draft_tokens
? we should have done this for all properties but we missed it