Skip to content

Commit

Permalink
feat: tweaks for ColPali2
Browse files Browse the repository at this point in the history
  • Loading branch information
tonywu71 committed Sep 10, 2024
1 parent db8cd5b commit 19c70d4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@


class ColPali2Config(PretrainedConfig):
"""
Configuration for the ColPali2 model.
"""

def __init__(
self,
vlm_config: PaliGemmaConfig,
Expand Down
19 changes: 13 additions & 6 deletions colpali_engine/models/paligemma/colpali_2/modeling_colpali_2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import cast
from typing import ClassVar, cast

import torch
from torch import nn
Expand All @@ -16,17 +16,24 @@ class ColPali2ModelOutput:


class ColPali2(PaliGemmaPreTrainedModel):
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related

def __init__(self, config: ColPali2Config):
super(ColPali2, self).__init__(config=config)
super().__init__(config=config)

self.config = cast(ColPali2Config, self.config)
self.model = PaliGemmaForConditionalGeneration(self.config.vlm_config)

self.single_vector_projector = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
self.multi_vector_pooler = MultiVectorPooler(pooling_strategy=self.config.single_vector_pool_strategy)
self.multi_vector_projector = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
self.single_vector_projector = nn.Linear(
in_features=self.model.config.text_config.hidden_size,
out_features=self.config.single_vector_projector_dim,
)

self.main_input_name = "doc_input_ids"
self.multi_vector_pooler = MultiVectorPooler(pooling_strategy=self.config.single_vector_pool_strategy)
self.multi_vector_projector = nn.Linear(
in_features=self.model.config.text_config.hidden_size,
out_features=self.config.multi_vector_projector_dim,
)

@property
def single_vector_projector_dim(self) -> int:
Expand Down

0 comments on commit 19c70d4

Please sign in to comment.