Skip to content

Commit

Permalink
Disable spec-decode + chunked-prefill for draft models with tensor pa…
Browse files Browse the repository at this point in the history
…rallelism > 1 (vllm-project#10136)

Signed-off-by: Sourashis Roy <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
sroy745 authored and tlrmchlsmth committed Nov 23, 2024
1 parent d63b2ab commit 461c69b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 8 deletions.
46 changes: 46 additions & 0 deletions tests/spec_decode/e2e/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,49 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)


@pytest.mark.parametrize("common_llm_kwargs",
[{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": "True",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"tensor_parallel_size": 2,
"speculative_draft_tensor_parallel_size": 2,
},
{
"tensor_parallel_size": 4,
"speculative_draft_tensor_parallel_size": 4,
},
{
"tensor_parallel_size": 8,
"speculative_draft_tensor_parallel_size": 8,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one(
test_llm_generator):
"""Verify that speculative decoding fails if chunked prefill is enabled for
draft model with tensor parallelism of more than 1.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

with pytest.raises(ValueError, match="with tensor parallel size 1"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
45 changes: 37 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,23 @@ def maybe_create_spec_config(
"Chunked prefill and hidden-state based draft models are "
"not compatible.")

speculative_draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
target_parallel_config,
speculative_draft_tensor_parallel_size,
draft_hf_config
)

if (enable_chunked_prefill and \
speculative_draft_tensor_parallel_size != 1):
# TODO - Investigate why the error reported in
# https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258
# is happening and re-enable it.
raise ValueError(
"Chunked prefill and speculative decoding can be enabled "
"simultaneously only for draft models with tensor "
"parallel size 1.")

draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
Expand Down Expand Up @@ -1466,15 +1483,16 @@ def _maybe_override_draft_max_model_len(
)

@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
def _verify_and_get_draft_model_tensor_parallel_size(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig) -> int:
"""
Verifies and adjusts the tensor parallel size for a draft model
specified using speculative_draft_tensor_parallel_size.
"""
# If speculative_draft_tensor_parallel_size is unset then set it
# appropriately else verify that it is set correctly.
if speculative_draft_tensor_parallel_size is None:
if draft_hf_config.model_type == "mlp_speculator":
speculative_draft_tensor_parallel_size = 1
Expand All @@ -1490,7 +1508,18 @@ def create_draft_parallel_config(
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1 or target model tensor_parallel_size")
return speculative_draft_tensor_parallel_size

@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
draft_hf_config: PretrainedConfig,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
Expand Down

0 comments on commit 461c69b

Please sign in to comment.