-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wip: add ColPali2 (double-head architecture)
- Loading branch information
Showing
5 changed files
with
339 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import torch | ||
from torch import nn | ||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel | ||
|
||
from colpali_engine.models.colpali_2.colpali_2_modeling_outputs import ColPali2ModelOutput | ||
|
||
|
||
class ColPali2(PaliGemmaPreTrainedModel): | ||
def __init__(self, config): | ||
super(ColPali2, self).__init__(config=config) | ||
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config) | ||
self.dim = 128 | ||
self.single_vector_projector = nn.Linear(self.model.config.text_config.hidden_size, self.dim) | ||
self.multi_vector_projector = nn.Linear(self.model.config.text_config.hidden_size, self.dim) | ||
self.main_input_name = "doc_input_ids" | ||
|
||
def forward(self, *args, **kwargs) -> ColPali2ModelOutput: | ||
""" | ||
Forward pass through ColPali. Returns both single-vector and multi-vector embeddings. | ||
NOTE: Both the text and image processors should prepend the <CLS> token to the input_ids tensor | ||
before passing it to the model. | ||
Args: | ||
- input_ids (torch.LongTensor): The input tokens tensor. | ||
- attention_mask (torch.LongTensor): The attention mask tensor. | ||
Returns: | ||
- ColPaliModelOutput: | ||
- single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim). | ||
- multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim). | ||
""" | ||
|
||
# Forward pass through the VLM | ||
vlm_outputs = self.model(*args, output_hidden_states=True, **kwargs) | ||
vlm_last_hidden_states = vlm_outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size) | ||
|
||
# Head 1: Single-vector embedding | ||
cls_last_hidden_state = vlm_last_hidden_states[:, 0, :] # (batch_size, hidden_size) | ||
single_vec_emb = self.single_vector_projector(cls_last_hidden_state) # (batch_size, hidden_size) | ||
single_vec_emb = torch.nn.functional.normalize(single_vec_emb, dim=-1) | ||
|
||
# Head 2: Multi-vector embedding | ||
multi_vec_emb = self.multi_vector_projector( | ||
vlm_last_hidden_states[:, 1:, :] | ||
) # (batch_size, sequence_length, hidden_size) | ||
multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1) | ||
multi_vec_emb = multi_vec_emb * kwargs["attention_mask"].unsqueeze(-1) | ||
|
||
return ColPali2ModelOutput(single_vector=single_vec_emb, multi_vector=multi_vec_emb) | ||
|
||
def forward_single_vector(self, *args, **kwargs): | ||
""" | ||
Forward pass through ColPali. Returns only the single-vector embeddings. | ||
""" | ||
vlm_outputs = self.model(*args, output_hidden_states=True, **kwargs) | ||
cls_last_hidden_state = vlm_outputs.hidden_states[-1][:, 0, :] # (batch_size, hidden_size) | ||
single_vec_emb = self.single_vector_projector(cls_last_hidden_state) # (batch_size, hidden_size) | ||
single_vec_emb = torch.nn.functional.normalize(single_vec_emb, dim=-1) | ||
|
||
return single_vec_emb | ||
|
||
def forward_multi_vector(self, *args, **kwargs): | ||
""" | ||
Forward pass through ColPali. Returns only the multi-vector embeddings. | ||
""" | ||
vlm_outputs = self.model(*args, output_hidden_states=True, **kwargs) | ||
multi_vec_emb = self.multi_vector_projector( | ||
vlm_outputs.hidden_stages[-1][:, 1:, :] | ||
) # (batch_size, sequence_length, hidden_size) | ||
multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1) | ||
multi_vec_emb = multi_vec_emb * kwargs["attention_mask"].unsqueeze(-1) | ||
|
||
return multi_vec_emb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Any, Dict, List | ||
|
||
import torch | ||
|
||
from colpali_engine.models.colpali_2.colpali_2_processor import ColPali2Processor | ||
|
||
|
||
class ColPali2Collator: | ||
def __init__( | ||
self, | ||
processor: ColPali2Processor, | ||
max_length: int = 2048, | ||
add_suffix: bool = False, | ||
): | ||
self.processor = processor | ||
self.image_token_id = None | ||
self.max_length = max_length | ||
self.suffix = "" | ||
if add_suffix: | ||
self.suffix = "\n" * 10 | ||
|
||
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | ||
# Placeholders | ||
texts_query = [] | ||
images = [] | ||
|
||
# Populate the placeholders | ||
for example in examples: | ||
if example["image"] is None: | ||
raise ValueError("Image is None - This collator does not support `None` images yet.") | ||
|
||
image = example["image"].convert("RGB") | ||
images.append(image) | ||
|
||
if example["query"] is None: | ||
texts_query.append(None) | ||
else: | ||
query = example["query"] | ||
query = f"Question: {query}" | ||
texts_query.append(query) | ||
|
||
# Process the documents | ||
batch_doc = self.processor.process_image( | ||
image=images, | ||
padding="longest", | ||
do_convert_rgb=True, | ||
return_tensors="pt", | ||
add_instruction_prompt=True, | ||
) | ||
|
||
# Process the queries | ||
batch_query = None | ||
|
||
# Check if some but not all queries are `None` | ||
if all([t is None for t in texts_query]): | ||
print("All queries are None. Returning `None` for all queries.") | ||
elif any([t is None for t in texts_query]): | ||
raise ValueError("Some queries are None. This collator does not support None queries yet.") | ||
else: | ||
batch_query = self.processor.process_image( | ||
image=images, | ||
padding="longest", | ||
do_convert_rgb=True, | ||
return_tensors="pt", | ||
add_instruction_prompt=True, | ||
) | ||
|
||
# Prefix each key in ouptut dict with "doc_" or "query_" to avoid key conflicts | ||
batch_all = {f"doc_{k}": v for k, v in batch_doc.items()} | ||
del batch_doc | ||
if batch_query is not None: | ||
batch_query = {f"query_{k}": v for k, v in batch_query.items()} | ||
batch_all.update(batch_query) | ||
del batch_query | ||
|
||
return batch_all |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F # noqa: N812 | ||
|
||
from colpali_engine.models.colpali_2.colpali_2_modeling_outputs import ColPali2ModelOutput | ||
|
||
|
||
class MatryoshkaCELoss(torch.nn.Module): | ||
""" | ||
Loss function for Matryoshka Representation Learning | ||
Adapted from https://github.com/RAIVNLab/MRL/blob/7ccb42df6be05f3d21d0648aa03099bba46386bf/MRL.py#L11 | ||
""" | ||
|
||
def __init__(self, relative_importance: Optional[List[float]] = None, **kwargs): | ||
super(MatryoshkaCELoss, self).__init__() | ||
self.criterion = torch.nn.CrossEntropyLoss(**kwargs) | ||
self.relative_importance = relative_importance | ||
|
||
def forward(self, output, target): | ||
# Calculate losses for each output and stack them. This is still O(N) | ||
losses = torch.stack([self.criterion(output_i, target) for output_i in output]) | ||
|
||
# Set relative_importance to 1 if not specified | ||
rel_importance = ( | ||
torch.ones_like(losses) if self.relative_importance is None else torch.tensor(self.relative_importance) | ||
) | ||
|
||
# Apply relative importance weights | ||
weighted_losses = rel_importance * losses | ||
return weighted_losses.sum() | ||
|
||
|
||
class ColPali2Loss(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.matryoshka_loss = MatryoshkaCELoss() | ||
self.alpha: float = 0.5 | ||
|
||
def single_vector_loss(self, query_embeddings, doc_embeddings) -> torch.Tensor: | ||
""" | ||
query_embeddings: (batch_size, dim) | ||
doc_embeddings: (batch_size, dim) | ||
""" | ||
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) | ||
|
||
loss_rowwise = self.matryoshka_loss(scores, torch.arange(scores.shape[0], device=scores.device)) | ||
return loss_rowwise | ||
|
||
def multi_vector_loss(self, query_embeddings, doc_embeddings) -> torch.Tensor: | ||
""" | ||
query_embeddings: (batch_size, num_query_tokens, dim) | ||
doc_embeddings: (batch_size, num_doc_tokens, dim) | ||
""" | ||
# Compute the ColBERT scores | ||
scores = ( | ||
torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2) | ||
) # (batch_size, batch_size) | ||
|
||
# Positive scores are the diagonal of the scores matrix. | ||
pos_scores = scores.diagonal() # (batch_size,) | ||
|
||
# Negative score for a given query is the maximum of the scores against all all other pages. | ||
# NOTE: We exclude the diagonal by setting it to a very low value: since we know the maximum score is 1, | ||
# we can subtract 1 from the diagonal to exclude it from the maximum operation. | ||
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size) | ||
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,) | ||
|
||
# Compute the loss | ||
# The loss is computed as the negative log of the softmax of the positive scores | ||
# relative to the negative scores. | ||
# This can be simplified to log-sum-exp of negative scores minus the positive score | ||
# for numerical stability. | ||
loss = F.softplus(neg_scores - pos_scores).mean() | ||
|
||
return loss | ||
|
||
def forward(self, query_embeddings: ColPali2ModelOutput, doc_embeddings: ColPali2ModelOutput) -> torch.Tensor: | ||
single_vector_loss = self.single_vector_loss(query_embeddings.single_vector, doc_embeddings.single_vector) | ||
multi_vector_loss = self.multi_vector_loss(query_embeddings.multi_vector, doc_embeddings.multi_vector) | ||
total_loss = self.alpha * single_vector_loss + (1 - self.alpha) * multi_vector_loss | ||
return total_loss |
9 changes: 9 additions & 0 deletions
9
colpali_engine/models/colpali_2/colpali_2_modeling_outputs.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
|
||
|
||
@dataclass | ||
class ColPali2ModelOutput: | ||
single_vector: torch.Tensor | ||
multi_vector: torch.Tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from __future__ import annotations | ||
|
||
from typing import List, cast | ||
|
||
import torch | ||
from PIL import Image | ||
from transformers import BatchFeature, LlamaTokenizerFast, PaliGemmaProcessor | ||
|
||
|
||
class ColPali2Processor(PaliGemmaProcessor): | ||
def __init__(self, processor: PaliGemmaProcessor, cls_token: str = "<unused1>"): | ||
self.processor = processor | ||
self.tokenizer = cast(LlamaTokenizerFast, self.processor.tokenizer) # type: ignore | ||
self.special_tokens_map = self.tokenizer.special_tokens_map | ||
self.cls_token = cls_token | ||
if self.cls_token not in self.tokenizer.added_tokens_decoder: | ||
raise ValueError(f"The tokenizer should have an `{cls_token}` token to be used as the <cls> token.") | ||
self.special_tokens_map["cls_token"] = self.cls_token | ||
self.cls_token_id = cast(int, self.tokenizer.convert_tokens_to_ids(self.cls_token)) | ||
|
||
def process_text( | ||
self, | ||
text: str | List[str], | ||
padding: str = "longest", | ||
return_tensors: str = "pt", | ||
add_special_tokens: bool = True, | ||
) -> BatchFeature: | ||
""" | ||
Process text inputs for the model. | ||
If `add_special_tokens` is True (default), the text will be prepended with the <bos> token and appended with " \n". | ||
""" | ||
if add_special_tokens: | ||
if isinstance(text, str): | ||
text = self.tokenizer.bos_token + text + "\n" | ||
elif isinstance(text, list): | ||
text = [self.tokenizer.bos_token + t + "\n" for t in text] | ||
else: | ||
raise ValueError("text must be a string or a list of strings.") | ||
|
||
tokenized_outputs = self.tokenizer( | ||
text, padding=padding, return_tensors=return_tensors, add_special_tokens=add_special_tokens | ||
) | ||
|
||
return BatchFeature( | ||
data={ | ||
"input_ids": tokenized_outputs["input_ids"], | ||
"attention_mask": tokenized_outputs["attention_mask"], | ||
} | ||
) | ||
|
||
def process_image( | ||
self, | ||
image: Image.Image | List[Image.Image], | ||
padding: str = "longest", | ||
do_convert_rgb: bool = True, | ||
return_tensors: str = "pt", | ||
add_instruction_prompt: bool = True, | ||
) -> BatchFeature: | ||
# NOTE: The special prompt was used at training time. If used, it will be appended at the end of the input_ids. | ||
special_prompt = "Describe the image." if add_instruction_prompt else None | ||
|
||
if isinstance(image, Image.Image): | ||
text_input = [special_prompt] | ||
elif isinstance(image, list): | ||
text_input = [special_prompt] * len(image) | ||
else: | ||
raise ValueError("image must be a PIL Image or a list of PIL Images.") | ||
|
||
batch_output = self.processor( | ||
text=text_input, # type: ignore | ||
images=image, | ||
padding=padding, | ||
do_convert_rgb=do_convert_rgb, | ||
return_tensors=return_tensors, | ||
) | ||
|
||
batch_output["input_ids"] = batch_output["input_ids"][:, : self.processor.image_seq_length] | ||
batch_output["pixel_values"] = batch_output["pixel_values"][:, : self.processor.image_seq_length] | ||
batch_output["attention_mask"] = batch_output["attention_mask"][:, : self.processor.image_seq_length] | ||
|
||
return batch_output | ||
|
||
def decode(self, *args, **kwargs): | ||
return self.tokenizer.decode(*args, **kwargs) | ||
|
||
def batch_decode(self, *args, **kwargs): | ||
return self.tokenizer.batch_decode(*args, **kwargs) | ||
|
||
def is_cls_token_first(self, input_ids: torch.Tensor) -> bool: | ||
""" | ||
Check if the first token in each sequence of the batch is the CLS token. | ||
Inputs: | ||
- input_ids (torch.Tensor): The input_ids tensor (batch_size, sequence_length). | ||
""" | ||
if input_ids.dim() != 2: | ||
raise ValueError("`input_ids` must be a 2D tensor.") | ||
return cast(bool, torch.all(input_ids[:, 0] == self.cls_token_id).item()) |