Skip to content

Commit 973617a

Browse files
comaniaccadedaniel
andauthored
[Speculative decoding][Re-take] Enable TP>1 speculative decoding (vllm-project#4840)
Co-authored-by: Cade Daniel <[email protected]> Co-authored-by: Cade Daniel <[email protected]>
1 parent 30e7543 commit 973617a

12 files changed

+297
-182
lines changed

.buildkite/test-pipeline.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ steps:
4242
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
4343
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
4444
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
45+
- pytest -v -s spec_decode/e2e/test_integration_dist.py
4546

4647
- label: Distributed Tests (Multiple Groups)
4748
working_dir: "/vllm-workspace/tests"

benchmarks/benchmark_latency.py

+6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def main(args: argparse.Namespace):
1818
# NOTE(woosuk): If the request cannot be processed in a single batch,
1919
# the engine will automatically process the request in multiple batches.
2020
llm = LLM(model=args.model,
21+
speculative_model=args.speculative_model,
22+
num_speculative_tokens=args.num_speculative_tokens,
2123
tokenizer=args.tokenizer,
2224
quantization=args.quantization,
2325
tensor_parallel_size=args.tensor_parallel_size,
@@ -28,6 +30,7 @@ def main(args: argparse.Namespace):
2830
quantization_param_path=args.quantization_param_path,
2931
device=args.device,
3032
ray_workers_use_nsight=args.ray_workers_use_nsight,
33+
use_v2_block_manager=args.use_v2_block_manager,
3134
enable_chunked_prefill=args.enable_chunked_prefill,
3235
download_dir=args.download_dir,
3336
block_size=args.block_size)
@@ -99,6 +102,8 @@ def run_to_completion(profile_dir: Optional[str] = None):
99102
description='Benchmark the latency of processing a single batch of '
100103
'requests till completion.')
101104
parser.add_argument('--model', type=str, default='facebook/opt-125m')
105+
parser.add_argument('--speculative-model', type=str, default=None)
106+
parser.add_argument('--num-speculative-tokens', type=int, default=None)
102107
parser.add_argument('--tokenizer', type=str, default=None)
103108
parser.add_argument('--quantization',
104109
'-q',
@@ -181,6 +186,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
181186
action='store_true',
182187
help='If True, the prefill requests can be chunked based on the '
183188
'max_num_batched_tokens')
189+
parser.add_argument('--use-v2-block-manager', action='store_true')
184190
parser.add_argument(
185191
"--ray-workers-use-nsight",
186192
action='store_true',

tests/spec_decode/e2e/test_compatibility.py

-50
Original file line numberDiff line numberDiff line change
@@ -5,56 +5,6 @@
55
from .conftest import get_output_from_llm_generator
66

77

8-
@pytest.mark.parametrize(
9-
"common_llm_kwargs",
10-
[{
11-
"model": "JackFram/llama-68m",
12-
"speculative_model": "JackFram/llama-68m",
13-
"num_speculative_tokens": 5,
14-
15-
# Required for spec decode.
16-
"use_v2_block_manager": True
17-
}])
18-
@pytest.mark.parametrize(
19-
"per_test_common_llm_kwargs",
20-
[
21-
{
22-
# Expect failure as spec decode not supported by
23-
# Ray backend.
24-
"worker_use_ray": True,
25-
},
26-
])
27-
@pytest.mark.parametrize("test_llm_kwargs", [{}])
28-
@pytest.mark.parametrize("seed", [1])
29-
def test_spec_decode_xfail_ray(test_llm_generator):
30-
"""Verify that speculative decoding with Ray fails.
31-
"""
32-
output_len = 128
33-
temperature = 0.0
34-
35-
prompts = [
36-
"Hello, my name is",
37-
]
38-
39-
sampling_params = SamplingParams(
40-
max_tokens=output_len,
41-
ignore_eos=True,
42-
temperature=temperature,
43-
)
44-
45-
try:
46-
with pytest.raises(
47-
AssertionError,
48-
match="Speculative decoding not yet supported for "):
49-
get_output_from_llm_generator(test_llm_generator, prompts,
50-
sampling_params)
51-
finally:
52-
# we need to free up ray resource,
53-
# so that latter test could use the gpu we allocated here
54-
import ray
55-
ray.shutdown()
56-
57-
588
@pytest.mark.parametrize(
599
"common_llm_kwargs",
6010
[{
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Tests which cover integration of the speculative decoding framework with
2+
other features, e.g. cuda graphs.
3+
"""
4+
5+
import pytest
6+
7+
from .conftest import run_greedy_equality_correctness_test
8+
9+
10+
@pytest.mark.parametrize(
11+
"common_llm_kwargs",
12+
[{
13+
# Required for spec decode.
14+
"use_v2_block_manager": True,
15+
16+
# Verify equality when cuda graphs allowed.
17+
"enforce_eager": False,
18+
"model": "JackFram/llama-68m",
19+
}])
20+
@pytest.mark.parametrize(
21+
"per_test_common_llm_kwargs",
22+
[
23+
{
24+
# Identical models.
25+
"speculative_model": "JackFram/llama-68m",
26+
"num_speculative_tokens": 5,
27+
},
28+
])
29+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
30+
@pytest.mark.parametrize("test_llm_kwargs", [{}])
31+
@pytest.mark.parametrize("batch_size", [8])
32+
@pytest.mark.parametrize("output_len", [32])
33+
@pytest.mark.parametrize("seed", [1])
34+
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
35+
batch_size, output_len):
36+
"""Verify spec decode equality when cuda graphs are enabled.
37+
"""
38+
run_greedy_equality_correctness_test(
39+
baseline_llm_generator,
40+
test_llm_generator,
41+
batch_size,
42+
max_output_len=output_len,
43+
force_output_len=True,
44+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Tests which cover integration of the speculative decoding framework with
2+
tensor parallelism.
3+
"""
4+
5+
import pytest
6+
import torch
7+
8+
from vllm.utils import is_hip
9+
10+
from .conftest import run_greedy_equality_correctness_test
11+
12+
13+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
14+
reason="Need at least 2 GPUs to run the test.")
15+
@pytest.mark.parametrize(
16+
"common_llm_kwargs",
17+
[{
18+
"model": "JackFram/llama-68m",
19+
20+
# Skip cuda graph recording for fast test.
21+
"enforce_eager": True,
22+
23+
# Required for spec decode.
24+
"use_v2_block_manager": True,
25+
"tensor_parallel_size": 2,
26+
27+
# Use AsyncLLM engine, so that the engine runs in its own process.
28+
# Otherwise, since vLLM does not follow true SPMD, the test runner
29+
# process will have both the engine and the rank0 worker. NCCL is not
30+
# cleaned up properly, and its server host thread leaks, causing the
31+
# second run of the test to fail with internal NCCL error.
32+
"use_async": True,
33+
}])
34+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
35+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
36+
@pytest.mark.parametrize("test_llm_kwargs", [
37+
{
38+
"speculative_model": "JackFram/llama-68m",
39+
"num_speculative_tokens": 3,
40+
},
41+
{
42+
"speculative_model": "[ngram]",
43+
"num_speculative_tokens": 5,
44+
"ngram_prompt_lookup_max": 3,
45+
},
46+
])
47+
@pytest.mark.parametrize("batch_size", [2])
48+
@pytest.mark.parametrize(
49+
"output_len",
50+
[
51+
# Use smaller output len for fast test.
52+
32,
53+
])
54+
@pytest.mark.parametrize("seed", [1])
55+
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
56+
batch_size: int, output_len: int):
57+
"""Verify greedy equality when tensor parallelism is used.
58+
"""
59+
if is_hip():
60+
pytest.skip("hip is not well-supported yet")
61+
run_greedy_equality_correctness_test(baseline_llm_generator,
62+
test_llm_generator,
63+
batch_size,
64+
max_output_len=output_len,
65+
force_output_len=True)

tests/spec_decode/e2e/test_multistep_correctness.py

-37
Original file line numberDiff line numberDiff line change
@@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
611611
batch_size,
612612
max_output_len=output_len,
613613
force_output_len=True)
614-
615-
616-
@pytest.mark.parametrize(
617-
"common_llm_kwargs",
618-
[{
619-
# Required for spec decode.
620-
"use_v2_block_manager": True,
621-
622-
# Verify equality when cuda graphs allowed.
623-
"enforce_eager": False,
624-
"model": "JackFram/llama-68m",
625-
}])
626-
@pytest.mark.parametrize(
627-
"per_test_common_llm_kwargs",
628-
[
629-
{
630-
# Identical models.
631-
"speculative_model": "JackFram/llama-68m",
632-
"num_speculative_tokens": 5,
633-
},
634-
])
635-
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
636-
@pytest.mark.parametrize("test_llm_kwargs", [{}])
637-
@pytest.mark.parametrize("batch_size", [8])
638-
@pytest.mark.parametrize("output_len", [32])
639-
@pytest.mark.parametrize("seed", [1])
640-
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
641-
batch_size, output_len):
642-
"""Verify spec decode equality when cuda graphs are enabled.
643-
"""
644-
run_greedy_equality_correctness_test(
645-
baseline_llm_generator,
646-
test_llm_generator,
647-
batch_size,
648-
max_output_len=output_len,
649-
force_output_len=True,
650-
)

vllm/distributed/communication_op.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,16 @@ def broadcast_tensor_dict(
219219
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
220220
dtypes).
221221
"""
222+
# Bypass the function if we are using only 1 GPU.
223+
if (not torch.distributed.is_initialized()
224+
or torch.distributed.get_world_size(group=group) == 1):
225+
return tensor_dict
226+
222227
group = group or torch.distributed.group.WORLD
223228
metadata_group = metadata_group or get_cpu_world_group()
224229
ranks = torch.distributed.get_process_group_ranks(group)
225230
assert src in ranks, f"Invalid src rank ({src})"
226231

227-
# Bypass the function if we are using only 1 GPU.
228-
world_size = torch.distributed.get_world_size(group=group)
229-
if world_size == 1:
230-
return tensor_dict
231-
232232
rank = torch.distributed.get_rank()
233233
if rank == src:
234234
metadata_list: List[Tuple[Any, Any]] = []

vllm/executor/gpu_executor.py

+17-54
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
1515

1616
def _init_executor(self) -> None:
1717
"""Initialize the worker and load the model.
18-
19-
If speculative decoding is enabled, we instead create the speculative
20-
worker.
2118
"""
22-
if self.speculative_config is None:
23-
self._init_non_spec_worker()
24-
else:
25-
self._init_spec_worker()
19+
assert self.parallel_config.world_size == 1, (
20+
"GPUExecutor only supports single GPU.")
21+
22+
self.driver_worker = self._create_worker()
23+
self.driver_worker.init_device()
24+
self.driver_worker.load_model()
2625

2726
def _get_worker_kwargs(
2827
self,
@@ -45,66 +44,30 @@ def _get_worker_kwargs(
4544
distributed_init_method=distributed_init_method,
4645
lora_config=self.lora_config,
4746
vision_language_config=self.vision_language_config,
47+
speculative_config=self.speculative_config,
4848
is_driver_worker=rank == 0,
4949
)
5050

5151
def _create_worker(self,
5252
local_rank: int = 0,
5353
rank: int = 0,
5454
distributed_init_method: Optional[str] = None):
55+
56+
if self.speculative_config is None:
57+
worker_module_name = "vllm.worker.worker"
58+
worker_class_name = "Worker"
59+
else:
60+
worker_module_name = "vllm.spec_decode.spec_decode_worker"
61+
worker_class_name = "create_spec_worker"
62+
5563
wrapper = WorkerWrapperBase(
56-
worker_module_name="vllm.worker.worker",
57-
worker_class_name="Worker",
64+
worker_module_name=worker_module_name,
65+
worker_class_name=worker_class_name,
5866
)
5967
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
6068
distributed_init_method))
6169
return wrapper.worker
6270

63-
def _init_non_spec_worker(self):
64-
assert self.parallel_config.world_size == 1, (
65-
"GPUExecutor only supports single GPU.")
66-
67-
self.driver_worker = self._create_worker()
68-
self.driver_worker.init_device()
69-
self.driver_worker.load_model()
70-
71-
def _init_spec_worker(self):
72-
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
73-
"""
74-
assert self.speculative_config is not None
75-
76-
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
77-
78-
target_worker = self._create_worker()
79-
80-
draft_worker_kwargs = self._get_worker_kwargs()
81-
# Override draft-model specific worker args.
82-
draft_worker_kwargs.update(
83-
model_config=self.speculative_config.draft_model_config,
84-
parallel_config=self.speculative_config.draft_parallel_config,
85-
ngram_prompt_lookup_max=self.speculative_config.
86-
ngram_prompt_lookup_max,
87-
ngram_prompt_lookup_min=self.speculative_config.
88-
ngram_prompt_lookup_min,
89-
# TODO allow draft-model specific load config.
90-
#load_config=self.load_config,
91-
)
92-
93-
spec_decode_worker = SpecDecodeWorker.create_worker(
94-
scorer_worker=target_worker,
95-
draft_worker_kwargs=draft_worker_kwargs,
96-
disable_by_batch_size=self.speculative_config.
97-
speculative_disable_by_batch_size,
98-
)
99-
100-
assert self.parallel_config.world_size == 1, (
101-
"GPUExecutor only supports single GPU.")
102-
103-
self.driver_worker = spec_decode_worker
104-
105-
# Load model handled in spec decode worker.
106-
self.driver_worker.init_device()
107-
10871
def determine_num_available_blocks(self) -> Tuple[int, int]:
10972
"""Determine the number of available KV blocks by invoking the
11073
underlying worker.

0 commit comments

Comments
 (0)