diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index d109d8edc1b0b..818302a68ec9f 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -22,6 +22,9 @@ class SpeculativeProposals: # The valid length of each proposal; can be zero. proposal_lens: torch.Tensor + # A flag to mark that there's no available proposals + no_proposals: bool + def __repr__(self): return (f"SpeculativeProposals(" f"proposal_token_ids={self.proposal_token_ids}, " diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index c7931e5d50197..d9e775c9ddd7f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -201,7 +201,7 @@ def __init__( self.scorer_worker = scorer_worker self.disable_by_batch_size = disable_by_batch_size or float("inf") self.spec_decode_sampler = spec_decode_sampler - self.allow_no_draft_tokens = allow_zero_draft_token_step + self._allow_zero_draft_token_step = allow_zero_draft_token_step self._metrics = AsyncMetricsCollector( self.spec_decode_sampler ) if metrics_collector is None else metrics_collector @@ -471,8 +471,7 @@ def _run_speculative_decoding_step( proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) - if not self.allow_no_draft_tokens and sum( - proposals.proposal_lens) == 0: + if not self._allow_zero_draft_token_step and proposals.no_proposals: #TODO: Fix it #5814 raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 7b34b5d34208b..5628ebabc349e 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -108,6 +108,7 @@ def get_spec_proposals( proposal_token_ids=proposal_tokens, proposal_probs=proposal_probs, proposal_lens=proposal_lens, + no_proposals=maybe_sampler_output is None ) return proposals