Skip to content

Commit

Permalink
feat: Add image tag filter to support images within a prompt for vi…
Browse files Browse the repository at this point in the history
…sion models (#22)

* first draft

* add unit tests

* removed unused import

* remove useless test

* docs
  • Loading branch information
masci authored Nov 10, 2024
1 parent 0b240e6 commit 32c3841
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 15 deletions.
7 changes: 7 additions & 0 deletions docs/prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ provided by Jinja, Banks supports the following ones, specific for prompt engine
show_signature_annotations: false
heading_level: 3

::: banks.filters.image.image
options:
show_root_full_path: false
show_symbol_type_heading: false
show_signature_annotations: false
heading_level: 3

::: banks.filters.lemmatize.lemmatize
options:
show_root_full_path: false
Expand Down
6 changes: 4 additions & 2 deletions src/banks/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jinja2 import Environment, PackageLoader, select_autoescape

from .config import config
from .filters import cache_control, lemmatize, tool
from .filters import cache_control, image, lemmatize, tool


def _add_extensions(_env):
Expand Down Expand Up @@ -38,7 +38,9 @@ def _add_extensions(_env):


# Setup custom filters and defaults
env.filters["lemmatize"] = lemmatize
env.filters["cache_control"] = cache_control
env.filters["image"] = image
env.filters["lemmatize"] = lemmatize
env.filters["tool"] = tool

_add_extensions(env)
6 changes: 3 additions & 3 deletions src/banks/extensions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def content(self) -> ChatMessageContent:

return self._content_blocks

def handle_starttag(self, tag, _):
if tag == "content_block_txt":
def handle_starttag(self, tag, attrs): # noqa
if tag == "content_block":
self._parse_block_content = True

def handle_endtag(self, tag):
if tag == "content_block_txt":
if tag == "content_block":
self._parse_block_content = False

def handle_data(self, data):
Expand Down
2 changes: 2 additions & 0 deletions src/banks/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
#
# SPDX-License-Identifier: MIT
from .cache_control import cache_control
from .image import image
from .lemmatize import lemmatize
from .tool import tool

__all__ = (
"cache_control",
"image",
"lemmatize",
"tool",
)
6 changes: 3 additions & 3 deletions src/banks/filters/cache_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def cache_control(value: str, cache_type: str = "ephemeral") -> str:
```
Important:
this filter marks the content to cache by surrounding it with `<content_block_txt>` and
`</content_block_txt>`, so it's only useful when used within a `{% chat %}` block.
this filter marks the content to cache by surrounding it with `<content_block>` and
`</content_block>`, so it's only useful when used within a `{% chat %}` block.
"""
block = ContentBlock.model_validate({"type": "text", "text": value, "cache_control": {"type": cache_type}})
return f"<content_block_txt>{block.model_dump_json()}</content_block_txt>"
return f"<content_block>{block.model_dump_json()}</content_block>"
37 changes: 37 additions & 0 deletions src/banks/filters/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from pathlib import Path
from urllib.parse import urlparse

from banks.types import ContentBlock, ImageUrl


def _is_url(string: str) -> bool:
result = urlparse(string)
return all([result.scheme, result.netloc])


def image(value: str) -> str:
"""Wrap the filtered value into a ContentBlock of type image.
The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects.
Example:
```jinja
Describe what you see
{{ "path/to/image/file" | image }}
```
Important:
this filter marks the content to cache by surrounding it with `<content_block>` and
`</content_block>`, so it's only useful when used within a `{% chat %}` block.
"""
if _is_url(value):
image_url = ImageUrl(url=value)
else:
image_url = ImageUrl.from_path(Path(value))

block = ContentBlock.model_validate({"type": "image_url", "image_url": image_url})
return f"<content_block>{block.model_dump_json()}</content_block>"
9 changes: 8 additions & 1 deletion src/banks/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import base64
from enum import Enum
from inspect import Parameter, getdoc, signature
from pathlib import Path
from typing import Callable

from pydantic import BaseModel
Expand All @@ -26,9 +28,14 @@ class ImageUrl(BaseModel):
url: str

@classmethod
def from_base64(cls, media_type: str, base64_str: str):
def from_base64(cls, media_type: str, base64_str: str) -> Self:
return cls(url=f"data:{media_type};base64,{base64_str}")

@classmethod
def from_path(cls, file_path: Path) -> Self:
with open(file_path, "rb") as image_file:
return cls.from_base64("image/jpeg", base64.b64encode(image_file.read()).decode("utf-8"))


class ContentBlock(BaseModel):
type: ContentBlockType
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cache_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

def test_cache_control():
res = cache_control("foo", "ephemeral")
res = res.replace("<content_block_txt>", "")
res = res.replace("</content_block_txt>", "")
res = res.replace("<content_block>", "")
res = res.replace("</content_block>", "")
assert res == '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","image_url":null}'
8 changes: 4 additions & 4 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_content_block_parser_init():
def test_content_block_parser_single_with_cache_control():
p = _ContentBlockParser()
p.feed(
'<content_block_txt>{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}</content_block_txt>'
'<content_block>{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}</content_block>'
)
assert p.content == [
ContentBlock(type=ContentBlockType.text, cache_control=CacheControl(type="ephemeral"), text="foo", source=None)
Expand All @@ -39,15 +39,15 @@ def test_content_block_parser_single_with_cache_control():

def test_content_block_parser_single_no_cache_control():
p = _ContentBlockParser()
p.feed('<content_block_txt>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block_txt>')
p.feed('<content_block>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block>')
assert p.content == "foo"


def test_content_block_parser_multiple():
p = _ContentBlockParser()
p.feed(
'<content_block_txt>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block_txt>'
'<content_block_txt>{"type":"text","cache_control":null,"text":"bar","source":null}</content_block_txt>'
'<content_block>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block>'
'<content_block>{"type":"text","cache_control":null,"text":"bar","source":null}</content_block>'
)
assert p.content == [
ContentBlock(type=ContentBlockType.text, cache_control=None, text="foo", source=None),
Expand Down
76 changes: 76 additions & 0 deletions tests/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import json

import pytest

from banks.filters.image import _is_url, image


def test_is_url():
"""Test the internal URL validation function"""
assert _is_url("https://example.com/image.jpg") is True
assert _is_url("http://example.com/image.jpg") is True
assert _is_url("ftp://example.com/image.jpg") is True
assert _is_url("not_a_url.jpg") is False
assert _is_url("/path/to/image.jpg") is False
assert _is_url("relative/path/image.jpg") is False
assert _is_url("") is False
assert _is_url("https:\\example.com/image.jpg") is False


def test_image_with_url():
"""Test image filter with a URL input"""
url = "https://example.com/image.jpg"
result = image(url)

# Verify the content block wrapper
assert result.startswith("<content_block>")
assert result.endswith("</content_block>")

# Parse the JSON content
json_content = result[15:-16] # Remove wrapper tags
content_block = json.loads(json_content)

assert content_block["type"] == "image_url"
assert content_block["image_url"]["url"] == url


def test_image_with_file_path(tmp_path):
"""Test image filter with a file path input"""
# Create a temporary test image file
test_image = tmp_path / "test_image.jpg"
test_content = b"fake image content"
test_image.write_bytes(test_content)

result = image(str(test_image))

# Verify the content block wrapper
assert result.startswith("<content_block>")
assert result.endswith("</content_block>")

# Parse the JSON content
json_content = result[15:-16] # Remove wrapper tags
content_block = json.loads(json_content)

assert content_block["type"] == "image_url"
assert content_block["image_url"]["url"].startswith("data:image/jpeg;base64,")


def test_image_with_nonexistent_file():
"""Test image filter with a nonexistent file path"""
with pytest.raises(FileNotFoundError):
image("nonexistent/image.jpg")


def test_image_content_block_structure():
"""Test the structure of the generated content block"""
url = "https://example.com/image.jpg"
result = image(url)

json_content = result[15:-16] # Remove wrapper tags
content_block = json.loads(json_content)

# Verify the content block has all expected fields
assert set(content_block.keys()) >= {"type", "image_url"}
assert content_block["type"] == "image_url"
assert isinstance(content_block["image_url"], dict)
assert "url" in content_block["image_url"]
41 changes: 41 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import base64
from pathlib import Path

import pytest

from banks.types import ImageUrl


def test_image_url_from_base64():
"""Test creating ImageUrl from base64 encoded data"""
test_data = "Hello, World!"
base64_data = base64.b64encode(test_data.encode()).decode("utf-8")
media_type = "image/jpeg"

image_url = ImageUrl.from_base64(media_type, base64_data)
expected_url = f"data:{media_type};base64,{base64_data}"
assert image_url.url == expected_url


def test_image_url_from_path(tmp_path):
"""Test creating ImageUrl from a file path"""
# Create a temporary test image file
test_image = tmp_path / "test_image.jpg"
test_content = b"fake image content"
test_image.write_bytes(test_content)

image_url = ImageUrl.from_path(test_image)

# Verify the URL starts with the expected data URI prefix
assert image_url.url.startswith("data:image/jpeg;base64,")

# Decode the base64 part and verify the content matches
base64_part = image_url.url.split(",")[1]
decoded_content = base64.b64decode(base64_part)
assert decoded_content == test_content


def test_image_url_from_path_nonexistent():
"""Test creating ImageUrl from a nonexistent file path"""
with pytest.raises(FileNotFoundError):
ImageUrl.from_path(Path("nonexistent.jpg"))

0 comments on commit 32c3841

Please sign in to comment.