Skip to content

Commit

Permalink
[Bugfix] Fix spec decoding when seed is none in a batch (vllm-project…
Browse files Browse the repository at this point in the history
…#10863)

Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: lucast2021 <[email protected]>
  • Loading branch information
wallashss authored and lucast2021 committed Dec 21, 2024
1 parent 547fafb commit 8da0465
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 7 deletions.
63 changes: 63 additions & 0 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,69 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
assert torch.equal(results[j][i], results[0][i])


@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool):
torch.set_default_device(device)
set_random_seed(0)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)

single_batches = []
for i in range(batch_size):
single_batches.append((draft_probs[i].clone().unsqueeze(0),
draft_token_ids[i].clone().unsqueeze(0),
target_probs[i].clone().unsqueeze(0),
bonus_token_ids[i].clone().unsqueeze(0),
draft_token_ids[i].clone().unsqueeze(0)))

set_random_seed(0)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)

results = []
seeded_seqs = {
i: torch.Generator(device=device).manual_seed(i)
for i in range(1, batch_size) # 0 is seed None
}
batch_result = rejection_sampler(target_probs.clone(),
bonus_token_ids.clone(),
draft_probs.clone(),
draft_token_ids.clone(), seeded_seqs)

set_random_seed(0)

rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
for i in range(batch_size):
request_seeded_seqs = {
0: torch.Generator(device=device).manual_seed(i)
} if seeded_seqs.get(i) is not None else None
(draft_probs, draft_token_ids, target_probs, bonus_token_ids,
draft_token_ids) = single_batches[i]
results.append(
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, request_seeded_seqs))
for i in range(batch_size):
assert torch.equal(batch_result[i], results[i].squeeze(0))


@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
Expand Down
10 changes: 3 additions & 7 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import cached_property
from importlib.util import find_spec
from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple

import torch
import torch.jit
Expand Down Expand Up @@ -386,16 +386,12 @@ def _multinomial(
if not seeded_seqs:
q.exponential_(1.0)
else:
non_seeded_indices: List[int] = []
start = 0
for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.extend(list(range(start, end)))
else:
q[start:end].exponential_(1.0, generator=generator)
# Note: generator might be None for non seeded
q[start:end].exponential_(1.0, generator=generator)
start = end
q[non_seeded_indices].exponential_(1.0)

return probs.div_(q).argmax(dim=1).view(-1, num_samples)

0 comments on commit 8da0465

Please sign in to comment.