Skip to content

Commit

Permalink
fix: fix MultiVectorPooler strategy initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
tonywu71 committed Sep 10, 2024
1 parent 19c70d4 commit ea3f281
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion colpali_engine/compression/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .pooling import MultiVectorPooler, MultiVectorPoolingStrategy
from .pooling import MultiVectorPooler
2 changes: 1 addition & 1 deletion colpali_engine/compression/pooling/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .multi_vector_pooler import MultiVectorPooler, MultiVectorPoolingStrategy
from .multi_vector_pooler import MultiVectorPooler
25 changes: 11 additions & 14 deletions colpali_engine/compression/pooling/multi_vector_pooler.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
from enum import Enum
from typing import ClassVar, List

import torch
import torch.nn as nn


class MultiVectorPoolingStrategy(str, Enum):
MEAN = "mean"
SUM = "sum"
MAX = "max"


class MultiVectorPooler(nn.Module):
def __init__(self, pooling_strategy: str = MultiVectorPoolingStrategy.MEAN):
supported_pooling_strategies: ClassVar[List[str]] = ["mean", "sum", "max"]

def __init__(self, pooling_strategy: str = "mean"):
"""
Initialize the MultiVectorPooler with a specified pooling strategy.
Args:
- pooling_strategy (SupportedPoolingType): The type of pooling to apply.
"""
super(MultiVectorPooler, self).__init__()
if not isinstance(pooling_strategy, MultiVectorPoolingStrategy):
super().__init__()

if pooling_strategy not in self.supported_pooling_strategies:
raise ValueError(
f"Unsupported pooling type: {pooling_strategy}. Use one of {list(MultiVectorPoolingStrategy)}."
f"Unsupported pooling type: {pooling_strategy}. Use one of {self.supported_pooling_strategies}."
)
self.pooling_strategy = pooling_strategy

Expand All @@ -35,11 +32,11 @@ def forward(self, input_tensor) -> torch.Tensor:
Returns:
- torch.Tensor: A 2D tensor with shape (batch_size, dim) after pooling.
"""
if self.pooling_strategy == MultiVectorPoolingStrategy.MEAN:
if self.pooling_strategy == "mean":
pooled_tensor = torch.mean(input_tensor, dim=1)
elif self.pooling_strategy == MultiVectorPoolingStrategy.SUM:
elif self.pooling_strategy == "sum":
pooled_tensor = torch.sum(input_tensor, dim=1)
elif self.pooling_strategy == MultiVectorPoolingStrategy.MAX:
elif self.pooling_strategy == "max":
pooled_tensor, _ = torch.max(input_tensor, dim=1)
else:
raise ValueError(f"Unsupported pooling strategy: {self.pooling_strategy}.")
Expand Down

0 comments on commit ea3f281

Please sign in to comment.