Skip to content

Commit

Permalink
feat: add chat message support
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Oct 2, 2024
1 parent 452aa29 commit b25eca8
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 17 deletions.
9 changes: 2 additions & 7 deletions src/banks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
# SPDX-License-Identifier: MIT
from .config import config
from .env import env
from .prompt import AsyncPrompt, Prompt
from .prompt import AsyncPrompt, ChatMessage, Prompt

__all__ = (
"env",
"Prompt",
"AsyncPrompt",
"config",
)
__all__ = ("env", "Prompt", "AsyncPrompt", "config", "ChatMessage")
13 changes: 6 additions & 7 deletions src/banks/extensions/chat.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import html
import os
import json

import requests
from jinja2 import TemplateSyntaxError, nodes
from jinja2.ext import Extension


SUPPORTED_TYPES = ("system", "user")


class ChatMessage(Extension):
"""
`chat` can be used to render prompt text as structured ChatMessage objects.
Example:
```
{% chat role="system" %}
You are a helpful assistant.
{% endchat %}
```
"""

Expand Down Expand Up @@ -65,5 +65,4 @@ def _store_chat_messages(self, role, caller):
"""
Helper callback.
"""
print({"role": role, "content": caller()})
return caller()
return json.dumps({"role": role, "content": caller()})
34 changes: 32 additions & 2 deletions src/banks/prompt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import uuid
from typing import Any

from pydantic import BaseModel, ValidationError

from .cache import DefaultCache, RenderCache
from .config import config
from .env import env
Expand All @@ -12,6 +15,11 @@
DEFAULT_VERSION = "0"


class ChatMessage(BaseModel):
role: str
content: str


class BasePrompt:
def __init__(
self,
Expand All @@ -37,7 +45,7 @@ def __init__(
be used.
"""
self._metadata = metadata or {}
self._name = name
self._name = name or str(uuid.uuid4())
self._raw: str = text
self._render_cache = render_cache or DefaultCache()
self._template = env.from_string(text)
Expand All @@ -55,7 +63,7 @@ def metadata(self) -> dict[str, Any]:
return self._metadata

@property
def name(self) -> str | None:
def name(self) -> str:
return self._name

@property
Expand Down Expand Up @@ -105,6 +113,28 @@ def text(self, data: dict[str, Any] | None = None) -> str:
self._render_cache.set(data, rendered)
return rendered

def chat_messages(self, data: dict[str, Any] | None = None) -> list[ChatMessage]:
"""
Render the prompt using variables present in `data`
Parameters:
data: A dictionary containing the context variables.
"""
data = self._get_context(data)
rendered = self._render_cache.get(data)
if not rendered:
rendered = self._template.render(data)
self._render_cache.set(data, rendered)

messages: list[ChatMessage] = []
for line in rendered.strip().split("\n"):
try:
messages.append(ChatMessage.model_validate_json(line))
except ValidationError:
# Ignore lines that are not a message
pass
return messages


class AsyncPrompt(BasePrompt):
"""
Expand Down
2 changes: 2 additions & 0 deletions tests/templates/chat.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ I'm doing well, thank you! How can I assist you today?
{% chat role="user" %}
Can you explain quantum computing?
{% endchat %}

Some random text.
19 changes: 19 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
from jinja2 import TemplateSyntaxError

from banks import Prompt


def test_wrong_tag():
with pytest.raises(TemplateSyntaxError):
Prompt("{% chat %}{% endchat %}")


def test_wrong_tag_params():
with pytest.raises(TemplateSyntaxError):
Prompt('{% chat foo="bar" %}{% endchat %}')


def test_wrong_role_type():
with pytest.raises(TemplateSyntaxError):
Prompt('{% chat role="does not exist" %}{% endchat %}')
35 changes: 34 additions & 1 deletion tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from pathlib import Path
from unittest import mock

import pytest
import regex as re
from jinja2 import Environment

from banks import AsyncPrompt, Prompt
from banks import AsyncPrompt, Prompt, ChatMessage
from banks.cache import DefaultCache
from banks.errors import AsyncError

Expand Down Expand Up @@ -78,3 +79,35 @@ def test__get_context():
assert p._get_context(None) == p.defaults
data = {"foo": 42}
assert p._get_context(data) == data | p.defaults


def test_chat_messages():
p_file = Path(__file__).parent / "templates" / "chat.jinja"
p = Prompt(p_file.read_text())

assert (
p.text()
== """
{"role": "system", "content": "You are a helpful assistant.\\n"}
{"role": "user", "content": "Hello, how are you?\\n"}
{"role": "system", "content": "I'm doing well, thank you! How can I assist you today?\\n"}
{"role": "user", "content": "Can you explain quantum computing?\\n"}
Some random text.
""".strip()
)

assert p.chat_messages() == [
ChatMessage(role="system", content="You are a helpful assistant.\n"),
ChatMessage(role="user", content="Hello, how are you?\n"),
ChatMessage(role="system", content="I'm doing well, thank you! How can I assist you today?\n"),
ChatMessage(role="user", content="Can you explain quantum computing?\n"),
]


def test_chat_messages_cached():
mock_cache = DefaultCache()
mock_cache.set = mock.Mock()
p_file = Path(__file__).parent / "templates" / "chat.jinja"
p = Prompt(p_file.read_text(), render_cache=mock_cache)
p.chat_messages()
mock_cache.set.assert_called_once()

0 comments on commit b25eca8

Please sign in to comment.