-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add
image
tag filter to support images within a prompt for vi…
…sion models (#22) * first draft * add unit tests * removed unused import * remove useless test * docs
- Loading branch information
Showing
11 changed files
with
187 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from pathlib import Path | ||
from urllib.parse import urlparse | ||
|
||
from banks.types import ContentBlock, ImageUrl | ||
|
||
|
||
def _is_url(string: str) -> bool: | ||
result = urlparse(string) | ||
return all([result.scheme, result.netloc]) | ||
|
||
|
||
def image(value: str) -> str: | ||
"""Wrap the filtered value into a ContentBlock of type image. | ||
The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects. | ||
Example: | ||
```jinja | ||
Describe what you see | ||
{{ "path/to/image/file" | image }} | ||
``` | ||
Important: | ||
this filter marks the content to cache by surrounding it with `<content_block>` and | ||
`</content_block>`, so it's only useful when used within a `{% chat %}` block. | ||
""" | ||
if _is_url(value): | ||
image_url = ImageUrl(url=value) | ||
else: | ||
image_url = ImageUrl.from_path(Path(value)) | ||
|
||
block = ContentBlock.model_validate({"type": "image_url", "image_url": image_url}) | ||
return f"<content_block>{block.model_dump_json()}</content_block>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
import base64 | ||
from enum import Enum | ||
from inspect import Parameter, getdoc, signature | ||
from pathlib import Path | ||
from typing import Callable | ||
|
||
from pydantic import BaseModel | ||
|
@@ -26,9 +28,14 @@ class ImageUrl(BaseModel): | |
url: str | ||
|
||
@classmethod | ||
def from_base64(cls, media_type: str, base64_str: str): | ||
def from_base64(cls, media_type: str, base64_str: str) -> Self: | ||
return cls(url=f"data:{media_type};base64,{base64_str}") | ||
|
||
@classmethod | ||
def from_path(cls, file_path: Path) -> Self: | ||
with open(file_path, "rb") as image_file: | ||
return cls.from_base64("image/jpeg", base64.b64encode(image_file.read()).decode("utf-8")) | ||
|
||
|
||
class ContentBlock(BaseModel): | ||
type: ContentBlockType | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import json | ||
|
||
import pytest | ||
|
||
from banks.filters.image import _is_url, image | ||
|
||
|
||
def test_is_url(): | ||
"""Test the internal URL validation function""" | ||
assert _is_url("https://example.com/image.jpg") is True | ||
assert _is_url("http://example.com/image.jpg") is True | ||
assert _is_url("ftp://example.com/image.jpg") is True | ||
assert _is_url("not_a_url.jpg") is False | ||
assert _is_url("/path/to/image.jpg") is False | ||
assert _is_url("relative/path/image.jpg") is False | ||
assert _is_url("") is False | ||
assert _is_url("https:\\example.com/image.jpg") is False | ||
|
||
|
||
def test_image_with_url(): | ||
"""Test image filter with a URL input""" | ||
url = "https://example.com/image.jpg" | ||
result = image(url) | ||
|
||
# Verify the content block wrapper | ||
assert result.startswith("<content_block>") | ||
assert result.endswith("</content_block>") | ||
|
||
# Parse the JSON content | ||
json_content = result[15:-16] # Remove wrapper tags | ||
content_block = json.loads(json_content) | ||
|
||
assert content_block["type"] == "image_url" | ||
assert content_block["image_url"]["url"] == url | ||
|
||
|
||
def test_image_with_file_path(tmp_path): | ||
"""Test image filter with a file path input""" | ||
# Create a temporary test image file | ||
test_image = tmp_path / "test_image.jpg" | ||
test_content = b"fake image content" | ||
test_image.write_bytes(test_content) | ||
|
||
result = image(str(test_image)) | ||
|
||
# Verify the content block wrapper | ||
assert result.startswith("<content_block>") | ||
assert result.endswith("</content_block>") | ||
|
||
# Parse the JSON content | ||
json_content = result[15:-16] # Remove wrapper tags | ||
content_block = json.loads(json_content) | ||
|
||
assert content_block["type"] == "image_url" | ||
assert content_block["image_url"]["url"].startswith("data:image/jpeg;base64,") | ||
|
||
|
||
def test_image_with_nonexistent_file(): | ||
"""Test image filter with a nonexistent file path""" | ||
with pytest.raises(FileNotFoundError): | ||
image("nonexistent/image.jpg") | ||
|
||
|
||
def test_image_content_block_structure(): | ||
"""Test the structure of the generated content block""" | ||
url = "https://example.com/image.jpg" | ||
result = image(url) | ||
|
||
json_content = result[15:-16] # Remove wrapper tags | ||
content_block = json.loads(json_content) | ||
|
||
# Verify the content block has all expected fields | ||
assert set(content_block.keys()) >= {"type", "image_url"} | ||
assert content_block["type"] == "image_url" | ||
assert isinstance(content_block["image_url"], dict) | ||
assert "url" in content_block["image_url"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import base64 | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
from banks.types import ImageUrl | ||
|
||
|
||
def test_image_url_from_base64(): | ||
"""Test creating ImageUrl from base64 encoded data""" | ||
test_data = "Hello, World!" | ||
base64_data = base64.b64encode(test_data.encode()).decode("utf-8") | ||
media_type = "image/jpeg" | ||
|
||
image_url = ImageUrl.from_base64(media_type, base64_data) | ||
expected_url = f"data:{media_type};base64,{base64_data}" | ||
assert image_url.url == expected_url | ||
|
||
|
||
def test_image_url_from_path(tmp_path): | ||
"""Test creating ImageUrl from a file path""" | ||
# Create a temporary test image file | ||
test_image = tmp_path / "test_image.jpg" | ||
test_content = b"fake image content" | ||
test_image.write_bytes(test_content) | ||
|
||
image_url = ImageUrl.from_path(test_image) | ||
|
||
# Verify the URL starts with the expected data URI prefix | ||
assert image_url.url.startswith("data:image/jpeg;base64,") | ||
|
||
# Decode the base64 part and verify the content matches | ||
base64_part = image_url.url.split(",")[1] | ||
decoded_content = base64.b64decode(base64_part) | ||
assert decoded_content == test_content | ||
|
||
|
||
def test_image_url_from_path_nonexistent(): | ||
"""Test creating ImageUrl from a nonexistent file path""" | ||
with pytest.raises(FileNotFoundError): | ||
ImageUrl.from_path(Path("nonexistent.jpg")) |