Skip to content

Commit

Permalink
[Bugfix] Make spec. decode respect per-request seed. (vllm-project#6034)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
  • Loading branch information
2 people authored and phil committed Aug 6, 2024
1 parent 4f9680e commit c5aa35c
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 46 deletions.
56 changes: 53 additions & 3 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,54 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
generators = [None] * batch_size

rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
draft_token_ids, generators)


@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int,
device: str):
torch.set_default_device(device)
rejection_sampler = RejectionSampler()
rejection_sampler.init_gpu_tensors(rank=0)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)

seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

results = []
for _ in range(n_rep):
generators = [
torch.Generator(
device=device).manual_seed(i) if seeded_mask[i] else None
for i in range(batch_size)
]
results.append(
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators))

for i in range(batch_size):
if seeded_mask[i]:
for j in range(1, n_rep):
assert torch.equal(results[j][i], results[0][i])


@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
Expand Down Expand Up @@ -197,10 +242,11 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
raise AssertionError()

oob_token_ids[0][0] = rogue_token_id
generators = [None] * batch_size

with pytest.raises(AssertionError):
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
draft_token_ids, generators)


@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
Expand Down Expand Up @@ -371,11 +417,15 @@ def _estimate_rejection_sampling_pdf(
dtype=torch.int64,
device="cuda").repeat(num_samples, 1)

# unseeded
generators = [None]

# Get output tokens via rejection sampling.
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
bonus_token_ids.to("cuda"),
draft_probs.to("cuda"),
draft_token_ids.to("cuda"))
draft_token_ids.to("cuda"),
generators)

# Remove bonus tokens
output_token_ids = output_token_ids[:, :-1].flatten()
Expand Down
52 changes: 44 additions & 8 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from itertools import cycle
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union

import pytest
import ray
Expand Down Expand Up @@ -128,7 +128,9 @@ async def get_output(prompt, sampling_param) -> RequestOutput:
try:
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
res = asyncio.run(get_output(prompt, sampling_params))
params = sampling_params[i] if isinstance(
sampling_params, Sequence) else sampling_params
res = asyncio.run(get_output(prompt, params))
outputs.append(res)
finally:
ray.shutdown()
Expand Down Expand Up @@ -267,7 +269,31 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0

run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len,
temperature=0.0,
seeded=False,
print_tokens=print_tokens,
ensure_all_accepted=ensure_all_accepted)


def run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero (or when temperature is > 0 and seeded).
"""

prompts = [
"Hello, my name is",
Expand All @@ -286,11 +312,21 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
# sampling params to ignore eos token.
ignore_eos = force_output_len

sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
if seeded:
sampling_params = [
SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
seed=i,
) for i in range(len(prompts))
]
else:
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)

(spec_batch_tokens, spec_batch_token_ids,
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
Expand Down
44 changes: 44 additions & 0 deletions tests/spec_decode/e2e/test_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

from .conftest import run_equality_correctness_test


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# speculative model
"speculative_model": "JackFram/llama-160m",
# num speculative tokens
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [1, 8, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
10,
])
@pytest.mark.parametrize("seed", [1])
def test_seeded_consistency(baseline_llm_generator, batch_size: int,
temperature: float, output_len: int):
"""Verify outputs are consistent across multiple runs with same seed
"""
run_equality_correctness_test(baseline_llm_generator,
baseline_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=True,
force_output_len=True)
106 changes: 86 additions & 20 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from functools import cached_property
from typing import Tuple
from typing import List, Optional, Tuple

import torch
import torch.jit

from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
SpecDecodeStochasticBaseSampler)


class RejectionSampler(SpecDecodeBaseSampler):
class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
Expand Down Expand Up @@ -36,6 +36,7 @@ def forward(
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
Expand Down Expand Up @@ -82,6 +83,7 @@ def forward(
target_probs,
draft_probs,
draft_token_ids,
generators,
))

output_token_ids = self._create_output(
Expand All @@ -94,10 +96,11 @@ def forward(
return output_token_ids

def _batch_modified_rejection_sampling(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
Expand All @@ -114,22 +117,33 @@ def _batch_modified_rejection_sampling(

# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids)
draft_token_ids, generators)

recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)

seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators, k=k)

# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(recovered_probs,
num_samples=1).reshape(
batch_size, k)
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
generators=generators,
seed_indices=seed_indices,
# this arg is unused when None but torch.jit requires a list
non_seed_indices=non_seed_indices or [],
).reshape(batch_size, k)

return accepted, recovered_token_ids

def _get_accepted(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
Expand Down Expand Up @@ -164,10 +178,28 @@ def _get_accepted(
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]

uniform_rand = torch.rand(batch_size,
k,
dtype=self.probs_dtype,
device=target_probs.device)
seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators)

if len(seed_indices) == 0:
uniform_rand = torch.rand_like(selected_target_probs)
else:
uniform_rand = torch.empty_like(selected_target_probs)

for idx in seed_indices:
uniform_rand[idx, :] = torch.rand(1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generators[idx])

if non_seed_indices:
uniform_rand[non_seed_indices, :] = torch.rand(
len(non_seed_indices),
k,
dtype=self.probs_dtype,
device=target_probs.device)

capped_ratio = torch.minimum(
selected_target_probs / selected_draft_probs,
torch.full((1, ), 1, device=target_probs.device))
Expand Down Expand Up @@ -240,6 +272,27 @@ def _smallest_positive_value(self) -> float:
"""
return torch.finfo(self.probs_dtype).tiny

# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@staticmethod
def _split_batch_by_seeded(
generators: List[Optional[torch.Generator]],
k: int = 1,
) -> Tuple[List[int], Optional[List[int]]]:

if all(generator is None for generator in generators):
seed_indices: List[int] = []
non_seed_indices: Optional[List[int]] = None
else:
seed_indices, non_seed_indices = [], []
for i, generator in enumerate(generators):
if generator is None:
non_seed_indices.extend(range(k * i, k * (i + 1)))
else:
seed_indices.extend(range(k * i, k * (i + 1)))

return seed_indices, non_seed_indices


# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
Expand All @@ -250,12 +303,25 @@ def _smallest_positive_value(self) -> float:
def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
generators: List[Optional[torch.Generator]],
seed_indices: List[int],
non_seed_indices: List[int],
) -> torch.Tensor:

if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1.0)

q = torch.empty_like(probs)
if len(seed_indices) == 0:
q.exponential_(1.0)
else:
q[non_seed_indices].exponential_(1.0)
for idx in seed_indices:
q[idx].exponential_(1.0, generator=generators[idx // k])

return probs.div_(q).argmax(dim=1).view(-1, num_samples)
Loading

0 comments on commit c5aa35c

Please sign in to comment.