From 95eb1a52c6b0f5237f5bd9284ab646eb1caa91eb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 30 Jul 2024 02:06:29 -0700 Subject: [PATCH] [TPU] Fix greedy decoding (#6933) Signed-off-by: Alvant --- vllm/worker/tpu_model_runner.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 1692094af8c41..cf4cc5535ba5b 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -28,7 +28,9 @@ logger = init_logger(__name__) -_PAD_SLOT_ID = -1 # NOTE(woosuk): In PyTorch XLA, index -1 is ignored. +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 # FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. _ENABLE_TOP_P = False # FIXME(woosuk): A temporary hack to support `n > 1`. @@ -414,10 +416,7 @@ def _prepare_sample( best_of = [] for seq_group_metadata in seq_group_metadata_list: sampling_params = seq_group_metadata.sampling_params - # NOTE(woosuk): Here we mimic argmax sampling by applying a very - # low temperature. This is not accurate. - t.append(sampling_params.temperature - if sampling_params.temperature >= 1e-5 else 1e-5) + t.append(sampling_params.temperature) if sampling_params.top_p != 1 and not _ENABLE_TOP_P: raise NotImplementedError( "Top-p sampling is currently disabled for the TPU backend " @@ -678,13 +677,23 @@ def forward( hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, sampling_metadata) - logits = logits / t.unsqueeze(dim=1) + # Argmax sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + + # Zero temperature means greedy decoding. Avoid division by zero. + nonzero_t = torch.where(t != 0, t, 1.0) + logits = logits / nonzero_t.unsqueeze(dim=1) if _ENABLE_TOP_P: logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + + # Random sampling. probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - next_token_ids = torch.multinomial(probs, - num_samples, - replacement=True) + sampled_token_ids = torch.multinomial(probs, + num_samples, + replacement=True) + next_token_ids = torch.where(t != 0, sampled_token_ids, + argmax_token_ids) return next_token_ids