diff --git a/docs/prompt.md b/docs/prompt.md index 8617226..91e9042 100644 --- a/docs/prompt.md +++ b/docs/prompt.md @@ -21,6 +21,13 @@ In addition to all the [builtin filters](https://jinja.palletsprojects.com/en/3. provided by Jinja, Banks supports the following ones, specific for prompt engineering. +::: banks.filters.cache_control.cache_control + options: + show_root_full_path: false + show_symbol_type_heading: false + show_signature_annotations: false + heading_level: 3 + ::: banks.filters.lemmatize.lemmatize options: show_root_full_path: false @@ -53,12 +60,12 @@ Insert into the prompt a canary word that can be checked later with `Prompt.cana to ensure the original prompt was not leaked. Example: -```python -from banks import Prompt + ```python + from banks import Prompt -p = Prompt("{{canary_word}}Hello, World!") -p.text() ## outputs 'BANKS[5f0bbba4]Hello, World!' -``` + p = Prompt("{{canary_word}}Hello, World!") + p.text() ## outputs 'BANKS[5f0bbba4]Hello, World!' + ``` ## Macros diff --git a/src/banks/__init__.py b/src/banks/__init__.py index eed8efa..f96f434 100644 --- a/src/banks/__init__.py +++ b/src/banks/__init__.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT from .config import config from .env import env -from .prompt import AsyncPrompt, ChatMessage, Prompt +from .prompt import AsyncPrompt, Prompt +from .types import ChatMessage __all__ = ("env", "Prompt", "AsyncPrompt", "config", "ChatMessage") diff --git a/src/banks/env.py b/src/banks/env.py index ba1ef73..f020d27 100644 --- a/src/banks/env.py +++ b/src/banks/env.py @@ -4,7 +4,7 @@ from jinja2 import Environment, PackageLoader, select_autoescape from .config import config -from .filters import lemmatize +from .filters import cache_control, lemmatize def _add_extensions(_env): @@ -13,13 +13,13 @@ def _add_extensions(_env): For example, we use banks to manage the system prompt in `GenerateExtension` """ - from .extensions.chat import ChatMessage # pylint: disable=import-outside-toplevel + from .extensions.chat import ChatExtension # pylint: disable=import-outside-toplevel from .extensions.generate import GenerateExtension # pylint: disable=import-outside-toplevel from .extensions.inference_endpoint import HFInferenceEndpointsExtension # pylint: disable=import-outside-toplevel _env.add_extension(GenerateExtension) _env.add_extension(HFInferenceEndpointsExtension) - _env.add_extension(ChatMessage) + _env.add_extension(ChatExtension) # Init the Jinja env @@ -37,4 +37,5 @@ def _add_extensions(_env): # Setup custom filters and defaults env.filters["lemmatize"] = lemmatize +env.filters["cache_control"] = cache_control _add_extensions(env) diff --git a/src/banks/extensions/chat.py b/src/banks/extensions/chat.py index 63fa6ba..f73e193 100644 --- a/src/banks/extensions/chat.py +++ b/src/banks/extensions/chat.py @@ -1,11 +1,13 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT -import json +from html.parser import HTMLParser from jinja2 import TemplateSyntaxError, nodes from jinja2.ext import Extension +from banks.types import ChatMessage, ChatMessageContent, ContentBlock, ContentBlockType + SUPPORTED_TYPES = ("system", "user") @@ -16,7 +18,7 @@ def chat(role: str): # pylint: disable=W0613 will return a list of `ChatMessage` instances. Example: - ``` + ```jinja {% chat role="system" %} You are a helpful assistant. {% endchat %} @@ -28,7 +30,44 @@ def chat(role: str): # pylint: disable=W0613 """ -class ChatMessage(Extension): +class _ContentBlockParser(HTMLParser): + """A parser used to extract text surrounded by `` and `` tags.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._parse_block_content = False + self._content_blocks: list[ContentBlock] = [] + + @property + def content(self) -> ChatMessageContent: + """Returns ChatMessageContent data that can be directly assigned to ChatMessage.content. + + If only one block is present, this block is of type text and has no cache control set, we just + return it as plain text for simplicity. + """ + if len(self._content_blocks) == 1: + block = self._content_blocks[0] + if block.type == "text" and block.cache_control is None: + return block.text or "" + + return self._content_blocks + + def handle_starttag(self, tag, _): + if tag == "content_block_txt": + self._parse_block_content = True + + def handle_endtag(self, tag): + if tag == "content_block_txt": + self._parse_block_content = False + + def handle_data(self, data): + if self._parse_block_content: + self._content_blocks.append(ContentBlock.model_validate_json(data)) + else: + self._content_blocks.append(ContentBlock(type=ContentBlockType.text, text=data)) + + +class ChatExtension(Extension): """ `chat` can be used to render prompt text as structured ChatMessage objects. @@ -85,4 +124,7 @@ def _store_chat_messages(self, role, caller): """ Helper callback. """ - return json.dumps({"role": role, "content": caller()}) + parser = _ContentBlockParser() + parser.feed(caller()) + cm = ChatMessage(role=role, content=parser.content) + return cm.model_dump_json() diff --git a/src/banks/extensions/generate.py b/src/banks/extensions/generate.py index 6d693af..d9f0abe 100644 --- a/src/banks/extensions/generate.py +++ b/src/banks/extensions/generate.py @@ -21,7 +21,7 @@ def generate(model_name: str): # pylint: disable=W0613 `generate` can be used to call the LiteLLM API passing the tag text as a prompt and get back some content. Example: - ``` + ```jinja {% generate "write a tweet with positive sentiment" "gpt-3.5-turbo" %} Feeling grateful for all the opportunities that come my way! #positivity #productivity ``` diff --git a/src/banks/extensions/inference_endpoint.py b/src/banks/extensions/inference_endpoint.py index 99fee6f..300892b 100644 --- a/src/banks/extensions/inference_endpoint.py +++ b/src/banks/extensions/inference_endpoint.py @@ -15,7 +15,7 @@ class HFInferenceEndpointsExtension(Extension): passing a prompt to get back some content. Example: - ``` + ```jinja {% inference_endpoint "write a tweet with positive sentiment", "https://foo.aws.endpoints.huggingface.cloud" %} Life is beautiful, full of opportunities & positivity ``` diff --git a/src/banks/filters/__init__.py b/src/banks/filters/__init__.py index ae16d7a..23d05f3 100644 --- a/src/banks/filters/__init__.py +++ b/src/banks/filters/__init__.py @@ -1,6 +1,10 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT -from banks.filters.lemmatize import lemmatize +from .cache_control import cache_control +from .lemmatize import lemmatize -__all__ = ("lemmatize",) +__all__ = ( + "cache_control", + "lemmatize", +) diff --git a/src/banks/filters/cache_control.py b/src/banks/filters/cache_control.py new file mode 100644 index 0000000..65f4640 --- /dev/null +++ b/src/banks/filters/cache_control.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi +# +# SPDX-License-Identifier: MIT +from banks.types import ContentBlock + + +def cache_control(value: str, cache_type: str = "ephemeral") -> str: + """Wrap the filtered value into a ContentBlock with the proper cache_control field set. + + The resulting ChatMessage will have the field `content` populated with a list of ContentBlock objects. + + Example: + ```jinja + {{ "This is a long, long text" | cache_control("ephemeral") }} + + This is short and won't be cached. + ``` + + Important: + this filter marks the content to cache by surrounding it with `` and + ``, so it's only useful when used within a `{% chat %}` block. + """ + block = ContentBlock.model_validate({"type": "text", "text": value, "cache_control": {"type": cache_type}}) + return f"{block.model_dump_json()}" diff --git a/src/banks/filters/lemmatize.py b/src/banks/filters/lemmatize.py index 277c210..1b344e7 100644 --- a/src/banks/filters/lemmatize.py +++ b/src/banks/filters/lemmatize.py @@ -4,10 +4,10 @@ from banks.errors import MissingDependencyError try: - from simplemma import text_lemmatizer + from simplemma import text_lemmatizer # type: ignore SIMPLEMMA_AVAIL = True -except ImportError: +except ImportError: # pragma: no cover SIMPLEMMA_AVAIL = False @@ -17,10 +17,10 @@ def lemmatize(text: str) -> str: to English. Example: - ``` - {{"The dog is running" | lemmatize}} - "the dog be run" - ``` + ```jinja + {{"The dog is running" | lemmatize}} + "the dog be run" + ``` Note: Simplemma must be manually installed to use this filter diff --git a/src/banks/prompt.py b/src/banks/prompt.py index df163c1..edc0963 100644 --- a/src/banks/prompt.py +++ b/src/banks/prompt.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: MIT import uuid -from typing import Any +from typing import Any, Protocol, Self from pydantic import BaseModel, ValidationError @@ -10,16 +10,12 @@ from .config import config from .env import env from .errors import AsyncError +from .types import ChatMessage from .utils import generate_canary_word DEFAULT_VERSION = "0" -class ChatMessage(BaseModel): - role: str - content: str - - class BasePrompt: def __init__( self, @@ -219,3 +215,24 @@ async def text(self, data: dict[str, Any] | None = None) -> str: rendered: str = await self._template.render_async(data) self._render_cache.set(data, rendered) return rendered + + +class PromptRegistry(Protocol): # pragma: no cover + """Interface to be implemented by concrete prompt registries.""" + + def get(self, *, name: str, version: str | None = None) -> Prompt: ... + + def set(self, *, prompt: Prompt, overwrite: bool = False) -> None: ... + + +class PromptModel(BaseModel): + """Serializable representation of a Prompt.""" + + text: str + name: str | None = None + version: str | None = None + metadata: dict[str, Any] | None = None + + @classmethod + def from_prompt(cls: type[Self], prompt: Prompt) -> Self: + return cls(text=prompt.raw, name=prompt.name, version=prompt.version, metadata=prompt.metadata) diff --git a/src/banks/registries/directory.py b/src/banks/registries/directory.py index 353e2ab..6b3385c 100644 --- a/src/banks/registries/directory.py +++ b/src/banks/registries/directory.py @@ -9,8 +9,7 @@ from banks import Prompt from banks.errors import InvalidPromptError, PromptNotFoundError -from banks.prompt import DEFAULT_VERSION -from banks.types import PromptModel +from banks.prompt import DEFAULT_VERSION, PromptModel # Constants DEFAULT_INDEX_NAME = "index.json" diff --git a/src/banks/registries/file.py b/src/banks/registries/file.py index 9eced74..f10b1f8 100644 --- a/src/banks/registries/file.py +++ b/src/banks/registries/file.py @@ -6,8 +6,7 @@ from pydantic import BaseModel from banks.errors import InvalidPromptError, PromptNotFoundError -from banks.prompt import Prompt -from banks.types import PromptModel +from banks.prompt import Prompt, PromptModel class PromptRegistryIndex(BaseModel): diff --git a/src/banks/types.py b/src/banks/types.py index 90acbc7..43b34ef 100644 --- a/src/banks/types.py +++ b/src/banks/types.py @@ -1,29 +1,51 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT -from typing import Any, Protocol, Self +from enum import Enum from pydantic import BaseModel -from .prompt import Prompt +# pylint: disable=invalid-name -class PromptRegistry(Protocol): # pragma: no cover - """Interface to be implemented by concrete prompt registries.""" +class ContentBlockType(str, Enum): + text = "text" + image = "image" - def get(self, *, name: str, version: str | None = None) -> Prompt: ... - def set(self, *, prompt: Prompt, overwrite: bool = False) -> None: ... +class MediaTypeBlockType(str, Enum): + image_jpeg = "image/jpeg" + image_png = "image/png" + image_gif = "image/gif" + image_webp = "image/webp" -class PromptModel(BaseModel): - """Serializable representation of a Prompt.""" +class CacheControl(BaseModel): + type: str = "ephemeral" - text: str - name: str | None = None - version: str | None = None - metadata: dict[str, Any] | None = None - @classmethod - def from_prompt(cls: type[Self], prompt: Prompt) -> Self: - return cls(text=prompt.raw, name=prompt.name, version=prompt.version, metadata=prompt.metadata) +class Source(BaseModel): + type: str = "base64" + media_type: MediaTypeBlockType + data: str + + class Config: + use_enum_values = True + + +class ContentBlock(BaseModel): + type: ContentBlockType + cache_control: CacheControl | None = None + text: str | None = None + source: Source | None = None + + class Config: + use_enum_values = True + + +ChatMessageContent = list[ContentBlock] | str + + +class ChatMessage(BaseModel): + role: str + content: str | ChatMessageContent diff --git a/tests/templates/cache.jinja b/tests/templates/cache.jinja new file mode 100644 index 0000000..66f4284 --- /dev/null +++ b/tests/templates/cache.jinja @@ -0,0 +1,7 @@ +{% chat role="user" %} +The book to analize is included in the tags and . + +{{ book | cache_control("ephemeral") }} + +What is the title of this book? Only output the title. +{% endchat %} \ No newline at end of file diff --git a/tests/test_cache_control.py b/tests/test_cache_control.py new file mode 100644 index 0000000..0a6e378 --- /dev/null +++ b/tests/test_cache_control.py @@ -0,0 +1,8 @@ +from banks.filters.cache_control import cache_control + + +def test_cache_control(): + res = cache_control("foo", "ephemeral") + res = res.replace("", "") + res = res.replace("", "") + assert res == '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}' diff --git a/tests/test_chat.py b/tests/test_chat.py index 2941d75..71e3be1 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -2,6 +2,8 @@ from jinja2 import TemplateSyntaxError from banks import Prompt +from banks.extensions.chat import _ContentBlockParser +from banks.types import CacheControl, ContentBlock, ContentBlockType def test_wrong_tag(): @@ -17,3 +19,43 @@ def test_wrong_tag_params(): def test_wrong_role_type(): with pytest.raises(TemplateSyntaxError): Prompt('{% chat role="does not exist" %}{% endchat %}') + + +def test_content_block_parser_init(): + p = _ContentBlockParser() + assert p._parse_block_content is False + assert p._content_blocks == [] + + +def test_content_block_parser_single_with_cache_control(): + p = _ContentBlockParser() + p.feed( + '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}' + ) + assert p.content == [ + ContentBlock(type=ContentBlockType.text, cache_control=CacheControl(type="ephemeral"), text="foo", source=None) + ] + + +def test_content_block_parser_single_no_cache_control(): + p = _ContentBlockParser() + p.feed('{"type":"text","cache_control":null,"text":"foo","source":null}') + assert p.content == "foo" + + +def test_content_block_parser_multiple(): + p = _ContentBlockParser() + p.feed( + '{"type":"text","cache_control":null,"text":"foo","source":null}' + '{"type":"text","cache_control":null,"text":"bar","source":null}' + ) + assert p.content == [ + ContentBlock(type=ContentBlockType.text, cache_control=None, text="foo", source=None), + ContentBlock(type=ContentBlockType.text, cache_control=None, text="bar", source=None), + ] + + +def test_content_block_parser_other_tags(): + p = _ContentBlockParser() + p.feed("FOO") + assert p.content == "FOO" diff --git a/tests/test_file_registry.py b/tests/test_file_registry.py index f04d3b8..7710dd5 100644 --- a/tests/test_file_registry.py +++ b/tests/test_file_registry.py @@ -1,9 +1,8 @@ import pytest from banks.errors import InvalidPromptError, PromptNotFoundError -from banks.prompt import Prompt +from banks.prompt import Prompt, PromptModel from banks.registries.file import FilePromptRegistry, PromptRegistryIndex -from banks.types import PromptModel @pytest.fixture diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 366c482..a206ec3 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -88,10 +88,10 @@ def test_chat_messages(): assert ( p.text() == """ -{"role": "system", "content": "You are a helpful assistant.\\n"} -{"role": "user", "content": "Hello, how are you?\\n"} -{"role": "system", "content": "I'm doing well, thank you! How can I assist you today?\\n"} -{"role": "user", "content": "Can you explain quantum computing?\\n"} +{"role":"system","content":"You are a helpful assistant.\\n"} +{"role":"user","content":"Hello, how are you?\\n"} +{"role":"system","content":"I'm doing well, thank you! How can I assist you today?\\n"} +{"role":"user","content":"Can you explain quantum computing?\\n"} Some random text. """.strip() )