diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 6dd643bbea2bb..b6330a5e5f7c5 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -150,9 +150,54 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, high=vocab_size, size=(batch_size, k), dtype=torch.int64) + generators = [None] * batch_size rejection_sampler(target_probs, bonus_token_ids, draft_probs, - draft_token_ids) + draft_token_ids, generators) + + +@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0]) +@pytest.mark.parametrize("k", [1, 3, 6]) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) +@pytest.mark.parametrize("n_rep", [100]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, + frac_seeded: float, n_rep: int, + device: str): + torch.set_default_device(device) + rejection_sampler = RejectionSampler() + rejection_sampler.init_gpu_tensors(rank=0) + + draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + + seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded + + results = [] + for _ in range(n_rep): + generators = [ + torch.Generator( + device=device).manual_seed(i) if seeded_mask[i] else None + for i in range(batch_size) + ] + results.append( + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids, generators)) + + for i in range(batch_size): + if seeded_mask[i]: + for j in range(1, n_rep): + assert torch.equal(results[j][i], results[0][i]) @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) @@ -197,10 +242,11 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, raise AssertionError() oob_token_ids[0][0] = rogue_token_id + generators = [None] * batch_size with pytest.raises(AssertionError): rejection_sampler(target_probs, bonus_token_ids, draft_probs, - draft_token_ids) + draft_token_ids, generators) @pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) @@ -371,11 +417,15 @@ def _estimate_rejection_sampling_pdf( dtype=torch.int64, device="cuda").repeat(num_samples, 1) + # unseeded + generators = [None] + # Get output tokens via rejection sampling. output_token_ids = self.rejection_sampler(target_probs.to("cuda"), bonus_token_ids.to("cuda"), draft_probs.to("cuda"), - draft_token_ids.to("cuda")) + draft_token_ids.to("cuda"), + generators) # Remove bonus tokens output_token_ids = output_token_ids[:, :-1].flatten() diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index da72f6d503c11..bd1ea43f0b101 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,6 +1,6 @@ import asyncio from itertools import cycle -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import pytest import ray @@ -128,7 +128,9 @@ async def get_output(prompt, sampling_param) -> RequestOutput: try: for i in range(num_requests): prompt = prompts[i] if prompts is not None else None - res = asyncio.run(get_output(prompt, sampling_params)) + params = sampling_params[i] if isinstance( + sampling_params, Sequence) else sampling_params + res = asyncio.run(get_output(prompt, params)) outputs.append(res) finally: ray.shutdown() @@ -267,7 +269,31 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero. """ - temperature = 0.0 + + run_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len, + temperature=0.0, + seeded=False, + print_tokens=print_tokens, + ensure_all_accepted=ensure_all_accepted) + + +def run_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + temperature: float, + seeded: bool, + print_tokens: bool = False, + ensure_all_accepted: bool = False): + """Helper method that compares the outputs of both the baseline LLM and + the test LLM. It asserts greedy equality, e.g. that the outputs are exactly + the same when temperature is zero (or when temperature is > 0 and seeded). + """ prompts = [ "Hello, my name is", @@ -286,11 +312,21 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, # sampling params to ignore eos token. ignore_eos = force_output_len - sampling_params = SamplingParams( - max_tokens=max_output_len, - ignore_eos=ignore_eos, - temperature=temperature, - ) + if seeded: + sampling_params = [ + SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + seed=i, + ) for i in range(len(prompts)) + ] + else: + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + ) (spec_batch_tokens, spec_batch_token_ids, acceptance_rate) = get_output_from_llm_generator(test_llm_generator, diff --git a/tests/spec_decode/e2e/test_seed.py b/tests/spec_decode/e2e/test_seed.py new file mode 100644 index 0000000000000..792d7cba0f270 --- /dev/null +++ b/tests/spec_decode/e2e/test_seed.py @@ -0,0 +1,44 @@ +import pytest + +from .conftest import run_equality_correctness_test + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # speculative model + "speculative_model": "JackFram/llama-160m", + + # num speculative tokens + "num_speculative_tokens": 3, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [1, 8, 32]) +@pytest.mark.parametrize("temperature", [0.1, 1.0]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 10, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_seeded_consistency(baseline_llm_generator, batch_size: int, + temperature: float, output_len: int): + """Verify outputs are consistent across multiple runs with same seed + """ + run_equality_correctness_test(baseline_llm_generator, + baseline_llm_generator, + batch_size, + max_output_len=output_len, + temperature=temperature, + seeded=True, + force_output_len=True) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index e189610461a70..b4994083c797b 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -1,14 +1,14 @@ from functools import cached_property -from typing import Tuple +from typing import List, Optional, Tuple import torch import torch.jit from vllm.model_executor.layers.spec_decode_base_sampler import ( - SpecDecodeBaseSampler) + SpecDecodeStochasticBaseSampler) -class RejectionSampler(SpecDecodeBaseSampler): +class RejectionSampler(SpecDecodeStochasticBaseSampler): """Apply modified rejection sampling as described in "Accelerating Large Language Model Decoding with Speculative Sampling" https://arxiv.org/pdf/2302.01318.pdf. @@ -36,6 +36,7 @@ def forward( bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, + generators: List[Optional[torch.Generator]], ) -> torch.Tensor: """Sample token ids using rejection sampling. This accepts or rejects tokens proposed by the draft model using the probability of each token @@ -82,6 +83,7 @@ def forward( target_probs, draft_probs, draft_token_ids, + generators, )) output_token_ids = self._create_output( @@ -94,10 +96,11 @@ def forward( return output_token_ids def _batch_modified_rejection_sampling( - self, - target_probs: torch.Tensor, # [batch_size, k, vocab_size] - draft_probs: torch.Tensor, # [batch_size, k, vocab_size] - draft_token_ids: torch.Tensor, # [batch_size, k] + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + generators: List[Optional[torch.Generator]], ) -> Tuple[torch.Tensor, torch.Tensor]: """Perform modified rejection sampling on each sequence. @@ -114,22 +117,33 @@ def _batch_modified_rejection_sampling( # shape [batch_size, k] accepted = self._get_accepted(target_probs, draft_probs, - draft_token_ids) + draft_token_ids, generators) recovered_probs = self._get_recovered_probs( target_probs, draft_probs).reshape(batch_size * k, vocab_size) + seed_indices, non_seed_indices = self._split_batch_by_seeded( + generators, k=k) + # NOTE: the recovered_probs are overwritten by this method. - recovered_token_ids = _multinomial(recovered_probs, - num_samples=1).reshape( - batch_size, k) + recovered_token_ids = _multinomial( + recovered_probs, + num_samples=1, + k=k, + generators=generators, + seed_indices=seed_indices, + # this arg is unused when None but torch.jit requires a list + non_seed_indices=non_seed_indices or [], + ).reshape(batch_size, k) + return accepted, recovered_token_ids def _get_accepted( - self, - target_probs: torch.Tensor, # [batch_size, k, vocab_size] - draft_probs: torch.Tensor, # [batch_size, k, vocab_size] - draft_token_ids: torch.Tensor, # [batch_size, k] + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + generators: List[Optional[torch.Generator]], ) -> torch.Tensor: r"""Create bool matrix over the proposed draft tokens. If True, then a token can be accepted, else it should be @@ -164,10 +178,28 @@ def _get_accepted( selected_target_probs = target_probs[batch_indices, probs_indicies, draft_token_ids] - uniform_rand = torch.rand(batch_size, - k, - dtype=self.probs_dtype, - device=target_probs.device) + seed_indices, non_seed_indices = self._split_batch_by_seeded( + generators) + + if len(seed_indices) == 0: + uniform_rand = torch.rand_like(selected_target_probs) + else: + uniform_rand = torch.empty_like(selected_target_probs) + + for idx in seed_indices: + uniform_rand[idx, :] = torch.rand(1, + k, + dtype=self.probs_dtype, + device=target_probs.device, + generator=generators[idx]) + + if non_seed_indices: + uniform_rand[non_seed_indices, :] = torch.rand( + len(non_seed_indices), + k, + dtype=self.probs_dtype, + device=target_probs.device) + capped_ratio = torch.minimum( selected_target_probs / selected_draft_probs, torch.full((1, ), 1, device=target_probs.device)) @@ -240,6 +272,27 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny + # partition batch into indices for which a generator is provided + # and indicies for which no generator is provided + @staticmethod + def _split_batch_by_seeded( + generators: List[Optional[torch.Generator]], + k: int = 1, + ) -> Tuple[List[int], Optional[List[int]]]: + + if all(generator is None for generator in generators): + seed_indices: List[int] = [] + non_seed_indices: Optional[List[int]] = None + else: + seed_indices, non_seed_indices = [], [] + for i, generator in enumerate(generators): + if generator is None: + non_seed_indices.extend(range(k * i, k * (i + 1))) + else: + seed_indices.extend(range(k * i, k * (i + 1))) + + return seed_indices, non_seed_indices + # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. @@ -250,12 +303,25 @@ def _smallest_positive_value(self) -> float: def _multinomial( probs: torch.Tensor, num_samples: int, + k: int, + generators: List[Optional[torch.Generator]], + seed_indices: List[int], + non_seed_indices: List[int], ) -> torch.Tensor: + if num_samples > 1: # This is equivalent to torch.repeat_interleaved (which also # forces a GPU<->CPU sync). probs = probs[:, None, :].expand(probs.shape[0], num_samples, probs.shape[1]).contiguous().view( -1, probs.shape[1]) - q = torch.empty_like(probs).exponential_(1.0) + + q = torch.empty_like(probs) + if len(seed_indices) == 0: + q.exponential_(1.0) + else: + q[non_seed_indices].exponential_(1.0) + for idx in seed_indices: + q[idx].exponential_(1.0, generator=generators[idx // k]) + return probs.div_(q).argmax(dim=1).view(-1, num_samples) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 692024056495c..08191da49d52f 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional +from typing import List, Optional import torch import torch.jit @@ -54,16 +54,6 @@ def probs_dtype(self): def token_id_dtype(self): return torch.int64 - @abstractmethod - def forward( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> torch.Tensor: - raise NotImplementedError - def _create_output( self, accepted: torch.Tensor, # [batch_size, k] @@ -217,3 +207,36 @@ def _raise_if_out_of_bounds_vocab( assert torch.all(bonus_token_ids >= 0) assert torch.all(draft_token_ids < vocab_size) assert torch.all(draft_token_ids >= 0) + + +class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler): + """Base class for samplers used for Speculative Decoding verification + step which are deterministic. + """ + + @abstractmethod + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler): + """Base class for samplers used for Speculative Decoding verification + step which are stochastic + """ + + @abstractmethod + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + generators: List[Optional[torch.Generator]], + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 9bf3c84a161c5..a87ea0eee57de 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -2,10 +2,10 @@ import torch.jit from vllm.model_executor.layers.spec_decode_base_sampler import ( - SpecDecodeBaseSampler) + SpecDecodeDeterministicBaseSampler) -class TypicalAcceptanceSampler(SpecDecodeBaseSampler): +class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): """Apply typical acceptance sampling as described in section 3.3.1 in "MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 40516556344e9..41f0aebf3c01b 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -4,7 +4,8 @@ import torch from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, - SequenceGroupMetadata, get_all_seq_ids) + SequenceGroupMetadata, SequenceGroupState, + get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, @@ -292,6 +293,15 @@ def _create_single_target_seq_group_metadata( for data in new_seq_data_dict.values(): data.update_num_computed_tokens(data.get_len() - 1) + if (seq_group_metadata.state is not None + and seq_group_metadata.state.generator is not None): + generator = torch.Generator( + device=seq_group_metadata.state.generator.device) + generator.set_state(seq_group_metadata.state.generator.get_state()) + state = SequenceGroupState(generator=generator) + else: + state = None + return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, @@ -302,6 +312,7 @@ def _create_single_target_seq_group_metadata( }, lora_request=None, token_chunk_size=1, + state=state, ) def _split_scoring_output( diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 903264aad7a15..553c956ae45d9 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -9,7 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.spec_decode_base_sampler import ( - SpecDecodeBaseSampler) + SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, @@ -521,11 +521,28 @@ def _verify_tokens( # Get proposed tokens. proposal_token_ids = proposals.proposal_token_ids[spec_indices] + # Sampler arguments + sampler_extra_kwargs = {} + if isinstance(self.spec_decode_sampler, + SpecDecodeStochasticBaseSampler): + + # Get sequence group state + generators = [] + for seq_group_metadata in seq_group_metadata_list: + if (seq_group_metadata.state is not None + and seq_group_metadata.state.generator is not None): + generators.append(seq_group_metadata.state.generator) + else: + generators.append(None) + + sampler_extra_kwargs["generators"] = generators + accepted_token_ids = self.spec_decode_sampler( target_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_probs=proposal_probs, draft_token_ids=proposal_token_ids, + **sampler_extra_kwargs, ) # Append output tokens from non-speculative sequences to