From b25eca82f67094a3b113b2fe2256630e3a900da4 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 2 Oct 2024 08:36:50 +0200 Subject: [PATCH] feat: add chat message support --- src/banks/__init__.py | 9 ++------- src/banks/extensions/chat.py | 13 ++++++------- src/banks/prompt.py | 34 ++++++++++++++++++++++++++++++++-- tests/templates/chat.jinja | 2 ++ tests/test_chat.py | 19 +++++++++++++++++++ tests/test_prompt.py | 35 ++++++++++++++++++++++++++++++++++- 6 files changed, 95 insertions(+), 17 deletions(-) create mode 100644 tests/test_chat.py diff --git a/src/banks/__init__.py b/src/banks/__init__.py index 681c3f0..eed8efa 100644 --- a/src/banks/__init__.py +++ b/src/banks/__init__.py @@ -3,11 +3,6 @@ # SPDX-License-Identifier: MIT from .config import config from .env import env -from .prompt import AsyncPrompt, Prompt +from .prompt import AsyncPrompt, ChatMessage, Prompt -__all__ = ( - "env", - "Prompt", - "AsyncPrompt", - "config", -) +__all__ = ("env", "Prompt", "AsyncPrompt", "config", "ChatMessage") diff --git a/src/banks/extensions/chat.py b/src/banks/extensions/chat.py index c0b49ce..c3590d0 100644 --- a/src/banks/extensions/chat.py +++ b/src/banks/extensions/chat.py @@ -1,23 +1,23 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT -import html -import os +import json -import requests from jinja2 import TemplateSyntaxError, nodes from jinja2.ext import Extension - SUPPORTED_TYPES = ("system", "user") class ChatMessage(Extension): """ + `chat` can be used to render prompt text as structured ChatMessage objects. Example: ``` - + {% chat role="system" %} + You are a helpful assistant. + {% endchat %} ``` """ @@ -65,5 +65,4 @@ def _store_chat_messages(self, role, caller): """ Helper callback. """ - print({"role": role, "content": caller()}) - return caller() + return json.dumps({"role": role, "content": caller()}) diff --git a/src/banks/prompt.py b/src/banks/prompt.py index f3dcd4b..df163c1 100644 --- a/src/banks/prompt.py +++ b/src/banks/prompt.py @@ -1,8 +1,11 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT +import uuid from typing import Any +from pydantic import BaseModel, ValidationError + from .cache import DefaultCache, RenderCache from .config import config from .env import env @@ -12,6 +15,11 @@ DEFAULT_VERSION = "0" +class ChatMessage(BaseModel): + role: str + content: str + + class BasePrompt: def __init__( self, @@ -37,7 +45,7 @@ def __init__( be used. """ self._metadata = metadata or {} - self._name = name + self._name = name or str(uuid.uuid4()) self._raw: str = text self._render_cache = render_cache or DefaultCache() self._template = env.from_string(text) @@ -55,7 +63,7 @@ def metadata(self) -> dict[str, Any]: return self._metadata @property - def name(self) -> str | None: + def name(self) -> str: return self._name @property @@ -105,6 +113,28 @@ def text(self, data: dict[str, Any] | None = None) -> str: self._render_cache.set(data, rendered) return rendered + def chat_messages(self, data: dict[str, Any] | None = None) -> list[ChatMessage]: + """ + Render the prompt using variables present in `data` + + Parameters: + data: A dictionary containing the context variables. + """ + data = self._get_context(data) + rendered = self._render_cache.get(data) + if not rendered: + rendered = self._template.render(data) + self._render_cache.set(data, rendered) + + messages: list[ChatMessage] = [] + for line in rendered.strip().split("\n"): + try: + messages.append(ChatMessage.model_validate_json(line)) + except ValidationError: + # Ignore lines that are not a message + pass + return messages + class AsyncPrompt(BasePrompt): """ diff --git a/tests/templates/chat.jinja b/tests/templates/chat.jinja index 03eb49b..95d7980 100644 --- a/tests/templates/chat.jinja +++ b/tests/templates/chat.jinja @@ -13,3 +13,5 @@ I'm doing well, thank you! How can I assist you today? {% chat role="user" %} Can you explain quantum computing? {% endchat %} + +Some random text. \ No newline at end of file diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..2941d75 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,19 @@ +import pytest +from jinja2 import TemplateSyntaxError + +from banks import Prompt + + +def test_wrong_tag(): + with pytest.raises(TemplateSyntaxError): + Prompt("{% chat %}{% endchat %}") + + +def test_wrong_tag_params(): + with pytest.raises(TemplateSyntaxError): + Prompt('{% chat foo="bar" %}{% endchat %}') + + +def test_wrong_role_type(): + with pytest.raises(TemplateSyntaxError): + Prompt('{% chat role="does not exist" %}{% endchat %}') diff --git a/tests/test_prompt.py b/tests/test_prompt.py index f05c966..21ab0be 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -1,10 +1,11 @@ +from pathlib import Path from unittest import mock import pytest import regex as re from jinja2 import Environment -from banks import AsyncPrompt, Prompt +from banks import AsyncPrompt, Prompt, ChatMessage from banks.cache import DefaultCache from banks.errors import AsyncError @@ -78,3 +79,35 @@ def test__get_context(): assert p._get_context(None) == p.defaults data = {"foo": 42} assert p._get_context(data) == data | p.defaults + + +def test_chat_messages(): + p_file = Path(__file__).parent / "templates" / "chat.jinja" + p = Prompt(p_file.read_text()) + + assert ( + p.text() + == """ +{"role": "system", "content": "You are a helpful assistant.\\n"} +{"role": "user", "content": "Hello, how are you?\\n"} +{"role": "system", "content": "I'm doing well, thank you! How can I assist you today?\\n"} +{"role": "user", "content": "Can you explain quantum computing?\\n"} +Some random text. +""".strip() + ) + + assert p.chat_messages() == [ + ChatMessage(role="system", content="You are a helpful assistant.\n"), + ChatMessage(role="user", content="Hello, how are you?\n"), + ChatMessage(role="system", content="I'm doing well, thank you! How can I assist you today?\n"), + ChatMessage(role="user", content="Can you explain quantum computing?\n"), + ] + + +def test_chat_messages_cached(): + mock_cache = DefaultCache() + mock_cache.set = mock.Mock() + p_file = Path(__file__).parent / "templates" / "chat.jinja" + p = Prompt(p_file.read_text(), render_cache=mock_cache) + p.chat_messages() + mock_cache.set.assert_called_once()