Skip to content

Commit

Permalink
feat: Add support for chat messages via custom tag {% chat %} (#15)
Browse files Browse the repository at this point in the history
* wip

* feat: add chat message support

* fix linting

* cleaner assignment
  • Loading branch information
masci authored Oct 2, 2024
1 parent 03808da commit 60b978a
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 10 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ ignore = [
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Avoid conflicts with the formatter
"ISC001",
# Magic numbers
"PLR2004",
]
unfixable = [
# Don't touch unused imports
Expand Down
9 changes: 2 additions & 7 deletions src/banks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
# SPDX-License-Identifier: MIT
from .config import config
from .env import env
from .prompt import AsyncPrompt, Prompt
from .prompt import AsyncPrompt, ChatMessage, Prompt

__all__ = (
"env",
"Prompt",
"AsyncPrompt",
"config",
)
__all__ = ("env", "Prompt", "AsyncPrompt", "config", "ChatMessage")
2 changes: 2 additions & 0 deletions src/banks/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ def _add_extensions(_env):
For example, we use banks to manage the system prompt in `GenerateExtension`
"""
from .extensions.chat import ChatMessage # pylint: disable=import-outside-toplevel
from .extensions.generate import GenerateExtension # pylint: disable=import-outside-toplevel
from .extensions.inference_endpoint import HFInferenceEndpointsExtension # pylint: disable=import-outside-toplevel

_env.add_extension(GenerateExtension)
_env.add_extension(HFInferenceEndpointsExtension)
_env.add_extension(ChatMessage)


# Init the Jinja env
Expand Down
69 changes: 69 additions & 0 deletions src/banks/extensions/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import json

from jinja2 import TemplateSyntaxError, nodes
from jinja2.ext import Extension

SUPPORTED_TYPES = ("system", "user")


class ChatMessage(Extension):
"""
`chat` can be used to render prompt text as structured ChatMessage objects.
Example:
```
{% chat role="system" %}
You are a helpful assistant.
{% endchat %}
```
"""

# a set of names that trigger the extension.
tags = {"chat"} # noqa

def parse(self, parser):
# We get the line number of the first token for error reporting
lineno = next(parser.stream).lineno

# Gather tokens up to the next block_end ('%}')
gathered = []
while parser.stream.current.type != "block_end":
gathered.append(next(parser.stream))

# If all has gone well, we will have one triplet of tokens:
# (type='name, value='role'),
# (type='assign', value='='),
# (type='string', value='user'),
# Anything else is a parse error
error_msg = f"Invalid syntax for chat attribute, got '{gathered}', expected role=\"value\""
try:
attr_name, attr_assign, attr_value = gathered # pylint: disable=unbalanced-tuple-unpacking
except ValueError:
raise TemplateSyntaxError(error_msg, lineno) from None

# Validate tag attributes
if attr_name.value != "role" or attr_assign.value != "=":
raise TemplateSyntaxError(error_msg, lineno)

if attr_value.value not in SUPPORTED_TYPES:
types = ",".join(SUPPORTED_TYPES)
msg = f"Unknown role type '{attr_value}', use one of ({types})"
raise TemplateSyntaxError(msg, lineno)

# Pass the role name to the CallBlock node
args: list[nodes.Expr] = [nodes.Const(attr_value.value)]

# Message body
body = parser.parse_statements(("name:endchat",), drop_needle=True)

# Build messages list
return nodes.CallBlock(self.call_method("_store_chat_messages", args), [], [], body).set_lineno(lineno)

def _store_chat_messages(self, role, caller):
"""
Helper callback.
"""
return json.dumps({"role": role, "content": caller()})
34 changes: 32 additions & 2 deletions src/banks/prompt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import uuid
from typing import Any

from pydantic import BaseModel, ValidationError

from .cache import DefaultCache, RenderCache
from .config import config
from .env import env
Expand All @@ -12,6 +15,11 @@
DEFAULT_VERSION = "0"


class ChatMessage(BaseModel):
role: str
content: str


class BasePrompt:
def __init__(
self,
Expand All @@ -37,7 +45,7 @@ def __init__(
be used.
"""
self._metadata = metadata or {}
self._name = name
self._name = name or str(uuid.uuid4())
self._raw: str = text
self._render_cache = render_cache or DefaultCache()
self._template = env.from_string(text)
Expand All @@ -55,7 +63,7 @@ def metadata(self) -> dict[str, Any]:
return self._metadata

@property
def name(self) -> str | None:
def name(self) -> str:
return self._name

@property
Expand Down Expand Up @@ -105,6 +113,28 @@ def text(self, data: dict[str, Any] | None = None) -> str:
self._render_cache.set(data, rendered)
return rendered

def chat_messages(self, data: dict[str, Any] | None = None) -> list[ChatMessage]:
"""
Render the prompt using variables present in `data`
Parameters:
data: A dictionary containing the context variables.
"""
data = self._get_context(data)
rendered = self._render_cache.get(data)
if not rendered:
rendered = self._template.render(data)
self._render_cache.set(data, rendered)

messages: list[ChatMessage] = []
for line in rendered.strip().split("\n"):
try:
messages.append(ChatMessage.model_validate_json(line))
except ValidationError:
# Ignore lines that are not a message
pass
return messages


class AsyncPrompt(BasePrompt):
"""
Expand Down
17 changes: 17 additions & 0 deletions tests/templates/chat.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{% chat role="system" %}
You are a helpful assistant.
{% endchat %}

{% chat role="user" %}
Hello, how are you?
{% endchat %}

{% chat role="system" %}
I'm doing well, thank you! How can I assist you today?
{% endchat %}

{% chat role="user" %}
Can you explain quantum computing?
{% endchat %}

Some random text.
19 changes: 19 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
from jinja2 import TemplateSyntaxError

from banks import Prompt


def test_wrong_tag():
with pytest.raises(TemplateSyntaxError):
Prompt("{% chat %}{% endchat %}")


def test_wrong_tag_params():
with pytest.raises(TemplateSyntaxError):
Prompt('{% chat foo="bar" %}{% endchat %}')


def test_wrong_role_type():
with pytest.raises(TemplateSyntaxError):
Prompt('{% chat role="does not exist" %}{% endchat %}')
35 changes: 34 additions & 1 deletion tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from pathlib import Path
from unittest import mock

import pytest
import regex as re
from jinja2 import Environment

from banks import AsyncPrompt, Prompt
from banks import AsyncPrompt, ChatMessage, Prompt
from banks.cache import DefaultCache
from banks.errors import AsyncError

Expand Down Expand Up @@ -78,3 +79,35 @@ def test__get_context():
assert p._get_context(None) == p.defaults
data = {"foo": 42}
assert p._get_context(data) == data | p.defaults


def test_chat_messages():
p_file = Path(__file__).parent / "templates" / "chat.jinja"
p = Prompt(p_file.read_text())

assert (
p.text()
== """
{"role": "system", "content": "You are a helpful assistant.\\n"}
{"role": "user", "content": "Hello, how are you?\\n"}
{"role": "system", "content": "I'm doing well, thank you! How can I assist you today?\\n"}
{"role": "user", "content": "Can you explain quantum computing?\\n"}
Some random text.
""".strip()
)

assert p.chat_messages() == [
ChatMessage(role="system", content="You are a helpful assistant.\n"),
ChatMessage(role="user", content="Hello, how are you?\n"),
ChatMessage(role="system", content="I'm doing well, thank you! How can I assist you today?\n"),
ChatMessage(role="user", content="Can you explain quantum computing?\n"),
]


def test_chat_messages_cached():
mock_cache = DefaultCache()
mock_cache.set = mock.Mock()
p_file = Path(__file__).parent / "templates" / "chat.jinja"
p = Prompt(p_file.read_text(), render_cache=mock_cache)
p.chat_messages()
mock_cache.set.assert_called_once()

0 comments on commit 60b978a

Please sign in to comment.