Skip to content

Commit

Permalink
feat: move ColPali2 classes to match restructured package
Browse files Browse the repository at this point in the history
  • Loading branch information
tonywu71 committed Sep 9, 2024
1 parent bb0a87a commit dc28fe0
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 27 deletions.
1 change: 1 addition & 0 deletions colpali_engine/compression/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pooling import MultiVectorPooler, MultiVectorPoolingStrategy
1 change: 1 addition & 0 deletions colpali_engine/compression/pooling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .multi_vector_pooler import MultiVectorPooler, MultiVectorPoolingStrategy
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,25 @@
import torch.nn as nn


class PoolingStrategy(str, Enum):
class MultiVectorPoolingStrategy(str, Enum):
MEAN = "mean"
SUM = "sum"
MAX = "max"


class MultiVectorPooler(nn.Module):
def __init__(self, pooling_strategy: str = PoolingStrategy.MEAN):
def __init__(self, pooling_strategy: str = MultiVectorPoolingStrategy.MEAN):
"""
Initialize the MultiVectorPooler with a specified pooling strategy.
Args:
- pooling_strategy (SupportedPoolingType): The type of pooling to apply.
"""
super(MultiVectorPooler, self).__init__()
if not isinstance(pooling_strategy, PoolingStrategy):
raise ValueError(f"Unsupported pooling type: {pooling_strategy}. Use one of {list(PoolingStrategy)}.")
if not isinstance(pooling_strategy, MultiVectorPoolingStrategy):
raise ValueError(
f"Unsupported pooling type: {pooling_strategy}. Use one of {list(MultiVectorPoolingStrategy)}."
)
self.pooling_strategy = pooling_strategy

def forward(self, input_tensor) -> torch.Tensor:
Expand All @@ -33,11 +35,11 @@ def forward(self, input_tensor) -> torch.Tensor:
Returns:
- torch.Tensor: A 2D tensor with shape (batch_size, dim) after pooling.
"""
if self.pooling_strategy == PoolingStrategy.MEAN:
if self.pooling_strategy == MultiVectorPoolingStrategy.MEAN:
pooled_tensor = torch.mean(input_tensor, dim=1)
elif self.pooling_strategy == PoolingStrategy.SUM:
elif self.pooling_strategy == MultiVectorPoolingStrategy.SUM:
pooled_tensor = torch.sum(input_tensor, dim=1)
elif self.pooling_strategy == PoolingStrategy.MAX:
elif self.pooling_strategy == MultiVectorPoolingStrategy.MAX:
pooled_tensor, _ = torch.max(input_tensor, dim=1)
else:
raise ValueError(f"Unsupported pooling strategy: {self.pooling_strategy}.")
Expand Down
18 changes: 9 additions & 9 deletions colpali_engine/loss/colpali_2_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
import torch.nn as nn
import torch.nn.functional as F # noqa: N812

from colpali_engine.models.late_interaction.colpali_2.colpali_2_modeling_outputs import ColPali2ModelOutput
from colpali_engine.models.paligemma.colpali_2.modeling_colpali_2 import ColPali2ModelOutput


@dataclass(kw_only=True)
class ColPali2LossOutputs:
single_vector_loss: torch.Tensor
multi_vector_loss: torch.Tensor
distillation_loss: Optional[torch.Tensor] = None
total_loss: torch.Tensor


class MatryoshkaCELoss(torch.nn.Module):
Expand Down Expand Up @@ -34,14 +42,6 @@ def forward(self, output, target) -> torch.Tensor:
return weighted_losses.sum()


@dataclass(kw_only=True)
class ColPali2LossOutputs:
single_vector_loss: torch.Tensor
multi_vector_loss: torch.Tensor
distillation_loss: Optional[torch.Tensor] = None
total_loss: torch.Tensor


class ColPali2Loss(torch.nn.Module):
"""
Loss function for ColPali2.
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from torch import nn
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel

from colpali_engine.models.late_interaction.colpali_2.colpali_2_config import ColPali2Config
from colpali_engine.compression.pooling.multi_vector_pooler import MultiVectorPooler
from colpali_engine.models.late_interaction.colpali_2.colpali_2_modeling_outputs import ColPali2ModelOutput
from colpali_engine.models.late_interaction.colpali_2.colpali_2_utils import MultiVectorPooler
from colpali_engine.models.paligemma.colpali_2.configuration_colpali_2 import ColPali2Config


class ColPali2(PaliGemmaPreTrainedModel):
Expand Down
96 changes: 96 additions & 0 deletions colpali_engine/models/paligemma/colpali_2/modeling_colpali_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from dataclasses import dataclass
from typing import cast

import torch
from torch import nn
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel

from colpali_engine.compression.pooling.multi_vector_pooler import MultiVectorPooler
from colpali_engine.models.paligemma.colpali_2.configuration_colpali_2 import ColPali2Config


@dataclass
class ColPali2ModelOutput:
single_vec_emb: torch.Tensor
multi_vec_emb: torch.Tensor


class ColPali2(PaliGemmaPreTrainedModel):
def __init__(self, config: ColPali2Config):
super(ColPali2, self).__init__(config=config)

self.config = cast(ColPali2Config, self.config)
self.model = PaliGemmaForConditionalGeneration(self.config.vlm_config)

self.single_vector_projector = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
self.multi_vector_pooler = MultiVectorPooler(pooling_strategy=self.config.single_vector_pool_strategy)
self.multi_vector_projector = nn.Linear(self.model.config.text_config.hidden_size, self.dim)

self.main_input_name = "doc_input_ids"

@property
def single_vector_projector_dim(self) -> int:
return self.config.single_vector_projector_dim

@property
def multi_vector_projector_dim(self) -> int:
return self.config.multi_vector_projector_dim

def forward(self, *args, **kwargs) -> ColPali2ModelOutput:
"""
Forward pass through ColPali. Returns both single-vector and multi-vector embeddings.
NOTE: Both the text and image processors should prepend the <CLS> token to the input_ids tensor
before passing it to the model.
Args:
- input_ids (torch.LongTensor): The input tokens tensor.
- attention_mask (torch.LongTensor): The attention mask tensor.
Returns:
- ColPaliModelOutput:
- single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
- multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
"""

# Forward pass through the VLM
vlm_outputs = self.model(*args, output_hidden_states=True, **kwargs)
vlm_last_hidden_states = vlm_outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)

# Head 1: Single-vector embedding
pooled_output = self.multi_vector_pooler(vlm_last_hidden_states) # (batch_size, hidden_size)
single_vec_emb = self.single_vector_projector(pooled_output) # (batch_size, hidden_size)
single_vec_emb = torch.nn.functional.normalize(single_vec_emb, dim=-1)

# Head 2: Multi-vector embedding
multi_vec_emb = self.multi_vector_projector(
vlm_last_hidden_states[:, 1:, :]
) # (batch_size, sequence_length, hidden_size)
multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1)
multi_vec_emb = multi_vec_emb * kwargs["attention_mask"].unsqueeze(-1)

return ColPali2ModelOutput(single_vec_emb=single_vec_emb, multi_vec_emb=multi_vec_emb)

def forward_single_vector(self, *args, **kwargs) -> torch.Tensor:
"""
Forward pass through ColPali. Returns only the single-vector embeddings.
"""
vlm_outputs = self.model(*args, output_hidden_states=True, **kwargs)
pooled_output = self.multi_vector_pooler(vlm_outputs.hidden_states[-1]) # (batch_size, hidden_size)
single_vec_emb = self.single_vector_projector(pooled_output) # (batch_size, hidden_size)
single_vec_emb = torch.nn.functional.normalize(single_vec_emb, dim=-1)

return single_vec_emb

def forward_multi_vector(self, *args, **kwargs) -> torch.Tensor:
"""
Forward pass through ColPali. Returns only the multi-vector embeddings.
"""
vlm_outputs = self.model(*args, output_hidden_states=True, **kwargs)
multi_vec_emb = self.multi_vector_projector(
vlm_outputs.hidden_states[-1]
) # (batch_size, sequence_length, hidden_size)
multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1)
multi_vec_emb = multi_vec_emb * kwargs["attention_mask"].unsqueeze(-1)

return multi_vec_emb

0 comments on commit dc28fe0

Please sign in to comment.