Skip to content

Commit

Permalink
Adds MEAN pooling type
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 21, 2024
1 parent 03d72a7 commit 0d6123a
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class PoolingType(IntEnum):
ALL = 1
CLS = 2
MEAN = 3
MAX = 4


class PoolingConfig():
Expand Down Expand Up @@ -49,7 +48,7 @@ class Pooler(nn.Module):
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, ALL, CLS).
pooling_type: The type of pooling to use (LAST, ALL, CLS, MEAN).
normalize: Whether to normalize the pooled data.
"""

Expand Down Expand Up @@ -83,6 +82,17 @@ def forward(
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
elif self.pooling_type == PoolingType.MEAN:
# Calculate mean pooling
cumsum = torch.cumsum(hidden_states, dim=0)
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
pooled_data = (
cumsum[end_indices - 1] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")

Expand Down

0 comments on commit 0d6123a

Please sign in to comment.