-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
cache_control
filter to support prompt caching in Anthrop…
…ic (#16) * reorganize import paths * feat: add support for prompt caching * fix tests * more unit tests * fix linter * add filter to docs * fix docs * fix docs
- Loading branch information
Showing
18 changed files
with
226 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
import json | ||
from html.parser import HTMLParser | ||
|
||
from jinja2 import TemplateSyntaxError, nodes | ||
from jinja2.ext import Extension | ||
|
||
from banks.types import ChatMessage, ChatMessageContent, ContentBlock, ContentBlockType | ||
|
||
SUPPORTED_TYPES = ("system", "user") | ||
|
||
|
||
|
@@ -16,7 +18,7 @@ def chat(role: str): # pylint: disable=W0613 | |
will return a list of `ChatMessage` instances. | ||
Example: | ||
``` | ||
```jinja | ||
{% chat role="system" %} | ||
You are a helpful assistant. | ||
{% endchat %} | ||
|
@@ -28,7 +30,44 @@ def chat(role: str): # pylint: disable=W0613 | |
""" | ||
|
||
|
||
class ChatMessage(Extension): | ||
class _ContentBlockParser(HTMLParser): | ||
"""A parser used to extract text surrounded by `<content_block_txt>` and `</content_block_txt>` tags.""" | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._parse_block_content = False | ||
self._content_blocks: list[ContentBlock] = [] | ||
|
||
@property | ||
def content(self) -> ChatMessageContent: | ||
"""Returns ChatMessageContent data that can be directly assigned to ChatMessage.content. | ||
If only one block is present, this block is of type text and has no cache control set, we just | ||
return it as plain text for simplicity. | ||
""" | ||
if len(self._content_blocks) == 1: | ||
block = self._content_blocks[0] | ||
if block.type == "text" and block.cache_control is None: | ||
return block.text or "" | ||
|
||
return self._content_blocks | ||
|
||
def handle_starttag(self, tag, _): | ||
if tag == "content_block_txt": | ||
self._parse_block_content = True | ||
|
||
def handle_endtag(self, tag): | ||
if tag == "content_block_txt": | ||
self._parse_block_content = False | ||
|
||
def handle_data(self, data): | ||
if self._parse_block_content: | ||
self._content_blocks.append(ContentBlock.model_validate_json(data)) | ||
else: | ||
self._content_blocks.append(ContentBlock(type=ContentBlockType.text, text=data)) | ||
|
||
|
||
class ChatExtension(Extension): | ||
""" | ||
`chat` can be used to render prompt text as structured ChatMessage objects. | ||
|
@@ -85,4 +124,7 @@ def _store_chat_messages(self, role, caller): | |
""" | ||
Helper callback. | ||
""" | ||
return json.dumps({"role": role, "content": caller()}) | ||
parser = _ContentBlockParser() | ||
parser.feed(caller()) | ||
cm = ChatMessage(role=role, content=parser.content) | ||
return cm.model_dump_json() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from banks.filters.lemmatize import lemmatize | ||
from .cache_control import cache_control | ||
from .lemmatize import lemmatize | ||
|
||
__all__ = ("lemmatize",) | ||
__all__ = ( | ||
"cache_control", | ||
"lemmatize", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from banks.types import ContentBlock | ||
|
||
|
||
def cache_control(value: str, cache_type: str = "ephemeral") -> str: | ||
"""Wrap the filtered value into a ContentBlock with the proper cache_control field set. | ||
The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects. | ||
Example: | ||
```jinja | ||
{{ "This is a long, long text" | cache_control("ephemeral") }} | ||
This is short and won't be cached. | ||
``` | ||
Important: | ||
this filter marks the content to cache by surrounding it with `<content_block_txt>` and | ||
`</content_block_txt>`, so it's only useful when used within a `{% chat %}` block. | ||
""" | ||
block = ContentBlock.model_validate({"type": "text", "text": value, "cache_control": {"type": cache_type}}) | ||
return f"<content_block_txt>{block.model_dump_json()}</content_block_txt>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,51 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from typing import Any, Protocol, Self | ||
from enum import Enum | ||
|
||
from pydantic import BaseModel | ||
|
||
from .prompt import Prompt | ||
# pylint: disable=invalid-name | ||
|
||
|
||
class PromptRegistry(Protocol): # pragma: no cover | ||
"""Interface to be implemented by concrete prompt registries.""" | ||
class ContentBlockType(str, Enum): | ||
text = "text" | ||
image = "image" | ||
|
||
def get(self, *, name: str, version: str | None = None) -> Prompt: ... | ||
|
||
def set(self, *, prompt: Prompt, overwrite: bool = False) -> None: ... | ||
class MediaTypeBlockType(str, Enum): | ||
image_jpeg = "image/jpeg" | ||
image_png = "image/png" | ||
image_gif = "image/gif" | ||
image_webp = "image/webp" | ||
|
||
|
||
class PromptModel(BaseModel): | ||
"""Serializable representation of a Prompt.""" | ||
class CacheControl(BaseModel): | ||
type: str = "ephemeral" | ||
|
||
text: str | ||
name: str | None = None | ||
version: str | None = None | ||
metadata: dict[str, Any] | None = None | ||
|
||
@classmethod | ||
def from_prompt(cls: type[Self], prompt: Prompt) -> Self: | ||
return cls(text=prompt.raw, name=prompt.name, version=prompt.version, metadata=prompt.metadata) | ||
class Source(BaseModel): | ||
type: str = "base64" | ||
media_type: MediaTypeBlockType | ||
data: str | ||
|
||
class Config: | ||
use_enum_values = True | ||
|
||
|
||
class ContentBlock(BaseModel): | ||
type: ContentBlockType | ||
cache_control: CacheControl | None = None | ||
text: str | None = None | ||
source: Source | None = None | ||
|
||
class Config: | ||
use_enum_values = True | ||
|
||
|
||
ChatMessageContent = list[ContentBlock] | str | ||
|
||
|
||
class ChatMessage(BaseModel): | ||
role: str | ||
content: str | ChatMessageContent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{% chat role="user" %} | ||
The book to analize is included in the tags <book> and </book>. | ||
|
||
<book>{{ book | cache_control("ephemeral") }}</book> | ||
|
||
What is the title of this book? Only output the title. | ||
{% endchat %} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from banks.filters.cache_control import cache_control | ||
|
||
|
||
def test_cache_control(): | ||
res = cache_control("foo", "ephemeral") | ||
res = res.replace("<content_block_txt>", "") | ||
res = res.replace("</content_block_txt>", "") | ||
assert res == '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}' |
Oops, something went wrong.