From 8da04652c308da6968718369c1339b1722667e0a Mon Sep 17 00:00:00 2001 From: Wallas Henrique Date: Fri, 20 Dec 2024 02:15:31 -0300 Subject: [PATCH] [Bugfix] Fix spec decoding when seed is none in a batch (#10863) Signed-off-by: Wallas Santos Signed-off-by: lucast2021 --- tests/samplers/test_rejection_sampler.py | 63 +++++++++++++++++++ .../layers/rejection_sampler.py | 10 +-- 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index f5497976faf7a..397fa2cc85821 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -200,6 +200,69 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, assert torch.equal(results[j][i], results[0][i]) +@pytest.mark.parametrize("k", [1, 3, 6]) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", [3, 8, 32, 128]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) +@torch.inference_mode() +def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int, + device: str, use_flashinfer: bool): + torch.set_default_device(device) + set_random_seed(0) + draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + + single_batches = [] + for i in range(batch_size): + single_batches.append((draft_probs[i].clone().unsqueeze(0), + draft_token_ids[i].clone().unsqueeze(0), + target_probs[i].clone().unsqueeze(0), + bonus_token_ids[i].clone().unsqueeze(0), + draft_token_ids[i].clone().unsqueeze(0))) + + set_random_seed(0) + rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer) + rejection_sampler.init_gpu_tensors(device=device) + + results = [] + seeded_seqs = { + i: torch.Generator(device=device).manual_seed(i) + for i in range(1, batch_size) # 0 is seed None + } + batch_result = rejection_sampler(target_probs.clone(), + bonus_token_ids.clone(), + draft_probs.clone(), + draft_token_ids.clone(), seeded_seqs) + + set_random_seed(0) + + rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer) + rejection_sampler.init_gpu_tensors(device=device) + for i in range(batch_size): + request_seeded_seqs = { + 0: torch.Generator(device=device).manual_seed(i) + } if seeded_seqs.get(i) is not None else None + (draft_probs, draft_token_ids, target_probs, bonus_token_ids, + draft_token_ids) = single_batches[i] + results.append( + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids, request_seeded_seqs)) + for i in range(batch_size): + assert torch.equal(batch_result[i], results[i].squeeze(0)) + + @pytest.mark.parametrize("k", [1, 3, 6]) @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 3ab0ba9e9f5c2..97a1b0c9603bd 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -1,6 +1,6 @@ from functools import cached_property from importlib.util import find_spec -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import torch import torch.jit @@ -386,16 +386,12 @@ def _multinomial( if not seeded_seqs: q.exponential_(1.0) else: - non_seeded_indices: List[int] = [] start = 0 for idx in range(len(q) // k): end = start + k generator = seeded_seqs.get(idx) - if generator is None: - non_seeded_indices.extend(list(range(start, end))) - else: - q[start:end].exponential_(1.0, generator=generator) + # Note: generator might be None for non seeded + q[start:end].exponential_(1.0, generator=generator) start = end - q[non_seeded_indices].exponential_(1.0) return probs.div_(q).argmax(dim=1).view(-1, num_samples)