From dadfa823a920c786d297dadf298aa73a07dad9b4 Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Fri, 12 Jul 2024 18:00:27 +0900 Subject: [PATCH 01/13] fix it --- vllm/spec_decode/spec_decode_worker.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3c8e3dee46831..a57aa3e416341 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -158,7 +158,8 @@ def create_worker( return SpecDecodeWorker(proposer_worker, scorer_worker, disable_by_batch_size=disable_by_batch_size, - spec_decode_sampler=spec_decode_sampler) + spec_decode_sampler=spec_decode_sampler, + allow_zero_draft_token_step=draft_tp == 1) def __init__( self, @@ -167,6 +168,7 @@ def __init__( spec_decode_sampler: SpecDecodeBaseSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, + allow_zero_draft_token_step: Optional[bool] = True, ): """ Create a SpecDecodeWorker. @@ -187,11 +189,15 @@ def __init__( disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set for testing purposes. + allow_zero_draft_token_step: whether to allow a step where the draft + model generates no draft token; should disallow when the tp of + draft model is larger than 1 (TODO: #5814) """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker self.disable_by_batch_size = disable_by_batch_size or float("inf") self.spec_decode_sampler = spec_decode_sampler + self.allow_no_draft_tokens = allow_zero_draft_token_step self._metrics = AsyncMetricsCollector( self.spec_decode_sampler ) if metrics_collector is None else metrics_collector @@ -461,6 +467,11 @@ def _run_speculative_decoding_step( proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) + if not self.allow_no_draft_tokens and sum(proposals.proposal_lens) == 0: + #TODO: Fix it #5814 + raise RuntimeError("Distributed draft worker cannot handle when " + "there's no draft tokens") + proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, From ad8390c4b587c3eb7476457071e94a29c91373cd Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Fri, 12 Jul 2024 18:19:30 +0900 Subject: [PATCH 02/13] yapf --- vllm/spec_decode/spec_decode_worker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index a57aa3e416341..56bc42e8cefef 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -467,10 +467,11 @@ def _run_speculative_decoding_step( proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) - if not self.allow_no_draft_tokens and sum(proposals.proposal_lens) == 0: + if not self.allow_no_draft_tokens and sum( + proposals.proposal_lens) == 0: #TODO: Fix it #5814 raise RuntimeError("Distributed draft worker cannot handle when " - "there's no draft tokens") + "there's no draft tokens") proposal_scores = self.scorer.score_proposals( execute_model_req, From a934b12797d49c57c6ad9a1289b9350f70b032e1 Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Fri, 12 Jul 2024 21:52:22 +0900 Subject: [PATCH 03/13] fix --- vllm/spec_decode/spec_decode_worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 56bc42e8cefef..907094024f601 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -109,6 +109,7 @@ def create_worker( typical_acceptance_sampler_posterior_alpha: float, ) -> "SpecDecodeWorker": + allow_zero_draft_token_step = False ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -133,6 +134,7 @@ def create_worker( if draft_tp == 1: draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner + allow_zero_draft_token_step = True proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( @@ -159,7 +161,7 @@ def create_worker( scorer_worker, disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, - allow_zero_draft_token_step=draft_tp == 1) + allow_zero_draft_token_step=allow_zero_draft_token_step) def __init__( self, From e93781d6a096254a866096236fae3c6eccfb9a8f Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Fri, 12 Jul 2024 21:57:05 +0900 Subject: [PATCH 04/13] allow zero token step for other cases --- vllm/spec_decode/spec_decode_worker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 907094024f601..0cf5dddb69ca6 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -109,7 +109,7 @@ def create_worker( typical_acceptance_sampler_posterior_alpha: float, ) -> "SpecDecodeWorker": - allow_zero_draft_token_step = False + allow_zero_draft_token_step = True ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -134,7 +134,8 @@ def create_worker( if draft_tp == 1: draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner - allow_zero_draft_token_step = True + else: + allow_zero_draft_token_step = False proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( From 10e644180a3b3b7a2597ea02ff2695f1e0fd5253 Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Fri, 12 Jul 2024 22:03:32 +0900 Subject: [PATCH 05/13] update comment --- vllm/spec_decode/spec_decode_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 0cf5dddb69ca6..eb8a72ebc0bbe 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -473,8 +473,8 @@ def _run_speculative_decoding_step( if not self.allow_no_draft_tokens and sum( proposals.proposal_lens) == 0: #TODO: Fix it #5814 - raise RuntimeError("Distributed draft worker cannot handle when " - "there's no draft tokens") + raise RuntimeError("Cannot handle cases where distributed draft " + "workers generate no tokens") proposal_scores = self.scorer.score_proposals( execute_model_req, From 6edf8fc114cacecfe67f3e206a0a353d20766dbf Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Fri, 12 Jul 2024 23:21:04 +0900 Subject: [PATCH 06/13] yapf --- vllm/spec_decode/spec_decode_worker.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index eb8a72ebc0bbe..c7931e5d50197 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -158,11 +158,12 @@ def create_worker( logger.info("Configuring SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) - return SpecDecodeWorker(proposer_worker, - scorer_worker, - disable_by_batch_size=disable_by_batch_size, - spec_decode_sampler=spec_decode_sampler, - allow_zero_draft_token_step=allow_zero_draft_token_step) + return SpecDecodeWorker( + proposer_worker, + scorer_worker, + disable_by_batch_size=disable_by_batch_size, + spec_decode_sampler=spec_decode_sampler, + allow_zero_draft_token_step=allow_zero_draft_token_step) def __init__( self, From 5398d7ab0f16d514492829de226e50693a8623bf Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Thu, 27 Jun 2024 10:39:40 +0900 Subject: [PATCH 07/13] test_skip_speculation --- .../e2e/test_integration_dist_tp4.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index 56cb0147d9e4f..10ee57084c876 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -58,3 +58,62 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator, batch_size, max_output_len=32, force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "tensor_parallel_size": 4, + + # Use AsyncLLM engine, so that the engine runs in its own process. + # Otherwise, since vLLM does not follow true SPMD, the test runner + # process will have both the engine and the rank0 worker. NCCL is not + # cleaned up properly, and its server host thread leaks, causing the + # second run of the test to fail with internal NCCL error. + "use_async": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Artificially limit the draft model max model len; this forces vLLM + # to skip speculation once the sequences grow beyond 32-k tokens. + "speculative_max_model_len": 32, + + }, + ]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # This must be a good bit larger than speculative_max_model_len so that + # we can test the case where all seqs are skipped, but still small to + # ensure fast test. + 64, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_skip_speculation(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when some (or all) sequences skip speculation. + We do this by setting the max model len of the draft model to an + artificially low value, such that when the sequences grow beyond it, they + are skipped in speculative decoding. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + From a70ccc93290b5ce49ca33b83603ab0c896bba67f Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Mon, 15 Jul 2024 10:20:04 +0900 Subject: [PATCH 08/13] error on test_skip_spec --- tests/spec_decode/e2e/test_integration_dist_tp4.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index 10ee57084c876..4ca22df4e7a41 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -111,9 +111,10 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, artificially low value, such that when the sequences grow beyond it, they are skipped in speculative decoding. """ - run_greedy_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len=output_len, - force_output_len=True) + with pytest.raises(RuntimeError): + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) From c4b6f721fd4df965ff6d051b7d7225fd5e939f5d Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Mon, 15 Jul 2024 10:36:28 +0900 Subject: [PATCH 09/13] add comment --- tests/spec_decode/e2e/test_integration_dist_tp4.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index 4ca22df4e7a41..f0d95b99fb61c 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -91,7 +91,6 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator, # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. "speculative_max_model_len": 32, - }, ]) @pytest.mark.parametrize("batch_size", [8]) @@ -106,10 +105,12 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator, @pytest.mark.parametrize("seed", [1]) def test_skip_speculation(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): - """Verify greedy equality when some (or all) sequences skip speculation. + """Verify job failure with RuntimeError when all sequences skip speculation. We do this by setting the max model len of the draft model to an artificially low value, such that when the sequences grow beyond it, they are skipped in speculative decoding. + + TODO: fix it to pass without raising Error. (#5814) """ with pytest.raises(RuntimeError): run_greedy_equality_correctness_test(baseline_llm_generator, From 02dc475cab3f817458b8d0791adc0750615cabb7 Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Mon, 15 Jul 2024 10:38:52 +0900 Subject: [PATCH 10/13] yapf --- tests/spec_decode/e2e/test_integration_dist_tp4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index f0d95b99fb61c..a720e75717ea3 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -118,4 +118,3 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, batch_size, max_output_len=output_len, force_output_len=True) - From c2382b516cf8055c4c0d580fef9296dd5a409d0a Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Mon, 15 Jul 2024 18:47:45 +0900 Subject: [PATCH 11/13] mark. need 4 gpus --- tests/spec_decode/e2e/test_integration_dist_tp4.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index a720e75717ea3..49e4a5f8150b5 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -60,6 +60,8 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator, force_output_len=True) +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") @pytest.mark.parametrize( "common_llm_kwargs", [{ From fa1f463ff279d1148164ac2f4715400519208d7d Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Thu, 18 Jul 2024 12:24:50 +0900 Subject: [PATCH 12/13] no_proposals flag in SpeculativeProposals --- vllm/spec_decode/interfaces.py | 3 +++ vllm/spec_decode/spec_decode_worker.py | 5 ++--- vllm/spec_decode/top1_proposer.py | 1 + 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index d109d8edc1b0b..818302a68ec9f 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -22,6 +22,9 @@ class SpeculativeProposals: # The valid length of each proposal; can be zero. proposal_lens: torch.Tensor + # A flag to mark that there's no available proposals + no_proposals: bool + def __repr__(self): return (f"SpeculativeProposals(" f"proposal_token_ids={self.proposal_token_ids}, " diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index c7931e5d50197..d9e775c9ddd7f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -201,7 +201,7 @@ def __init__( self.scorer_worker = scorer_worker self.disable_by_batch_size = disable_by_batch_size or float("inf") self.spec_decode_sampler = spec_decode_sampler - self.allow_no_draft_tokens = allow_zero_draft_token_step + self._allow_zero_draft_token_step = allow_zero_draft_token_step self._metrics = AsyncMetricsCollector( self.spec_decode_sampler ) if metrics_collector is None else metrics_collector @@ -471,8 +471,7 @@ def _run_speculative_decoding_step( proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) - if not self.allow_no_draft_tokens and sum( - proposals.proposal_lens) == 0: + if not self._allow_zero_draft_token_step and proposals.no_proposals: #TODO: Fix it #5814 raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 7b34b5d34208b..5628ebabc349e 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -108,6 +108,7 @@ def get_spec_proposals( proposal_token_ids=proposal_tokens, proposal_probs=proposal_probs, proposal_lens=proposal_lens, + no_proposals=maybe_sampler_output is None ) return proposals From 4b338d22c272bc8ea97495bf1a7073e4f323b04e Mon Sep 17 00:00:00 2001 From: Wooyeon Lee Date: Thu, 18 Jul 2024 13:25:06 +0900 Subject: [PATCH 13/13] mypy yapf --- vllm/spec_decode/interfaces.py | 2 +- vllm/spec_decode/top1_proposer.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 818302a68ec9f..11ab09f10c1f5 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -23,7 +23,7 @@ class SpeculativeProposals: proposal_lens: torch.Tensor # A flag to mark that there's no available proposals - no_proposals: bool + no_proposals: bool = False def __repr__(self): return (f"SpeculativeProposals(" diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 5628ebabc349e..59257f7a61a4d 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -108,8 +108,7 @@ def get_spec_proposals( proposal_token_ids=proposal_tokens, proposal_probs=proposal_probs, proposal_lens=proposal_lens, - no_proposals=maybe_sampler_output is None - ) + no_proposals=maybe_sampler_output is None) return proposals