Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pixtral] Improve image processing using cv2 resize instead of PIL #49

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 127 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mistral_common"
version = "1.4.0"
version = "1.4.1"
description = ""
authors = ["bam4d <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -35,6 +35,7 @@ typing-extensions = "^4.11.0"
tiktoken = "^0.7.0"
pillow = "^10.3.0"
requests = "^2.0.0"
opencv-python-headless = "^4.10.0.84"

[tool.poetry.group.dev.dependencies]
types-jsonschema = "4.21.0.20240118"
Expand Down
2 changes: 1 addition & 1 deletion src/mistral_common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.4.0"
__version__ = "1.4.1"
11 changes: 5 additions & 6 deletions src/mistral_common/tokens/tokenizers/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from io import BytesIO
from typing import Tuple, Union

import cv2
import numpy as np
from PIL import Image

Expand Down Expand Up @@ -54,22 +55,21 @@ def _convert_to_rgb(image: Image.Image) -> Image.Image:


def normalize(
image: Image.Image,
np_image: np.ndarray,
mean: Tuple[float, float, float],
std: Tuple[float, float, float],
) -> np.ndarray:
"""
Normalize a tensor image with mean and standard deviation.

Args:
image (Image.Image): Image to be normalized.
image (np.ndarray): Image to be normalized.
mean (tuple[float, float, float]): Mean for each channel.
std (tuple[float, float, float]): Standard deviation for each channel.

Returns:
np.ndarray: Normalized image with shape (C, H, W).
"""
np_image = np.array(image, dtype=np.float32)
np_image = np_image / 255.0

assert len(np_image.shape) == 3, f"{np_image.shape=}"
Expand All @@ -81,9 +81,8 @@ def normalize(


def transform_image(image: Image.Image, new_size: Tuple[int, int]) -> np.ndarray:
image = _convert_to_rgb(image)
image = image.resize(new_size, Image.Resampling.BICUBIC)
return normalize(image, DATASET_MEAN, DATASET_STD)
np_image = cv2.resize(np.array(_convert_to_rgb(image), dtype=np.float32), new_size, interpolation=cv2.INTER_CUBIC)
return normalize(np_image, DATASET_MEAN, DATASET_STD)


class ImageEncoder(MultiModalEncoder):
Expand Down
40 changes: 40 additions & 0 deletions tests/test_multimodal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
from io import BytesIO
from typing import Tuple

import numpy as np
import pytest
import requests
from mistral_common.protocol.instruct.messages import (
Expand Down Expand Up @@ -70,6 +72,44 @@ def test_image_encoder(mm_config: MultimodalConfig, special_token_ids: SpecialIm
assert len(tokens) == (w + 1) * h


@pytest.mark.parametrize("size", [(200, 311), (300, 212), (251, 1374), (1475, 477), (1344, 1544), (2133, 3422)])
def test_image_processing(
mm_config: MultimodalConfig, special_token_ids: SpecialImageIDs, size: Tuple[int, int]
) -> None:
mm_config.max_image_size = 1024
mm_encoder = ImageEncoder(mm_config, special_token_ids)

# all images with w,h >= 1024 should be resized to 1024
# else round to nearest multiple of 16
# all while keeping the aspect ratio
EXP_IMG_SIZES = {
(200, 311): (208, 320),
(300, 212): (304, 224),
(251, 1374): (192, 1024),
(1475, 477): (1024, 336),
(1344, 1544): (896, 1024),
(2133, 3422): (640, 1024),
}
# integration test to make sure the img processing stays 100% the same
EXP_IMG_SUM = {
(200, 311): 232038.65023772235,
(300, 212): 182668.98900347573,
(251, 1374): 726925.9371541862,
(1475, 477): 985935.4162606588,
(1344, 1544): 2982953.705365115,
(2133, 3422): 2304438.4010818982,
}

url = f"https://picsum.photos/id/237/{size[0]}/{size[1]}"

content = ImageURLChunk(image_url=url)

image = mm_encoder(content).image

assert image.transpose().shape[:2] == EXP_IMG_SIZES[size], image.transpose().shape[:2]
assert np.abs(image).sum() - EXP_IMG_SUM[size] < 1e-5, np.abs(image).sum()


def test_image_encoder_formats(mm_config: MultimodalConfig, special_token_ids: SpecialImageIDs) -> None:
mm_encoder = ImageEncoder(mm_config, special_token_ids)

Expand Down