Skip to content

Commit

Permalink
[BUGFIX] Raise an error for no draft token case when draft_tp>1 (vllm…
Browse files Browse the repository at this point in the history
  • Loading branch information
wooyeonlee0 authored and jimpang committed Jul 24, 2024
1 parent b7962ea commit c4ac0f2
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 5 deletions.
62 changes: 62 additions & 0 deletions tests/spec_decode/e2e/test_integration_dist_tp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,65 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 4,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify job failure with RuntimeError when all sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
TODO: fix it to pass without raising Error. (#5814)
"""
with pytest.raises(RuntimeError):
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
3 changes: 3 additions & 0 deletions vllm/spec_decode/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class SpeculativeProposals:
# The valid length of each proposal; can be zero.
proposal_lens: torch.Tensor

# A flag to mark that there's no available proposals
no_proposals: bool = False

def __repr__(self):
return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids}, "
Expand Down
23 changes: 19 additions & 4 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def create_worker(
typical_acceptance_sampler_posterior_alpha: float,
) -> "SpecDecodeWorker":

allow_zero_draft_token_step = True
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
Expand All @@ -133,6 +134,8 @@ def create_worker(
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
allow_zero_draft_token_step = False
proposer_worker = MultiStepWorker(**draft_worker_kwargs)

proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
Expand All @@ -155,10 +158,12 @@ def create_worker(
logger.info("Configuring SpecDecodeWorker with sampler=%s",
type(spec_decode_sampler))

return SpecDecodeWorker(proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler)
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)

def __init__(
self,
Expand All @@ -167,6 +172,7 @@ def __init__(
spec_decode_sampler: SpecDecodeBaseSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
):
"""
Create a SpecDecodeWorker.
Expand All @@ -187,11 +193,15 @@ def __init__(
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector
Expand Down Expand Up @@ -461,6 +471,11 @@ def _run_speculative_decoding_step(
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)

if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")

proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
Expand Down
2 changes: 1 addition & 1 deletion vllm/spec_decode/top1_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_spec_proposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
no_proposals=maybe_sampler_output is None)

return proposals

Expand Down

0 comments on commit c4ac0f2

Please sign in to comment.