Skip to content

Commit

Permalink
Merge pull request #748 from Mirascope/fix-746
Browse files Browse the repository at this point in the history
Fix prompt template handling of PIL images
  • Loading branch information
willbakst authored Dec 17, 2024
2 parents cca29dc + 3632f17 commit caa009e
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 33 deletions.
2 changes: 2 additions & 0 deletions mirascope/core/base/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ._messages_decorator import MessagesDecorator, messages_decorator
from ._parse_content_template import parse_content_template
from ._parse_prompt_messages import parse_prompt_messages
from ._pil_image_to_bytes import pil_image_to_bytes
from ._protocols import (
AsyncCreateFn,
CalculateCost,
Expand Down Expand Up @@ -77,6 +78,7 @@
"messages_decorator",
"parse_content_template",
"parse_prompt_messages",
"pil_image_to_bytes",
"SetupCall",
"setup_call",
"setup_extract_tool",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pydantic import BaseModel
from typing_extensions import TypeIs

from .._utils._get_image_type import get_image_type
from ..message_param import (
AudioPart,
BaseMessageParam,
Expand All @@ -18,6 +17,7 @@
TextPart,
)
from ..types import AudioSegment, Image, has_pil_module, has_pydub_module
from ._pil_image_to_bytes import pil_image_to_bytes

SAMPLE_WIDTH = 2
FRAME_RATE = 24000
Expand All @@ -43,11 +43,15 @@ def _convert_message_sequence_part_to_content_part(
):
return message_sequence_part
elif has_pil_module and isinstance(message_sequence_part, Image.Image):
image = message_sequence_part.tobytes()
media_type = (
Image.MIME[message_sequence_part.format]
if message_sequence_part.format
else "image/unknown"
)
return ImagePart(
type="image",
media_type=f"image/{get_image_type(image)}",
image=image,
media_type=media_type,
image=pil_image_to_bytes(message_sequence_part),
detail=None,
)
elif has_pydub_module and isinstance(message_sequence_part, AudioSegment):
Expand Down
17 changes: 14 additions & 3 deletions mirascope/core/base/_utils/_parse_content_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
ImagePart,
TextPart,
)
from ..types import Image, has_pil_module
from ._format_template import format_template
from ._get_audio_type import get_audio_type
from ._get_document_type import get_document_type
from ._get_image_type import get_image_type
from ._pil_image_to_bytes import pil_image_to_bytes

_PartType = Literal[
"text", "texts", "image", "images", "audio", "audios", "cache_control"
Expand Down Expand Up @@ -85,15 +87,24 @@ def _load_media(source: str | bytes) -> bytes:


def _construct_image_part(
source: str | bytes, options: dict[str, str] | None
source: str | bytes | Image.Image, options: dict[str, str] | None
) -> ImagePart:
image = _load_media(source)
if isinstance(source, Image.Image):
image = pil_image_to_bytes(source)
media_type = (
Image.MIME[source.format]
if has_pil_module and source.format
else "image/unknown"
)
else:
image = _load_media(source)
media_type = f"image/{get_image_type(image)}"
detail = None
if options:
detail = options.get("detail", None)
return ImagePart(
type="image",
media_type=f"image/{get_image_type(image)}",
media_type=media_type,
image=image,
detail=detail,
)
Expand Down
13 changes: 13 additions & 0 deletions mirascope/core/base/_utils/_pil_image_to_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from io import BytesIO

from ..types import Image


def pil_image_to_bytes(image: Image.Image) -> bytes:
try:
image_bytes = BytesIO()
image.save(image_bytes, format=image.format if image.format else None)
image_bytes.seek(0)
return image_bytes.read()
except Exception as e:
raise ValueError(f"Error converting image to bytes: {e}") from e
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "mirascope"
version = "1.13.0"
version = "1.13.1"
description = "LLM abstractions that aren't obstructions"
readme = "README.md"
license = { file = "LICENSE" }
Expand Down
17 changes: 10 additions & 7 deletions tests/core/base/_utils/test_convert_messages_to_message_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import wave
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest

Expand Down Expand Up @@ -72,9 +72,16 @@ def test_convert_message_sequence_part_to_content_part_with_document_part():
assert result == input_value


def test_convert_message_sequence_part_to_content_part_with_pil_image():
@patch(
"mirascope.core.base._utils._convert_messages_to_message_params.pil_image_to_bytes",
new_callable=MagicMock,
)
def test_convert_message_sequence_part_to_content_part_with_pil_image(
mock_pil_image_to_bytes: MagicMock,
):
mock_pil_image_to_bytes.return_value = b"image_bytes"
mock_image_instance = Mock()
mock_image_instance.tobytes.return_value = b"image_bytes"
mock_image_instance.format = "PNG"

from PIL import Image

Expand All @@ -87,10 +94,6 @@ def test_convert_message_sequence_part_to_content_part_with_pil_image():
"mirascope.core.base._utils._convert_messages_to_message_params.Image.Image",
Mock,
),
patch(
"mirascope.core.base._utils._convert_messages_to_message_params.get_image_type",
return_value="png",
),
patch(
"mirascope.core.base._utils._convert_messages_to_message_params.isinstance",
side_effect=lambda obj, cls: True
Expand Down
19 changes: 10 additions & 9 deletions tests/core/base/_utils/test_message_decorator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, cast
from unittest.mock import Mock
from unittest.mock import MagicMock, patch

import pytest
from PIL import Image
Expand Down Expand Up @@ -138,13 +138,14 @@ def recommend_book(genre: str) -> BaseMessageParam:
assert result[0].content == "hello! recommend a fantasy book"


@pytest.fixture
def mock_image():
return Mock(spec=Image.Image)


def test_multimodal_return(mock_image):
mock_image.tobytes.return_value = b"\xff\xd8\xff" # JPEG magic number
@patch(
"mirascope.core.base._utils._convert_messages_to_message_params.pil_image_to_bytes",
new_callable=MagicMock,
)
def test_multimodal_return(mock_pil_image_to_bytes: MagicMock):
mock_pil_image_to_bytes.return_value = b"image_bytes"
mock_image = MagicMock(spec=Image.Image)
mock_image.format = "PNG"

@messages_decorator()
def recommend_book(previous_book: Image.Image) -> list[Any]:
Expand All @@ -159,7 +160,7 @@ def recommend_book(previous_book: Image.Image) -> list[Any]:
assert len(result[0].content) == 3
assert result[0].content[0] == TextPart(type="text", text="I just read this book:")
assert result[0].content[1] == ImagePart(
type="image", media_type="image/jpeg", image=b"\xff\xd8\xff", detail=None
type="image", media_type="image/png", image=b"image_bytes", detail=None
)
assert result[0].content[2] == TextPart(
type="text", text="What should I read next?"
Expand Down
46 changes: 37 additions & 9 deletions tests/core/base/_utils/test_parse_content_template.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Tests the `_utils.parse_content_template` function."""

from io import BytesIO
from unittest.mock import MagicMock, patch

import pytest
from PIL import Image

from mirascope.core.base._utils._parse_content_template import parse_content_template
from mirascope.core.base.message_param import (
Expand All @@ -28,17 +30,21 @@ def test_parse_content_template() -> None:
assert parse_content_template("user", template, values) == expected


@pytest.fixture
def mock_jpeg_bytes() -> bytes:
return b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.' \",#\x1c\x1c(7),01444\x1f'9=82<.342\xff\xc0\x00\x0b\x08\x00\x01\x00\x01\x01\x01\x11\x00\xff\xc4\x00\x1f\x00\x00\x01\x05\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x10\x00\x02\x01\x03\x03\x02\x04\x03\x05\x05\x04\x04\x00\x00\x01}\x01\x02\x03\x00\x04\x11\x05\x12!1A\x06\x13Qa\x07\"q\x142\x81\x91\xa1\x08#B\xb1\xc1\x15R\xd1\xf0$3br\x82\t\n\x16\x17\x18\x19\x1a%&'()*456789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xff\xda\x00\x08\x01\x01\x00\x00?\x00\xf9\xfe\xbf\xff\xd9"


@patch(
"mirascope.core.base._utils._parse_content_template.open", new_callable=MagicMock
)
@patch("urllib.request.urlopen", new_callable=MagicMock)
def test_parse_content_template_images(
mock_urlopen: MagicMock, mock_open: MagicMock
mock_urlopen: MagicMock, mock_open: MagicMock, mock_jpeg_bytes: bytes
) -> None:
"""Test the parse_content_template function with image templates."""
image_data = b"\xff\xd8\xffimage data"
mock_response = MagicMock()
mock_response.read = lambda: image_data
mock_response.read = lambda: mock_jpeg_bytes
mock_urlopen.return_value.__enter__.return_value = mock_response
mock_open.return_value.__enter__.return_value = mock_response
template = "Analyze this image: {url:image}"
Expand All @@ -47,21 +53,35 @@ def test_parse_content_template_images(
content=[
TextPart(type="text", text="Analyze this image:"),
ImagePart(
type="image", media_type="image/jpeg", image=image_data, detail=None
type="image",
media_type="image/jpeg",
image=mock_jpeg_bytes,
detail=None,
),
],
)
assert parse_content_template("user", template, {"url": "https://"}) == expected
assert parse_content_template("user", template, {"url": "./image.jpg"}) == expected
assert parse_content_template("user", template, {"url": image_data}) == expected
assert (
parse_content_template("user", template, {"url": mock_jpeg_bytes}) == expected
)
assert (
parse_content_template(
"user", template, {"url": Image.open(BytesIO(mock_jpeg_bytes))}
)
== expected
)

template = "Analyze this image: {url:image(detail=low)}"
expected = BaseMessageParam(
role="user",
content=[
TextPart(type="text", text="Analyze this image:"),
ImagePart(
type="image", media_type="image/jpeg", image=image_data, detail="low"
type="image",
media_type="image/jpeg",
image=mock_jpeg_bytes,
detail="low",
),
],
)
Expand All @@ -73,10 +93,16 @@ def test_parse_content_template_images(
content=[
TextPart(type="text", text="Analyze these images:"),
ImagePart(
type="image", media_type="image/jpeg", image=image_data, detail=None
type="image",
media_type="image/jpeg",
image=mock_jpeg_bytes,
detail=None,
),
ImagePart(
type="image", media_type="image/jpeg", image=image_data, detail=None
type="image",
media_type="image/jpeg",
image=mock_jpeg_bytes,
detail=None,
),
],
)
Expand All @@ -91,7 +117,9 @@ def test_parse_content_template_images(
== expected
)
assert (
parse_content_template("user", template, {"urls": [image_data, image_data]})
parse_content_template(
"user", template, {"urls": [mock_jpeg_bytes, mock_jpeg_bytes]}
)
== expected
)

Expand Down
34 changes: 34 additions & 0 deletions tests/core/base/_utils/test_pil_image_to_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from io import BytesIO

import pytest
from PIL import Image

from mirascope.core.base._utils._pil_image_to_bytes import pil_image_to_bytes


@pytest.fixture
def mock_jpeg_bytes() -> bytes:
return b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.' \",#\x1c\x1c(7),01444\x1f'9=82<.342\xff\xc0\x00\x0b\x08\x00\x01\x00\x01\x01\x01\x11\x00\xff\xc4\x00\x1f\x00\x00\x01\x05\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x10\x00\x02\x01\x03\x03\x02\x04\x03\x05\x05\x04\x04\x00\x00\x01}\x01\x02\x03\x00\x04\x11\x05\x12!1A\x06\x13Qa\x07\"q\x142\x81\x91\xa1\x08#B\xb1\xc1\x15R\xd1\xf0$3br\x82\t\n\x16\x17\x18\x19\x1a%&'()*456789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xff\xda\x00\x08\x01\x01\x00\x00?\x00\xf9\xfe\xbf\xff\xd9"


@pytest.fixture
def mock_pil_image(mock_jpeg_bytes: bytes) -> Image.Image:
return Image.open(BytesIO(mock_jpeg_bytes))


def test_pil_image_to_bytes_success_jpg(
mock_jpeg_bytes: bytes, mock_pil_image: Image.Image
):
"""Tests pil_image_to_bytes with a JPEG image (mocking format)."""
image_bytes = pil_image_to_bytes(mock_pil_image)
assert isinstance(image_bytes, bytes)
assert image_bytes == mock_jpeg_bytes


def test_pil_image_to_bytes_no_format(mock_pil_image: Image.Image):
"""
Tests pil_image_to_bytes with an image with no format attribute.
"""
mock_pil_image.format = None
with pytest.raises(ValueError, match="Error converting image to bytes:"):
pil_image_to_bytes(mock_pil_image)

0 comments on commit caa009e

Please sign in to comment.