Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ColPali with double-head architecture #22

Closed
76 changes: 76 additions & 0 deletions colpali_engine/collators/colpali_2_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Any, Dict, List

import torch

from colpali_engine.models.colpali_2.colpali_2_processor import ColPali2Processor


class ColPali2Collator:
def __init__(
self,
processor: ColPali2Processor,
max_length: int = 2048,
add_suffix: bool = False,
):
self.processor = processor
self.image_token_id = None
self.max_length = max_length
self.suffix = ""
if add_suffix:
self.suffix = "\n" * 10

def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
# Placeholders
texts_query = []
images = []

# Populate the placeholders
for example in examples:
if example["image"] is None:
raise ValueError("Image is None - This collator does not support `None` images yet.")

image = example["image"].convert("RGB")
images.append(image)

if example["query"] is None:
texts_query.append(None)
else:
query = example["query"]
query = f"Question: {query}"
texts_query.append(query)

# Process the documents
batch_doc = self.processor.process_image(
image=images,
padding="longest",
do_convert_rgb=True,
return_tensors="pt",
add_instruction_prompt=True,
)

# Process the queries
batch_query = None

# Check if some but not all queries are `None`
if all([t is None for t in texts_query]):
print("All queries are None. Returning `None` for all queries.")
elif any([t is None for t in texts_query]):
raise ValueError("Some queries are None. This collator does not support None queries yet.")
else:
batch_query = self.processor.process_image(
image=images,
padding="longest",
do_convert_rgb=True,
return_tensors="pt",
add_instruction_prompt=True,
)

# Prefix each key in ouptut dict with "doc_" or "query_" to avoid key conflicts
batch_all = {f"doc_{k}": v for k, v in batch_doc.items()}
del batch_doc
if batch_query is not None:
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
batch_all.update(batch_query)
del batch_query

return batch_all
1 change: 1 addition & 0 deletions colpali_engine/compression/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pooling import MultiVectorPooler
1 change: 1 addition & 0 deletions colpali_engine/compression/pooling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .multi_vector_pooler import MultiVectorPooler
44 changes: 44 additions & 0 deletions colpali_engine/compression/pooling/multi_vector_pooler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import ClassVar, List

import torch
import torch.nn as nn


class MultiVectorPooler(nn.Module):
supported_pooling_strategies: ClassVar[List[str]] = ["mean", "sum", "max"]

def __init__(self, pooling_strategy: str = "mean"):
"""
Initialize the MultiVectorPooler with a specified pooling strategy.

Args:
- pooling_strategy (SupportedPoolingType): The type of pooling to apply.
"""
super().__init__()

if pooling_strategy not in self.supported_pooling_strategies:
raise ValueError(
f"Unsupported pooling type: {pooling_strategy}. Use one of {self.supported_pooling_strategies}."
)
self.pooling_strategy = pooling_strategy

def forward(self, input_tensor) -> torch.Tensor:
"""
Apply the pooling operation on the input tensor.

Args:
- input_tensor (torch.Tensor): A 3D tensor with shape (batch_size, num_tokens, dim).

Returns:
- torch.Tensor: A 2D tensor with shape (batch_size, dim) after pooling.
"""
if self.pooling_strategy == "mean":
pooled_tensor = torch.mean(input_tensor, dim=1)
elif self.pooling_strategy == "sum":
pooled_tensor = torch.sum(input_tensor, dim=1)
elif self.pooling_strategy == "max":
pooled_tensor, _ = torch.max(input_tensor, dim=1)
else:
raise ValueError(f"Unsupported pooling strategy: {self.pooling_strategy}.")

return pooled_tensor
198 changes: 198 additions & 0 deletions colpali_engine/loss/colpali_2_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812

from colpali_engine.models.paligemma.colpali_2.modeling_colpali_2 import ColPali2ModelOutput


@dataclass(kw_only=True)
class ColPali2LossOutputs:
single_vector_loss: torch.Tensor
multi_vector_loss: torch.Tensor
distillation_loss: Optional[torch.Tensor] = None
total_loss: torch.Tensor


class MatryoshkaCELoss(torch.nn.Module):
"""
Loss function for Matryoshka Representation Learning.

Adapted from https://github.com/RAIVNLab/MRL/blob/7ccb42df6be05f3d21d0648aa03099bba46386bf/MRL.py#L11
"""

def __init__(self, relative_importance: Optional[List[float]] = None, **kwargs):
super(MatryoshkaCELoss, self).__init__()
self.criterion = torch.nn.CrossEntropyLoss(**kwargs)
self.relative_importance = relative_importance

def forward(self, output, target) -> torch.Tensor:
# Calculate losses for each output and stack them. This is still O(N)
losses = torch.stack([self.criterion(output_i, target) for output_i in output])

# Set relative_importance to 1 if not specified
rel_importance = (
torch.ones_like(losses) if self.relative_importance is None else torch.tensor(self.relative_importance)
)

# Apply relative importance weights
weighted_losses = rel_importance * losses
return weighted_losses.sum()


class ColPali2Loss(torch.nn.Module):
"""
Loss function for ColPali2.

The loss function is a combination of two losses:
1. Single-vector loss: Cross-entropy (with optional Matryoshka) loss between the query and document
single-vector embeddings.
2. Multi-vector loss: Margin loss between the query and document multi-vector embeddings.
"""

def __init__(
self,
alpha: float = 0.5,
use_matryoshka_loss: bool = True,
use_distillation_loss: bool = True,
beta: float = 0.5,
temperature: float = 2.0,
):
super().__init__()
self.alpha = alpha
self.use_matryoshka_loss = use_matryoshka_loss
self.use_distillation_loss = use_distillation_loss
self.beta = beta
self.temperature = temperature
self.single_vector_loss_fn = MatryoshkaCELoss() if self.use_matryoshka_loss else F.cross_entropy

def single_vector_loss(
self,
query_embeddings: torch.Tensor,
doc_embeddings: torch.Tensor,
return_scores: bool = False,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
"""
Loss function for the single-vector head.

query_embeddings: (batch_size, dim)
doc_embeddings: (batch_size, dim)
"""
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)

loss = self.single_vector_loss_fn(scores, torch.arange(scores.shape[0], device=scores.device)) # (1,)

if return_scores:
return loss, scores
else:
return loss

def multi_vector_loss(
self,
query_embeddings: torch.Tensor,
doc_embeddings: torch.Tensor,
return_scores: bool = False,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
"""
Loss function for the multi-vector head.

query_embeddings: (batch_size, num_query_tokens, dim)
doc_embeddings: (batch_size, num_doc_tokens, dim)

NOTE: If `return_scores` is True, the function will return only the positive scores, i.e.
the diagonal of the scores matrix.
"""
# Compute the ColBERT scores
scores = (
torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
) # (batch_size, batch_size)

# Positive scores are the diagonal of the scores matrix.
pos_scores = scores.diagonal() # (batch_size,)

# Negative score for a given query is the maximum of the scores against all all other pages.
# NOTE: We exclude the diagonal by setting it to a very low value: since we know the maximum score is 1,
# we can subtract 1 from the diagonal to exclude it from the maximum operation.
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size)
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)

# Compute the margin loss
loss = F.softplus(neg_scores - pos_scores).mean() # (1,)

if return_scores:
return loss, pos_scores
else:
return loss

def distillation_loss(
self,
teacher_scores: torch.Tensor,
student_scores: torch.Tensor,
teacher_score_upper_bound: int,
):
"""
Compute the distillation loss between the multi-vector head (teacher) and
the single-vector head (student).

Inputs:
- teacher_scores: (batch_size)
- student_scores: (batch_size)
- teacher_score_upper_bound: The upper bound of the teacher scores.
"""
kl_div_loss = nn.KLDivLoss(reduction="batchmean")

# NOTE: Both the teacher and student scores should be turned into log-probabilities before
# computing the KL-divergence.
# The embeddings are normalized, thus we know the lower and upper bounds of the scores:
# - Teacher: the multi-vector scores (MaxSim) are between 0 and N_q, N_q being the number of query tokens
# - Student: the single-vector scores are between -1 and 1.

# Convert the scores to log-probabilities
teacher_logits = torch.logit(teacher_scores / teacher_score_upper_bound, eps=1e-6)
student_logits = torch.logit(student_scores, eps=1e-6)

# NOTE:
# - KLDivLoss argument order is the opposite of the KL(·||·) mathematical function.
# - KLDivLoss expects log-probabilities for `input` to avoid underflow issues.
loss_kd = self.temperature**2 * kl_div_loss(
input=student_logits / self.temperature,
target=teacher_logits / self.temperature,
) # (1,)

return loss_kd

def forward(
self,
query_embeddings: ColPali2ModelOutput,
doc_embeddings: ColPali2ModelOutput,
) -> ColPali2LossOutputs:
"""
Compute the total loss for the ColPali2 model.
"""

single_vector_loss, single_vector_scores = self.single_vector_loss(
query_embeddings.single_vec_emb, doc_embeddings.single_vec_emb, return_scores=True
)
multi_vector_loss, multi_vector_scores = self.multi_vector_loss(
query_embeddings.multi_vec_emb, doc_embeddings.multi_vec_emb, return_scores=True
)

total_loss = self.alpha * single_vector_loss + (1 - self.alpha) * multi_vector_loss

distillation_loss = None
if self.use_distillation_loss:
distillation_loss = self.distillation_loss(
single_vector_scores,
multi_vector_scores,
teacher_score_upper_bound=query_embeddings.multi_vec_emb.shape[1], # TODO: find the correct upper bound
)
total_loss += self.beta * distillation_loss

return ColPali2LossOutputs(
single_vector_loss=single_vector_loss,
multi_vector_loss=multi_vector_loss,
distillation_loss=distillation_loss,
total_loss=total_loss,
)
11 changes: 10 additions & 1 deletion colpali_engine/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
from .idefics_2 import BiIdefics2, ColIdefics2, ColIdefics2Processor
from .paligemma import BiPali, BiPaliProj, ColPali, ColPaliProcessor
from .paligemma import (
BiPali,
BiPaliProj,
ColPali,
ColPaliProcessor,
ColPali2,
ColPali2Processor,
ColPali2ModelOutput,
ColPali2Config,
)
6 changes: 6 additions & 0 deletions colpali_engine/models/paligemma/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from .bipali import BiPali, BiPaliProj
from .colpali import ColPali, ColPaliProcessor
from .colpali_2 import (
ColPali2,
ColPali2Processor,
ColPali2ModelOutput,
ColPali2Config,
)
3 changes: 3 additions & 0 deletions colpali_engine/models/paligemma/colpali_2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .configuration_colpali_2 import ColPali2Config
from .modeling_colpali_2 import ColPali2, ColPali2ModelOutput
from .processing_colpali_2 import ColPali2Processor
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from transformers import PretrainedConfig
from transformers.models.paligemma.modeling_paligemma import PaliGemmaConfig


class ColPali2Config(PretrainedConfig):
"""
Configuration for the ColPali2 model.
"""

def __init__(
self,
vlm_config: PaliGemmaConfig,
single_vector_projector_dim: int = 128,
single_vector_pool_strategy: str = "mean",
multi_vector_projector_dim: int = 128,
**kwargs,
):
super().__init__(**kwargs)
self.vlm_config = vlm_config
self.single_vector_projector_dim = single_vector_projector_dim
self.single_vector_pool_strategy = single_vector_pool_strategy
self.multi_vector_projector_dim = multi_vector_projector_dim
Loading
Loading