From acbb8c8a9665f934bcea91a416f41e27bb40d630 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 7 Dec 2024 02:28:45 +0000 Subject: [PATCH] Fix PP with speculative decoding Signed-off-by: DarkLight1337 --- tests/distributed/test_pipeline_parallel.py | 32 +++++++++++++++++++-- vllm/model_executor/models/exaone.py | 10 +++---- vllm/model_executor/models/granite.py | 17 +++++------ vllm/model_executor/models/llama.py | 11 ++++--- vllm/model_executor/models/nemotron.py | 11 +++---- vllm/model_executor/models/solar.py | 10 +++---- 6 files changed, 59 insertions(+), 32 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 386877e0e0a2c..eb3382439d969 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -26,6 +26,8 @@ class ParallelSetup(NamedTuple): pp_size: int eager_mode: bool chunked_prefill: bool + speculative_model: Optional[str] = None + num_speculative_tokens: Optional[int] = None class PPTestOptions(NamedTuple): @@ -77,6 +79,12 @@ def detailed( pp_size=pp_base, eager_mode=True, chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + eager_mode=False, + chunked_prefill=False, + speculative_model="ngram", + num_speculative_tokens=5), ], distributed_backends=["mp", "ray"], task=task, @@ -247,9 +255,21 @@ def _compare_tp( *, method: Literal["generate", "encode"], ): - tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup - multi_node_only, trust_remote_code, tokenizer_mode, \ - load_format, hf_overrides = test_options + ( + tp_size, + pp_size, + eager_mode, + chunked_prefill, + speculative_model, + num_speculative_tokens, + ) = parallel_setup + ( + multi_node_only, + trust_remote_code, + tokenizer_mode, + load_format, + hf_overrides, + ) = test_options if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") @@ -282,6 +302,12 @@ def _compare_tp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", hf_overrides]) + if speculative_model: + common_args.extend(["--speculative-model", speculative_model]) + if num_speculative_tokens: + common_args.extend( + ["--num-speculative-tokens", + str(num_speculative_tokens)]) if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2 and chunked_prefill): diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 5ca26d53a17e7..263ee17e7373e 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -469,14 +469,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index bd2394e71c973..81f479681243b 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -399,17 +399,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - logit_scale = getattr(config, "logit_scale", 1.0) - - if hasattr(config, "logits_scaling"): - logit_scale /= config.logits_scaling - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + logit_scale = getattr(config, "logit_scale", 1.0) + + if hasattr(config, "logits_scaling"): + logit_scale /= config.logits_scaling + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=logit_scale) + self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 31dfb235ae877..9062c3c313ed5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -535,15 +535,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights( self.model.embed_tokens) - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index c7b4c22b6896b..94ca6c4caf1a5 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -431,13 +431,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index f58710d215056..aa3c031053c26 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -439,14 +439,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors)