Skip to content

Commit

Permalink
revert softmax inside the pooledr
Browse files Browse the repository at this point in the history
  • Loading branch information
jason9693 committed Oct 26, 2024
1 parent d66d5bc commit fc1b238
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
9 changes: 8 additions & 1 deletion vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ class Pooler(nn.Module):
normalize: Whether to normalize the pooled data.
"""

def __init__(self, pooling_type: PoolingType, normalize: bool):
def __init__(self,
pooling_type: PoolingType,
normalize: bool,
softmax: bool = False):
super().__init__()

self.pooling_type = pooling_type
self.normalize = normalize
self.softmax = softmax

def forward(
self,
Expand Down Expand Up @@ -64,6 +68,9 @@ def forward(
if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)

if self.softmax:
pooled_data = nn.functional.softmax(pooled_data, dim=-1)

pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
]
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/qwen2_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def __init__(
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
self._pooler = Pooler(pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)

def forward(
self,
Expand All @@ -97,8 +99,7 @@ def pooler(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
pooled = self._pooler(hidden_states, pooling_metadata)
return nn.functional.softmax(pooled, dim=-1)
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self,
Expand Down

0 comments on commit fc1b238

Please sign in to comment.