From 57b7be0e1c4e594c58a78297ab65fbb3ec206958 Mon Sep 17 00:00:00 2001 From: William Lin Date: Thu, 8 Aug 2024 22:42:45 -0700 Subject: [PATCH] [Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (#6971) --- tests/samplers/test_sampler.py | 27 ++++++++++++++++++- vllm/lora/layers.py | 4 +++ vllm/model_executor/layers/sampler.py | 4 +-- vllm/spec_decode/medusa_worker.py | 3 +++ vllm/spec_decode/multi_step_worker.py | 4 +++ vllm/spec_decode/proposer_worker_base.py | 4 +++ .../spec_decode/smaller_tp_proposer_worker.py | 6 +++++ vllm/spec_decode/spec_decode_worker.py | 3 +++ 8 files changed, 52 insertions(+), 3 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index bf062e4a5c09d..f1370e411241c 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,7 +1,7 @@ import itertools import random from typing import Dict, List, Optional, Tuple -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest import torch @@ -703,3 +703,28 @@ def test_sampling_params(sampling_params: List[SamplingParams]): assert tokens1[0] == tokens2[1] assert tokens1[1] == tokens2[0] + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_include_gpu_probs_tensor(device: str): + set_random_seed(42) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + _, fake_logits, sampler = _prepare_test(batch_size) + sampler.include_gpu_probs_tensor = True + sampler.should_modify_greedy_probs_inplace = False + + sampling_params = SamplingParams(temperature=0) + + mock_inplace = Mock() + with patch( + "vllm.model_executor.layers.sampler._modify_greedy_probs_inplace", + mock_inplace): + + sampler_output = _do_sample(batch_size, fake_logits, sampler, + sampling_params, device) + mock_inplace.assert_not_called() + + assert sampler_output.sampled_token_probs is not None + assert sampler_output.logprobs is not None + assert sampler_output.sampled_token_ids is not None diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e3316059dc6d1..a8ea67991a375 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1067,6 +1067,10 @@ def org_vocab_size(self): def include_gpu_probs_tensor(self): return self.base_layer.include_gpu_probs_tensor + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + def create_lora_weights( self, max_loras: int, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6632b1c434582..cc78a0ea3b869 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -51,6 +51,7 @@ def __init__(self): # containing the sampled token ids and probabilities. This is used by # speculative decoding. self.include_gpu_probs_tensor = False + self.should_modify_greedy_probs_inplace = False def _init_sampling_tensors( self, @@ -177,8 +178,7 @@ def _should_modify_greedy_probs_inplace(self) -> bool: This is used by speculative decoding, which requires that the sampling method be encoded into the probability distribution. """ - # Modify greedy probs if include_gpu_probs_tensor is set. - return self.include_gpu_probs_tensor + return self.should_modify_greedy_probs_inplace def _get_bin_counts_and_mask( diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index 4b82f7bf92bab..d1809e49c2a8f 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -35,6 +35,9 @@ def init_device(self): def set_include_gpu_probs_tensor(self): pass + def set_should_modify_greedy_probs_inplace(self): + pass + @torch.inference_mode() def sampler_output( self, diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 91689324557b5..65bfb5dc8d5c6 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -46,6 +46,10 @@ def set_include_gpu_probs_tensor(self) -> None: # Need include_gpu_probs_tensor for MultiStepWorker self.model_runner.model.sampler.include_gpu_probs_tensor = True + def set_should_modify_greedy_probs_inplace(self) -> None: + self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( + True) + @torch.inference_mode() def sampler_output( self, diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py index 51cefc0cbca8b..efb8ee25ba2f9 100644 --- a/vllm/spec_decode/proposer_worker_base.py +++ b/vllm/spec_decode/proposer_worker_base.py @@ -28,6 +28,10 @@ def set_include_gpu_probs_tensor(self) -> None: """Implementation optional""" pass + def set_should_modify_greedy_probs_inplace(self) -> None: + """Implementation optional""" + pass + class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC): """Proposer worker which does not use a model with kvcache""" diff --git a/vllm/spec_decode/smaller_tp_proposer_worker.py b/vllm/spec_decode/smaller_tp_proposer_worker.py index 0dbb924d25400..215ede52fb812 100644 --- a/vllm/spec_decode/smaller_tp_proposer_worker.py +++ b/vllm/spec_decode/smaller_tp_proposer_worker.py @@ -83,6 +83,12 @@ def set_include_gpu_probs_tensor(self) -> None: # Need include_gpu_probs_tensor for multi_step_worker self._worker.set_include_gpu_probs_tensor() + def set_should_modify_greedy_probs_inplace(self) -> None: + if self._is_dummy: + return + + self._worker.set_should_modify_greedy_probs_inplace() + def load_model(self) -> None: if self._is_dummy: return diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 690aad505e215..63a00139cc09d 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -295,7 +295,10 @@ def _configure_model_sampler_for_spec_decode(self): """ (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor ) = True + (self.scorer_worker.model_runner.model.sampler. + should_modify_greedy_probs_inplace) = True self.proposer_worker.set_include_gpu_probs_tensor() + self.proposer_worker.set_should_modify_greedy_probs_inplace() def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use.