diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 3951619c6e3ec..57b6c7f907ae2 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -4,9 +4,9 @@ import torch import torch.nn as nn -from vllm.distributed import tensor_model_parallel_gather, tensor_model_parallel_all_gather +from vllm.distributed import tensor_model_parallel_gather from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.utils import is_hpu + class LogitsProcessor(nn.Module): """Process logits and apply logits processors from sampling metadata. @@ -50,9 +50,7 @@ def forward( # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) - # NOTE(kzawora): allgather on HPU will cause logits to be not None, - # and we need to guard against applying logits processors on non-driver worker - if logits is not None and sampling_metadata.seq_groups is not None: + if logits is not None: logits *= self.scale # Apply logits processors (if any). @@ -66,9 +64,7 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: logits += embedding_bias - # NOTE(kzawora): HPU PT bridge is missing support for single-rank gather. We'll use all-gather on Gaudi for now. - gather_op = tensor_model_parallel_all_gather if is_hpu() else tensor_model_parallel_gather - logits = gather_op(logits) + logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[:, :self.org_vocab_size]