Skip to content

Commit

Permalink
Fix PP with speculative decoding
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 7, 2024
1 parent 571da8f commit acbb8c8
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 32 deletions.
32 changes: 29 additions & 3 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class ParallelSetup(NamedTuple):
pp_size: int
eager_mode: bool
chunked_prefill: bool
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None


class PPTestOptions(NamedTuple):
Expand Down Expand Up @@ -77,6 +79,12 @@ def detailed(
pp_size=pp_base,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
eager_mode=False,
chunked_prefill=False,
speculative_model="ngram",
num_speculative_tokens=5),
],
distributed_backends=["mp", "ray"],
task=task,
Expand Down Expand Up @@ -247,9 +255,21 @@ def _compare_tp(
*,
method: Literal["generate", "encode"],
):
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
multi_node_only, trust_remote_code, tokenizer_mode, \
load_format, hf_overrides = test_options
(
tp_size,
pp_size,
eager_mode,
chunked_prefill,
speculative_model,
num_speculative_tokens,
) = parallel_setup
(
multi_node_only,
trust_remote_code,
tokenizer_mode,
load_format,
hf_overrides,
) = test_options

if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
Expand Down Expand Up @@ -282,6 +302,12 @@ def _compare_tp(
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", hf_overrides])
if speculative_model:
common_args.extend(["--speculative-model", speculative_model])
if num_speculative_tokens:
common_args.extend(
["--num-speculative-tokens",
str(num_speculative_tokens)])

if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
and chunked_prefill):
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,14 +469,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if config.tie_word_embeddings:
self.lm_head.weight = self.transformer.wte.weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)

Expand Down
17 changes: 9 additions & 8 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,17 +399,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)

if hasattr(config, "logits_scaling"):
logit_scale /= config.logits_scaling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()

logit_scale = getattr(config, "logit_scale", 1.0)

if hasattr(config, "logits_scaling"):
logit_scale /= config.logits_scaling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=logit_scale)
self.sampler = get_sampler()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand Down
11 changes: 5 additions & 6 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,15 +535,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down
11 changes: 6 additions & 5 deletions vllm/model_executor/models/nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,13 +431,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,14 +439,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down

0 comments on commit acbb8c8

Please sign in to comment.