Skip to content

Commit

Permalink
feat: add cache_control filter to support prompt caching in Anthrop…
Browse files Browse the repository at this point in the history
…ic (#16)

* reorganize import paths

* feat: add support for prompt caching

* fix tests

* more unit tests

* fix linter

* add filter to docs

* fix docs

* fix docs
  • Loading branch information
masci authored Oct 6, 2024
1 parent fdb35d7 commit 9b350af
Show file tree
Hide file tree
Showing 18 changed files with 226 additions and 54 deletions.
17 changes: 12 additions & 5 deletions docs/prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ In addition to all the [builtin filters](https://jinja.palletsprojects.com/en/3.
provided by Jinja, Banks supports the following ones, specific for prompt engineering.


::: banks.filters.cache_control.cache_control
options:
show_root_full_path: false
show_symbol_type_heading: false
show_signature_annotations: false
heading_level: 3

::: banks.filters.lemmatize.lemmatize
options:
show_root_full_path: false
Expand Down Expand Up @@ -53,12 +60,12 @@ Insert into the prompt a canary word that can be checked later with `Prompt.cana
to ensure the original prompt was not leaked.

Example:
```python
from banks import Prompt
```python
from banks import Prompt

p = Prompt("{{canary_word}}Hello, World!")
p.text() ## outputs 'BANKS[5f0bbba4]Hello, World!'
```
p = Prompt("{{canary_word}}Hello, World!")
p.text() ## outputs 'BANKS[5f0bbba4]Hello, World!'
```

## Macros

Expand Down
3 changes: 2 additions & 1 deletion src/banks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: MIT
from .config import config
from .env import env
from .prompt import AsyncPrompt, ChatMessage, Prompt
from .prompt import AsyncPrompt, Prompt
from .types import ChatMessage

__all__ = ("env", "Prompt", "AsyncPrompt", "config", "ChatMessage")
7 changes: 4 additions & 3 deletions src/banks/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jinja2 import Environment, PackageLoader, select_autoescape

from .config import config
from .filters import lemmatize
from .filters import cache_control, lemmatize


def _add_extensions(_env):
Expand All @@ -13,13 +13,13 @@ def _add_extensions(_env):
For example, we use banks to manage the system prompt in `GenerateExtension`
"""
from .extensions.chat import ChatMessage # pylint: disable=import-outside-toplevel
from .extensions.chat import ChatExtension # pylint: disable=import-outside-toplevel
from .extensions.generate import GenerateExtension # pylint: disable=import-outside-toplevel
from .extensions.inference_endpoint import HFInferenceEndpointsExtension # pylint: disable=import-outside-toplevel

_env.add_extension(GenerateExtension)
_env.add_extension(HFInferenceEndpointsExtension)
_env.add_extension(ChatMessage)
_env.add_extension(ChatExtension)


# Init the Jinja env
Expand All @@ -37,4 +37,5 @@ def _add_extensions(_env):

# Setup custom filters and defaults
env.filters["lemmatize"] = lemmatize
env.filters["cache_control"] = cache_control
_add_extensions(env)
50 changes: 46 additions & 4 deletions src/banks/extensions/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import json
from html.parser import HTMLParser

from jinja2 import TemplateSyntaxError, nodes
from jinja2.ext import Extension

from banks.types import ChatMessage, ChatMessageContent, ContentBlock, ContentBlockType

SUPPORTED_TYPES = ("system", "user")


Expand All @@ -16,7 +18,7 @@ def chat(role: str): # pylint: disable=W0613
will return a list of `ChatMessage` instances.
Example:
```
```jinja
{% chat role="system" %}
You are a helpful assistant.
{% endchat %}
Expand All @@ -28,7 +30,44 @@ def chat(role: str): # pylint: disable=W0613
"""


class ChatMessage(Extension):
class _ContentBlockParser(HTMLParser):
"""A parser used to extract text surrounded by `<content_block_txt>` and `</content_block_txt>` tags."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._parse_block_content = False
self._content_blocks: list[ContentBlock] = []

@property
def content(self) -> ChatMessageContent:
"""Returns ChatMessageContent data that can be directly assigned to ChatMessage.content.
If only one block is present, this block is of type text and has no cache control set, we just
return it as plain text for simplicity.
"""
if len(self._content_blocks) == 1:
block = self._content_blocks[0]
if block.type == "text" and block.cache_control is None:
return block.text or ""

return self._content_blocks

def handle_starttag(self, tag, _):
if tag == "content_block_txt":
self._parse_block_content = True

def handle_endtag(self, tag):
if tag == "content_block_txt":
self._parse_block_content = False

def handle_data(self, data):
if self._parse_block_content:
self._content_blocks.append(ContentBlock.model_validate_json(data))
else:
self._content_blocks.append(ContentBlock(type=ContentBlockType.text, text=data))


class ChatExtension(Extension):
"""
`chat` can be used to render prompt text as structured ChatMessage objects.
Expand Down Expand Up @@ -85,4 +124,7 @@ def _store_chat_messages(self, role, caller):
"""
Helper callback.
"""
return json.dumps({"role": role, "content": caller()})
parser = _ContentBlockParser()
parser.feed(caller())
cm = ChatMessage(role=role, content=parser.content)
return cm.model_dump_json()
2 changes: 1 addition & 1 deletion src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def generate(model_name: str): # pylint: disable=W0613
`generate` can be used to call the LiteLLM API passing the tag text as a prompt and get back some content.
Example:
```
```jinja
{% generate "write a tweet with positive sentiment" "gpt-3.5-turbo" %}
Feeling grateful for all the opportunities that come my way! #positivity #productivity
```
Expand Down
2 changes: 1 addition & 1 deletion src/banks/extensions/inference_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class HFInferenceEndpointsExtension(Extension):
passing a prompt to get back some content.
Example:
```
```jinja
{% inference_endpoint "write a tweet with positive sentiment", "https://foo.aws.endpoints.huggingface.cloud" %}
Life is beautiful, full of opportunities & positivity
```
Expand Down
8 changes: 6 additions & 2 deletions src/banks/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from banks.filters.lemmatize import lemmatize
from .cache_control import cache_control
from .lemmatize import lemmatize

__all__ = ("lemmatize",)
__all__ = (
"cache_control",
"lemmatize",
)
24 changes: 24 additions & 0 deletions src/banks/filters/cache_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from banks.types import ContentBlock


def cache_control(value: str, cache_type: str = "ephemeral") -> str:
"""Wrap the filtered value into a ContentBlock with the proper cache_control field set.
The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects.
Example:
```jinja
{{ "This is a long, long text" | cache_control("ephemeral") }}
This is short and won't be cached.
```
Important:
this filter marks the content to cache by surrounding it with `<content_block_txt>` and
`</content_block_txt>`, so it's only useful when used within a `{% chat %}` block.
"""
block = ContentBlock.model_validate({"type": "text", "text": value, "cache_control": {"type": cache_type}})
return f"<content_block_txt>{block.model_dump_json()}</content_block_txt>"
12 changes: 6 additions & 6 deletions src/banks/filters/lemmatize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from banks.errors import MissingDependencyError

try:
from simplemma import text_lemmatizer
from simplemma import text_lemmatizer # type: ignore

SIMPLEMMA_AVAIL = True
except ImportError:
except ImportError: # pragma: no cover
SIMPLEMMA_AVAIL = False


Expand All @@ -17,10 +17,10 @@ def lemmatize(text: str) -> str:
to English.
Example:
```
{{"The dog is running" | lemmatize}}
"the dog be run"
```
```jinja
{{"The dog is running" | lemmatize}}
"the dog be run"
```
Note:
Simplemma must be manually installed to use this filter
Expand Down
29 changes: 23 additions & 6 deletions src/banks/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,20 @@
#
# SPDX-License-Identifier: MIT
import uuid
from typing import Any
from typing import Any, Protocol, Self

from pydantic import BaseModel, ValidationError

from .cache import DefaultCache, RenderCache
from .config import config
from .env import env
from .errors import AsyncError
from .types import ChatMessage
from .utils import generate_canary_word

DEFAULT_VERSION = "0"


class ChatMessage(BaseModel):
role: str
content: str


class BasePrompt:
def __init__(
self,
Expand Down Expand Up @@ -219,3 +215,24 @@ async def text(self, data: dict[str, Any] | None = None) -> str:
rendered: str = await self._template.render_async(data)
self._render_cache.set(data, rendered)
return rendered


class PromptRegistry(Protocol): # pragma: no cover
"""Interface to be implemented by concrete prompt registries."""

def get(self, *, name: str, version: str | None = None) -> Prompt: ...

def set(self, *, prompt: Prompt, overwrite: bool = False) -> None: ...


class PromptModel(BaseModel):
"""Serializable representation of a Prompt."""

text: str
name: str | None = None
version: str | None = None
metadata: dict[str, Any] | None = None

@classmethod
def from_prompt(cls: type[Self], prompt: Prompt) -> Self:
return cls(text=prompt.raw, name=prompt.name, version=prompt.version, metadata=prompt.metadata)
3 changes: 1 addition & 2 deletions src/banks/registries/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

from banks import Prompt
from banks.errors import InvalidPromptError, PromptNotFoundError
from banks.prompt import DEFAULT_VERSION
from banks.types import PromptModel
from banks.prompt import DEFAULT_VERSION, PromptModel

# Constants
DEFAULT_INDEX_NAME = "index.json"
Expand Down
3 changes: 1 addition & 2 deletions src/banks/registries/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from pydantic import BaseModel

from banks.errors import InvalidPromptError, PromptNotFoundError
from banks.prompt import Prompt
from banks.types import PromptModel
from banks.prompt import Prompt, PromptModel


class PromptRegistryIndex(BaseModel):
Expand Down
52 changes: 37 additions & 15 deletions src/banks/types.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,51 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from typing import Any, Protocol, Self
from enum import Enum

from pydantic import BaseModel

from .prompt import Prompt
# pylint: disable=invalid-name


class PromptRegistry(Protocol): # pragma: no cover
"""Interface to be implemented by concrete prompt registries."""
class ContentBlockType(str, Enum):
text = "text"
image = "image"

def get(self, *, name: str, version: str | None = None) -> Prompt: ...

def set(self, *, prompt: Prompt, overwrite: bool = False) -> None: ...
class MediaTypeBlockType(str, Enum):
image_jpeg = "image/jpeg"
image_png = "image/png"
image_gif = "image/gif"
image_webp = "image/webp"


class PromptModel(BaseModel):
"""Serializable representation of a Prompt."""
class CacheControl(BaseModel):
type: str = "ephemeral"

text: str
name: str | None = None
version: str | None = None
metadata: dict[str, Any] | None = None

@classmethod
def from_prompt(cls: type[Self], prompt: Prompt) -> Self:
return cls(text=prompt.raw, name=prompt.name, version=prompt.version, metadata=prompt.metadata)
class Source(BaseModel):
type: str = "base64"
media_type: MediaTypeBlockType
data: str

class Config:
use_enum_values = True


class ContentBlock(BaseModel):
type: ContentBlockType
cache_control: CacheControl | None = None
text: str | None = None
source: Source | None = None

class Config:
use_enum_values = True


ChatMessageContent = list[ContentBlock] | str


class ChatMessage(BaseModel):
role: str
content: str | ChatMessageContent
7 changes: 7 additions & 0 deletions tests/templates/cache.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{% chat role="user" %}
The book to analize is included in the tags <book> and </book>.

<book>{{ book | cache_control("ephemeral") }}</book>

What is the title of this book? Only output the title.
{% endchat %}
8 changes: 8 additions & 0 deletions tests/test_cache_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from banks.filters.cache_control import cache_control


def test_cache_control():
res = cache_control("foo", "ephemeral")
res = res.replace("<content_block_txt>", "")
res = res.replace("</content_block_txt>", "")
assert res == '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}'
Loading

0 comments on commit 9b350af

Please sign in to comment.