Skip to content

Commit

Permalink
wip: add ColPali2 (double-head architecture)
Browse files Browse the repository at this point in the history
  • Loading branch information
tonywu71 committed Aug 21, 2024
1 parent 9413418 commit 82f1083
Show file tree
Hide file tree
Showing 5 changed files with 339 additions and 0 deletions.
74 changes: 74 additions & 0 deletions colpali_engine/models/colpali_2/colpali_2_architecture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
from torch import nn
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel

from colpali_engine.models.colpali_2.colpali_2_modeling_outputs import ColPali2ModelOutput


class ColPali2(PaliGemmaPreTrainedModel):
def __init__(self, config):
super(ColPali2, self).__init__(config=config)
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
self.dim = 128
self.single_vector_projector = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
self.multi_vector_projector = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
self.main_input_name = "doc_input_ids"

def forward(self, *args, **kwargs) -> ColPali2ModelOutput:
"""
Forward pass through ColPali. Returns both single-vector and multi-vector embeddings.
NOTE: Both the text and image processors should prepend the <CLS> token to the input_ids tensor
before passing it to the model.
Args:
- input_ids (torch.LongTensor): The input tokens tensor.
- attention_mask (torch.LongTensor): The attention mask tensor.
Returns:
- ColPaliModelOutput:
- single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
- multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
"""

# Forward pass through the VLM
vlm_outputs = self.model(*args, output_hidden_states=True, **kwargs)
vlm_last_hidden_states = vlm_outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)

# Head 1: Single-vector embedding
cls_last_hidden_state = vlm_last_hidden_states[:, 0, :] # (batch_size, hidden_size)
single_vec_emb = self.single_vector_projector(cls_last_hidden_state) # (batch_size, hidden_size)
single_vec_emb = torch.nn.functional.normalize(single_vec_emb, dim=-1)

# Head 2: Multi-vector embedding
multi_vec_emb = self.multi_vector_projector(
vlm_last_hidden_states[:, 1:, :]
) # (batch_size, sequence_length, hidden_size)
multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1)
multi_vec_emb = multi_vec_emb * kwargs["attention_mask"].unsqueeze(-1)

return ColPali2ModelOutput(single_vector=single_vec_emb, multi_vector=multi_vec_emb)

def forward_single_vector(self, *args, **kwargs):
"""
Forward pass through ColPali. Returns only the single-vector embeddings.
"""
vlm_outputs = self.model(*args, output_hidden_states=True, **kwargs)
cls_last_hidden_state = vlm_outputs.hidden_states[-1][:, 0, :] # (batch_size, hidden_size)
single_vec_emb = self.single_vector_projector(cls_last_hidden_state) # (batch_size, hidden_size)
single_vec_emb = torch.nn.functional.normalize(single_vec_emb, dim=-1)

return single_vec_emb

def forward_multi_vector(self, *args, **kwargs):
"""
Forward pass through ColPali. Returns only the multi-vector embeddings.
"""
vlm_outputs = self.model(*args, output_hidden_states=True, **kwargs)
multi_vec_emb = self.multi_vector_projector(
vlm_outputs.hidden_stages[-1][:, 1:, :]
) # (batch_size, sequence_length, hidden_size)
multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1)
multi_vec_emb = multi_vec_emb * kwargs["attention_mask"].unsqueeze(-1)

return multi_vec_emb
76 changes: 76 additions & 0 deletions colpali_engine/models/colpali_2/colpali_2_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Any, Dict, List

import torch

from colpali_engine.models.colpali_2.colpali_2_processor import ColPali2Processor


class ColPali2Collator:
def __init__(
self,
processor: ColPali2Processor,
max_length: int = 2048,
add_suffix: bool = False,
):
self.processor = processor
self.image_token_id = None
self.max_length = max_length
self.suffix = ""
if add_suffix:
self.suffix = "\n" * 10

def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
# Placeholders
texts_query = []
images = []

# Populate the placeholders
for example in examples:
if example["image"] is None:
raise ValueError("Image is None - This collator does not support `None` images yet.")

image = example["image"].convert("RGB")
images.append(image)

if example["query"] is None:
texts_query.append(None)
else:
query = example["query"]
query = f"Question: {query}"
texts_query.append(query)

# Process the documents
batch_doc = self.processor.process_image(
image=images,
padding="longest",
do_convert_rgb=True,
return_tensors="pt",
add_instruction_prompt=True,
)

# Process the queries
batch_query = None

# Check if some but not all queries are `None`
if all([t is None for t in texts_query]):
print("All queries are None. Returning `None` for all queries.")
elif any([t is None for t in texts_query]):
raise ValueError("Some queries are None. This collator does not support None queries yet.")
else:
batch_query = self.processor.process_image(
image=images,
padding="longest",
do_convert_rgb=True,
return_tensors="pt",
add_instruction_prompt=True,
)

# Prefix each key in ouptut dict with "doc_" or "query_" to avoid key conflicts
batch_all = {f"doc_{k}": v for k, v in batch_doc.items()}
del batch_doc
if batch_query is not None:
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
batch_all.update(batch_query)
del batch_query

return batch_all
82 changes: 82 additions & 0 deletions colpali_engine/models/colpali_2/colpali_2_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import List, Optional

import torch
import torch.nn.functional as F # noqa: N812

from colpali_engine.models.colpali_2.colpali_2_modeling_outputs import ColPali2ModelOutput


class MatryoshkaCELoss(torch.nn.Module):
"""
Loss function for Matryoshka Representation Learning
Adapted from https://github.com/RAIVNLab/MRL/blob/7ccb42df6be05f3d21d0648aa03099bba46386bf/MRL.py#L11
"""

def __init__(self, relative_importance: Optional[List[float]] = None, **kwargs):
super(MatryoshkaCELoss, self).__init__()
self.criterion = torch.nn.CrossEntropyLoss(**kwargs)
self.relative_importance = relative_importance

def forward(self, output, target):
# Calculate losses for each output and stack them. This is still O(N)
losses = torch.stack([self.criterion(output_i, target) for output_i in output])

# Set relative_importance to 1 if not specified
rel_importance = (
torch.ones_like(losses) if self.relative_importance is None else torch.tensor(self.relative_importance)
)

# Apply relative importance weights
weighted_losses = rel_importance * losses
return weighted_losses.sum()


class ColPali2Loss(torch.nn.Module):
def __init__(self):
super().__init__()
self.matryoshka_loss = MatryoshkaCELoss()
self.alpha: float = 0.5

def single_vector_loss(self, query_embeddings, doc_embeddings) -> torch.Tensor:
"""
query_embeddings: (batch_size, dim)
doc_embeddings: (batch_size, dim)
"""
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)

loss_rowwise = self.matryoshka_loss(scores, torch.arange(scores.shape[0], device=scores.device))
return loss_rowwise

def multi_vector_loss(self, query_embeddings, doc_embeddings) -> torch.Tensor:
"""
query_embeddings: (batch_size, num_query_tokens, dim)
doc_embeddings: (batch_size, num_doc_tokens, dim)
"""
# Compute the ColBERT scores
scores = (
torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
) # (batch_size, batch_size)

# Positive scores are the diagonal of the scores matrix.
pos_scores = scores.diagonal() # (batch_size,)

# Negative score for a given query is the maximum of the scores against all all other pages.
# NOTE: We exclude the diagonal by setting it to a very low value: since we know the maximum score is 1,
# we can subtract 1 from the diagonal to exclude it from the maximum operation.
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size)
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)

# Compute the loss
# The loss is computed as the negative log of the softmax of the positive scores
# relative to the negative scores.
# This can be simplified to log-sum-exp of negative scores minus the positive score
# for numerical stability.
loss = F.softplus(neg_scores - pos_scores).mean()

return loss

def forward(self, query_embeddings: ColPali2ModelOutput, doc_embeddings: ColPali2ModelOutput) -> torch.Tensor:
single_vector_loss = self.single_vector_loss(query_embeddings.single_vector, doc_embeddings.single_vector)
multi_vector_loss = self.multi_vector_loss(query_embeddings.multi_vector, doc_embeddings.multi_vector)
total_loss = self.alpha * single_vector_loss + (1 - self.alpha) * multi_vector_loss
return total_loss
9 changes: 9 additions & 0 deletions colpali_engine/models/colpali_2/colpali_2_modeling_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass

import torch


@dataclass
class ColPali2ModelOutput:
single_vector: torch.Tensor
multi_vector: torch.Tensor
98 changes: 98 additions & 0 deletions colpali_engine/models/colpali_2/colpali_2_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

from typing import List, cast

import torch
from PIL import Image
from transformers import BatchFeature, LlamaTokenizerFast, PaliGemmaProcessor


class ColPali2Processor(PaliGemmaProcessor):
def __init__(self, processor: PaliGemmaProcessor, cls_token: str = "<unused1>"):
self.processor = processor
self.tokenizer = cast(LlamaTokenizerFast, self.processor.tokenizer) # type: ignore
self.special_tokens_map = self.tokenizer.special_tokens_map
self.cls_token = cls_token
if self.cls_token not in self.tokenizer.added_tokens_decoder:
raise ValueError(f"The tokenizer should have an `{cls_token}` token to be used as the <cls> token.")
self.special_tokens_map["cls_token"] = self.cls_token
self.cls_token_id = cast(int, self.tokenizer.convert_tokens_to_ids(self.cls_token))

def process_text(
self,
text: str | List[str],
padding: str = "longest",
return_tensors: str = "pt",
add_special_tokens: bool = True,
) -> BatchFeature:
"""
Process text inputs for the model.
If `add_special_tokens` is True (default), the text will be prepended with the <bos> token and appended with " \n".
"""
if add_special_tokens:
if isinstance(text, str):
text = self.tokenizer.bos_token + text + "\n"
elif isinstance(text, list):
text = [self.tokenizer.bos_token + t + "\n" for t in text]
else:
raise ValueError("text must be a string or a list of strings.")

tokenized_outputs = self.tokenizer(
text, padding=padding, return_tensors=return_tensors, add_special_tokens=add_special_tokens
)

return BatchFeature(
data={
"input_ids": tokenized_outputs["input_ids"],
"attention_mask": tokenized_outputs["attention_mask"],
}
)

def process_image(
self,
image: Image.Image | List[Image.Image],
padding: str = "longest",
do_convert_rgb: bool = True,
return_tensors: str = "pt",
add_instruction_prompt: bool = True,
) -> BatchFeature:
# NOTE: The special prompt was used at training time. If used, it will be appended at the end of the input_ids.
special_prompt = "Describe the image." if add_instruction_prompt else None

if isinstance(image, Image.Image):
text_input = [special_prompt]
elif isinstance(image, list):
text_input = [special_prompt] * len(image)
else:
raise ValueError("image must be a PIL Image or a list of PIL Images.")

batch_output = self.processor(
text=text_input, # type: ignore
images=image,
padding=padding,
do_convert_rgb=do_convert_rgb,
return_tensors=return_tensors,
)

batch_output["input_ids"] = batch_output["input_ids"][:, : self.processor.image_seq_length]
batch_output["pixel_values"] = batch_output["pixel_values"][:, : self.processor.image_seq_length]
batch_output["attention_mask"] = batch_output["attention_mask"][:, : self.processor.image_seq_length]

return batch_output

def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)

def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)

def is_cls_token_first(self, input_ids: torch.Tensor) -> bool:
"""
Check if the first token in each sequence of the batch is the CLS token.
Inputs:
- input_ids (torch.Tensor): The input_ids tensor (batch_size, sequence_length).
"""
if input_ids.dim() != 2:
raise ValueError("`input_ids` must be a 2D tensor.")
return cast(bool, torch.all(input_ids[:, 0] == self.cls_token_id).item())

0 comments on commit 82f1083

Please sign in to comment.