Skip to content

Commit

Permalink
[Doc] Explicitly state that PP isn't compatible with speculative deco…
Browse files Browse the repository at this point in the history
…ding yet (#10975)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Dec 7, 2024
1 parent 39e227c commit c889d58
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 9 deletions.
3 changes: 3 additions & 0 deletions docs/source/usage/spec_decode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Speculative decoding
not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work
to optimize it is ongoing and can be followed in `this issue. <https://github.com/vllm-project/vllm/issues/4630>`_

.. warning::
Currently, speculative decoding in vLLM is not compatible with pipeline parallelism.

This document shows how to use `Speculative Decoding <https://x.com/karpathy/status/1697318534555336961>`_ with vLLM.
Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference.

Expand Down
16 changes: 13 additions & 3 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,19 @@ def _compare_tp(
*,
method: Literal["generate", "encode"],
):
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
multi_node_only, trust_remote_code, tokenizer_mode, \
load_format, hf_overrides = test_options
(
tp_size,
pp_size,
eager_mode,
chunked_prefill,
) = parallel_setup
(
multi_node_only,
trust_remote_code,
tokenizer_mode,
load_format,
hf_overrides,
) = test_options

if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()

self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)

Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,16 +400,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)

if hasattr(config, "logits_scaling"):
logit_scale /= config.logits_scaling

self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()

self.sampler = get_sampler()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()

self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()

self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()

self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down
4 changes: 4 additions & 0 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
speculative_config: SpeculativeConfig = vllm_config.speculative_config
assert speculative_config is not None

if vllm_config.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError("Speculative decoding is currently "
"incompatible with pipeline parallelism")

draft_worker_kwargs = kwargs.copy()

kwargs["model_runner_cls"] = TargetModelRunner
Expand Down

0 comments on commit c889d58

Please sign in to comment.