Skip to content

Commit

Permalink
[TPU] Fix greedy decoding (vllm-project#6933)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
WoosukKwon authored and Alvant committed Oct 26, 2024
1 parent 47ada3b commit 95eb1a5
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@

logger = init_logger(__name__)

_PAD_SLOT_ID = -1 # NOTE(woosuk): In PyTorch XLA, index -1 is ignored.
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
Expand Down Expand Up @@ -414,10 +416,7 @@ def _prepare_sample(
best_of = []
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
# low temperature. This is not accurate.
t.append(sampling_params.temperature
if sampling_params.temperature >= 1e-5 else 1e-5)
t.append(sampling_params.temperature)
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
raise NotImplementedError(
"Top-p sampling is currently disabled for the TPU backend "
Expand Down Expand Up @@ -678,13 +677,23 @@ def forward(
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)

logits = logits / t.unsqueeze(dim=1)
# Argmax sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)

# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t = torch.where(t != 0, t, 1.0)
logits = logits / nonzero_t.unsqueeze(dim=1)
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))

# Random sampling.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
next_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
sampled_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
next_token_ids = torch.where(t != 0, sampled_token_ids,
argmax_token_ids)
return next_token_ids


Expand Down

0 comments on commit 95eb1a5

Please sign in to comment.