Skip to content

Commit

Permalink
feat: add tests for ColPali2
Browse files Browse the repository at this point in the history
  • Loading branch information
tonywu71 committed Sep 10, 2024
1 parent 61ecbd1 commit f31b4a2
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 0 deletions.
194 changes: 194 additions & 0 deletions tests/models/paligemma/colpali_2/test_modeling_colpali_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import logging
from typing import Generator, List, cast

import pytest
import torch
from PIL import Image
from transformers import BatchFeature
from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig

from colpali_engine.models import ColPali2, ColPali2Config, ColPali2Processor
from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch

logger = logging.getLogger(__name__)


@pytest.fixture(scope="module")
def colpali_2_config() -> Generator[ColPali2Config, None, None]:
yield ColPali2Config(
vlm_config=cast(
PaliGemmaConfig,
PaliGemmaConfig.from_pretrained("google/paligemma-3b-mix-448"),
),
single_vector_projector_dim=128,
single_vector_pool_strategy="mean",
multi_vector_projector_dim=128,
)


@pytest.fixture(scope="module")
def colpali_2_from_config(colpali_2_config: ColPali2Config) -> Generator[ColPali2, None, None]:
device = get_torch_device("auto")
logger.info(f"Device used: {device}")

yield ColPali2(config=colpali_2_config)
tear_down_torch()


@pytest.skip("No model available in the hub yet")
@pytest.fixture(scope="module")
def colpali_2_model_path() -> str:
raise NotImplementedError("Please provide the path to the model in the hub")


@pytest.skip("No model available in the hub yet")
@pytest.fixture(scope="module")
def colpali_2_from_pretrained(colpali_2_model_path: str) -> Generator[ColPali2, None, None]:
device = get_torch_device("auto")
logger.info(f"Device used: {device}")

yield cast(
ColPali2,
ColPali2.from_pretrained(
colpali_2_model_path,
torch_dtype=torch.bfloat16,
device_map=device,
),
)
tear_down_torch()


@pytest.fixture(scope="class")
def processor() -> Generator[ColPali2Processor, None, None]:
yield cast(ColPali2Processor, ColPali2Processor.from_pretrained("google/paligemma-3b-mix-448"))


@pytest.fixture(scope="class")
def images() -> Generator[List[Image.Image], None, None]:
yield [
Image.new("RGB", (32, 32), color="white"),
Image.new("RGB", (16, 16), color="black"),
]


@pytest.fixture(scope="class")
def queries() -> Generator[List[str], None, None]:
yield [
"Does Manu like to play football?",
"Are Benjamin, Antoine, Merve, and Jo friends?",
"Is byaldi a dish or an awesome repository for RAG?",
]


@pytest.fixture(scope="function")
def batch_queries(processor: ColPali2Processor, queries: List[str]) -> Generator[BatchFeature, None, None]:
yield processor.process_queries(queries)


@pytest.fixture(scope="function")
def batch_images(processor: ColPali2Processor, images: List[Image.Image]) -> Generator[BatchFeature, None, None]:
yield processor.process_images(images)


class TestLoadColPali2:
"""
Test the different ways to load ColPali2.
"""

@pytest.mark.slow
def test_load_colpali_2_from_config(self, colpali_2_config: ColPali2Config):
device = get_torch_device("auto")
logger.info(f"Device used: {device}")

model = ColPali2(config=colpali_2_config)

assert isinstance(model, ColPali2)
assert model.single_vector_projector_dim == colpali_2_config.single_vector_projector_dim
assert model.multi_vector_pooler.pooling_strategy == colpali_2_config.single_vector_pool_strategy
assert model.multi_vector_projector_dim == colpali_2_config.multi_vector_projector_dim

tear_down_torch()

@pytest.mark.slow
def test_load_colpali_2_from_pretrained(self, colpali_2_from_config: ColPali2):
assert isinstance(colpali_2_from_config, ColPali2)


class TestForwardSingleVector:
"""
Test the forward pass of ColPali2 for single-vector embeddings.
"""

@pytest.mark.slow
def test_colpali_2_forward_images(
self,
colpali_2_from_config: ColPali2,
batch_images: BatchFeature,
):
# Forward pass
with torch.no_grad():
outputs = colpali_2_from_config.forward_single_vector(**batch_images)

# Assertions
assert isinstance(outputs, torch.Tensor)
assert outputs.dim() == 2
batch_size, emb_dim = outputs.shape
assert batch_size == len(batch_images["input_ids"])
assert emb_dim == colpali_2_from_config.single_vector_projector_dim

@pytest.mark.slow
def test_colpali_2_forward_queries(
self,
colpali_2_from_config: ColPali2,
batch_queries: BatchFeature,
):
# Forward pass
with torch.no_grad():
outputs = colpali_2_from_config.forward_single_vector(**batch_queries)

# Assertions
assert isinstance(outputs, torch.Tensor)
assert outputs.dim() == 2
batch_size, emb_dim = outputs.shape
assert batch_size == len(batch_queries["input_ids"])
assert emb_dim == colpali_2_from_config.single_vector_projector_dim


class TestForwardMultiVector:
"""
Test the forward pass of ColPali2 for multi-vector embeddings.
"""

@pytest.mark.slow
def test_colpali_2_forward_images(
self,
colpali_2_from_config: ColPali2,
batch_images: BatchFeature,
):
# Forward pass
with torch.no_grad():
outputs = colpali_2_from_config.forward_multi_vector(**batch_images)

# Assertions
assert isinstance(outputs, torch.Tensor)
assert outputs.dim() == 3
batch_size, n_visual_tokens, emb_dim = outputs.shape
assert batch_size == len(batch_images["input_ids"])
assert emb_dim == colpali_2_from_config.multi_vector_projector_dim

@pytest.mark.slow
def test_colpali_2_forward_queries(
self,
colpali_2_from_config: ColPali2,
batch_queries: BatchFeature,
):
# Forward pass
with torch.no_grad():
outputs = colpali_2_from_config.forward_multi_vector(**batch_queries)

# Assertions
assert isinstance(outputs, torch.Tensor)
assert outputs.dim() == 3
batch_size, n_query_tokens, emb_dim = outputs.shape
assert batch_size == len(batch_queries["input_ids"])
assert emb_dim == colpali_2_from_config.multi_vector_projector_dim
50 changes: 50 additions & 0 deletions tests/models/paligemma/colpali_2/test_processing_colpali_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Generator, cast

import pytest
import torch
from PIL import Image

from colpali_engine.models import ColPali2Processor


@pytest.fixture(scope="module")
def colpali_model_path() -> str:
return "google/paligemma-3b-mix-448"


@pytest.fixture(scope="module")
def processor_from_pretrained(colpali_model_path: str) -> Generator[ColPali2Processor, None, None]:
yield cast(ColPali2Processor, ColPali2Processor.from_pretrained(colpali_model_path))


def test_load_processor_from_pretrained(processor_from_pretrained: ColPali2Processor):
assert isinstance(processor_from_pretrained, ColPali2Processor)


def test_process_images(processor_from_pretrained: ColPali2Processor):
# Create a dummy image
image = Image.new("RGB", (16, 16), color="black")
images = [image]

# Process the image
batch_feature = processor_from_pretrained.process_images(images)

# Assertions
assert "pixel_values" in batch_feature
assert batch_feature["pixel_values"].shape == torch.Size([1, 3, 448, 448])


def test_process_queries(processor_from_pretrained: ColPali2Processor):
queries = [
"Does Manu like to play football?",
"Are Benjamin, Antoine, Merve, and Jo friends?",
"Is byaldi a dish or a nice repository for RAG?",
]

# Process the queries
batch_encoding = processor_from_pretrained.process_queries(queries)

# Assertions
assert "input_ids" in batch_encoding
assert isinstance(batch_encoding["input_ids"], torch.Tensor)
assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries)

0 comments on commit f31b4a2

Please sign in to comment.