Skip to content

Commit 92bef6b

Browse files
authored
Merge branch 'vllm-project:main' into mypy-checking
2 parents 14a0f69 + 1ecc645 commit 92bef6b

File tree

5 files changed

+95
-13
lines changed

5 files changed

+95
-13
lines changed

docs/source/getting_started/debugging.rst

+5
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ If GPU/CPU communication cannot be established, you can use the following Python
8686
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
8787
8888
pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
89+
# pynccl is enabled by default for 0.6.5+,
90+
# but for 0.6.4 and below, we need to enable it manually.
91+
# keep the code for backward compatibility when because people
92+
# prefer to read the latest documentation.
93+
pynccl.disabled = False
8994
9095
s = torch.cuda.Stream()
9196
with torch.cuda.stream(s):

tests/samplers/test_rejection_sampler.py

+63
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,69 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
200200
assert torch.equal(results[j][i], results[0][i])
201201

202202

203+
@pytest.mark.parametrize("k", [1, 3, 6])
204+
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
205+
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
206+
@pytest.mark.parametrize("device", CUDA_DEVICES)
207+
@pytest.mark.parametrize("use_flashinfer", [True, False])
208+
@torch.inference_mode()
209+
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
210+
device: str, use_flashinfer: bool):
211+
torch.set_default_device(device)
212+
set_random_seed(0)
213+
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
214+
target_probs = torch.rand(batch_size,
215+
k + 1,
216+
vocab_size,
217+
dtype=torch.float32)
218+
bonus_token_ids = torch.randint(low=0,
219+
high=vocab_size,
220+
size=(batch_size, 1),
221+
dtype=torch.int64)
222+
draft_token_ids = torch.randint(low=0,
223+
high=vocab_size,
224+
size=(batch_size, k),
225+
dtype=torch.int64)
226+
227+
single_batches = []
228+
for i in range(batch_size):
229+
single_batches.append((draft_probs[i].clone().unsqueeze(0),
230+
draft_token_ids[i].clone().unsqueeze(0),
231+
target_probs[i].clone().unsqueeze(0),
232+
bonus_token_ids[i].clone().unsqueeze(0),
233+
draft_token_ids[i].clone().unsqueeze(0)))
234+
235+
set_random_seed(0)
236+
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
237+
rejection_sampler.init_gpu_tensors(device=device)
238+
239+
results = []
240+
seeded_seqs = {
241+
i: torch.Generator(device=device).manual_seed(i)
242+
for i in range(1, batch_size) # 0 is seed None
243+
}
244+
batch_result = rejection_sampler(target_probs.clone(),
245+
bonus_token_ids.clone(),
246+
draft_probs.clone(),
247+
draft_token_ids.clone(), seeded_seqs)
248+
249+
set_random_seed(0)
250+
251+
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
252+
rejection_sampler.init_gpu_tensors(device=device)
253+
for i in range(batch_size):
254+
request_seeded_seqs = {
255+
0: torch.Generator(device=device).manual_seed(i)
256+
} if seeded_seqs.get(i) is not None else None
257+
(draft_probs, draft_token_ids, target_probs, bonus_token_ids,
258+
draft_token_ids) = single_batches[i]
259+
results.append(
260+
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
261+
draft_token_ids, request_seeded_seqs))
262+
for i in range(batch_size):
263+
assert torch.equal(batch_result[i], results[i].squeeze(0))
264+
265+
203266
@pytest.mark.parametrize("k", [1, 3, 6])
204267
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
205268
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])

vllm/model_executor/layers/rejection_sampler.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import cached_property
22
from importlib.util import find_spec
3-
from typing import Dict, List, Optional, Tuple
3+
from typing import Dict, Optional, Tuple
44

55
import torch
66
import torch.jit
@@ -386,16 +386,12 @@ def _multinomial(
386386
if not seeded_seqs:
387387
q.exponential_(1.0)
388388
else:
389-
non_seeded_indices: List[int] = []
390389
start = 0
391390
for idx in range(len(q) // k):
392391
end = start + k
393392
generator = seeded_seqs.get(idx)
394-
if generator is None:
395-
non_seeded_indices.extend(list(range(start, end)))
396-
else:
397-
q[start:end].exponential_(1.0, generator=generator)
393+
# Note: generator might be None for non seeded
394+
q[start:end].exponential_(1.0, generator=generator)
398395
start = end
399-
q[non_seeded_indices].exponential_(1.0)
400396

401397
return probs.div_(q).argmax(dim=1).view(-1, num_samples)

vllm/utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1577,8 +1577,18 @@ def direct_register_custom_op(
15771577
library object. If you want to bind the operator to a different library,
15781578
make sure the library object is alive when the operator is used.
15791579
"""
1580-
if is_in_doc_build() or not supports_custom_op():
1580+
if is_in_doc_build():
15811581
return
1582+
1583+
if not supports_custom_op():
1584+
assert not current_platform.is_cuda_alike(), (
1585+
"cuda platform needs torch>=2.4 to support custom op, "
1586+
"chances are you are using an old version of pytorch "
1587+
"or a custom build of pytorch. It is recommended to "
1588+
"use vLLM in a fresh new environment and let it install "
1589+
"the required dependencies.")
1590+
return
1591+
15821592
import torch.library
15831593
if hasattr(torch.library, "infer_schema"):
15841594
schema_str = torch.library.infer_schema(op_func,

vllm/worker/model_runner.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414
import torch.distributed
1515
import torch.nn as nn
16+
from tqdm import tqdm
1617

1718
import vllm.envs as envs
1819
from vllm.attention import AttentionMetadata, get_attn_backend
@@ -21,7 +22,8 @@
2122
from vllm.config import CompilationLevel, VllmConfig
2223
from vllm.core.scheduler import SchedulerOutputs
2324
from vllm.distributed import get_kv_transfer_group, get_pp_group
24-
from vllm.distributed.parallel_state import graph_capture
25+
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
26+
graph_capture)
2527
from vllm.forward_context import set_forward_context
2628
from vllm.inputs import INPUT_REGISTRY, InputRegistry
2729
from vllm.logger import init_logger
@@ -1413,8 +1415,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
14131415
logger.info("Capturing cudagraphs for decoding. This may lead to "
14141416
"unexpected consequences if the model is not static. To "
14151417
"run the model in eager mode, set 'enforce_eager=True' or "
1416-
"use '--enforce-eager' in the CLI.")
1417-
logger.info("If out-of-memory error occurs during cudagraph capture,"
1418+
"use '--enforce-eager' in the CLI. "
1419+
"If out-of-memory error occurs during cudagraph capture,"
14181420
" consider decreasing `gpu_memory_utilization` or "
14191421
"switching to eager mode. You can also reduce the "
14201422
"`max_num_seqs` as needed to decrease memory usage.")
@@ -1451,8 +1453,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
14511453
# memory usage of CUDA graph.
14521454
for virtual_engine in range(
14531455
self.parallel_config.pipeline_parallel_size):
1454-
for batch_size in \
1455-
self.vllm_config.compilation_config.capture_sizes:
1456+
# Only rank 0 should print progress bar during capture
1457+
capture_sizes = (
1458+
tqdm(
1459+
self.vllm_config.compilation_config.capture_sizes,
1460+
desc="Capturing CUDA graph shapes",
1461+
) if get_tensor_model_parallel_rank() == 0 else
1462+
self.vllm_config.compilation_config.capture_sizes)
1463+
for batch_size in capture_sizes:
14561464
attn_metadata = (
14571465
self.attn_state.graph_capture_get_metadata_for_batch(
14581466
batch_size,

0 commit comments

Comments
 (0)