Skip to content

Commit

Permalink
Implement message parsing.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Jul 1, 2024
1 parent cd59c71 commit 5b5c22d
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 55 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ sphinxcontrib-napoleon = "^0.7"
toml = "^0.10.2"

[tool.pytest.ini_options]
addopts = "--doctest-modules"
addopts = "--doctest-modules -vv --showlocals"

[tool.pyright]
pythonVersion = "3.10"
Expand Down
56 changes: 56 additions & 0 deletions tests/test_models/test_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from textwrap import dedent

from wanga.models.messages import (
AssistantMessage,
ImageContent,
SystemMessage,
ToolInvocation,
ToolMessage,
UserMessage,
parse_messages,
)


def test_consistency():
messages = [
SystemMessage(content="Be a helpful assistant!\n"),
UserMessage(name="Alice", content=["Hello, world!"]),
AssistantMessage(
name="Bob",
content="Here's a helpful image!",
tool_invocations=[
ToolInvocation(
invocation_id="abc123",
name="paste_image",
arguments={"url": "https://example.com/image.png"},
)
],
),
ToolMessage(
invocation_id="abc123",
content=[
"Here is the image you requested:\n\n",
ImageContent(url="https://example.com/image.png"),
],
),
UserMessage(name="Alice", content=["Thanks for the image!"]),
]

chat_str = "\n".join(str(message) for message in messages)
parsed_messages = parse_messages(chat_str)
assert parsed_messages == messages

chat_str = r"""
[|user|]
Hi! Here is the pic<|image url="https://example.com/image.png"|>
[|assistant|]
It is beautiful!
[|user|]
I love you!
[|assistant|]
I love you too!
"""
chat_str = dedent(chat_str.removeprefix("\n"))

parsed_messages = parse_messages(chat_str)
assert "\n".join(str(message) for message in parsed_messages) == chat_str.strip()
266 changes: 212 additions & 54 deletions wanga/models/messages.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,276 @@
from pprint import pformat
from typing import TYPE_CHECKING, TypeAlias
import json
import logging
import re
from collections.abc import Iterable
from typing import NamedTuple

from attrs import frozen

from wanga import _PIL_INSTALLED
from attrs import field, frozen

from ..common import JSON

__all__ = [
"ImageURL",
"AssistantMessage",
"ImageContent",
"Message",
"UserMessage",
"AssistantMessage",
"ToolInvocation",
"ToolMessage",
"UserMessage",
]


@frozen
class ImageURL:
url: str
class TagPair(NamedTuple):
open: str
close: str


if TYPE_CHECKING or _PIL_INSTALLED:
from PIL.Image import Image
_MESSAGE_TAGS = TagPair(r"[|", r"|]")
_CONTENT_BLOCK_TAGS = TagPair(r"<|", r"|>")

ImageContent: TypeAlias = Image | ImageURL
else:
ImageContent: TypeAlias = ImageURL

_logger = logging.getLogger(__name__)

_MESSAGE_TAG_OPEN = "[|"
_MESSAGE_TAG_CLOSE = "|]"

_CONTENT_BLOCK_TAG_OPEN = "<|"
_CONTENT_BLOCK_TAG_CLOSE = "|>"
def _format_header(_tags: TagPair, _name: str, /, **kwargs: str) -> str:
param_str = "".join(f' {key}="{value}"' for key, value in kwargs.items())
return f"{_tags.open}{_name}{param_str}{_tags.close}"


def _pretty_print_message(role: str, name: str | None, content: str | list[str | ImageContent]) -> str:
if name:
role_str = f"{_MESSAGE_TAG_OPEN}{role} {name}{_MESSAGE_TAG_CLOSE}"
else:
role_str = f"{_MESSAGE_TAG_OPEN}{role}{_MESSAGE_TAG_CLOSE}"
@frozen
class ImageContent:
r"""Image embedded in a message. Can either be a URL, or a base64-encoded file."""

content_strings = []
if isinstance(content, str):
content = [content]
for content_block in content:
if isinstance(content_block, ImageURL):
content_block = f'{_CONTENT_BLOCK_TAG_OPEN}image url="{content_block.url}"{_CONTENT_BLOCK_TAG_CLOSE}'
if not isinstance(content_block, str):
content_block = f"{_CONTENT_BLOCK_TAG_OPEN}image{_CONTENT_BLOCK_TAG_CLOSE}"
content_strings.append(content_block)
joined_content = "\n".join(content_strings)
url: str | None = None
base64: str | None = None

return f"{role_str}\n{joined_content}"
def __attrs_post_init__(self):
if self.url is None and self.base64 is None:
raise ValueError("ImageURL must have either a URL or base64 data")
if self.url is not None and self.base64 is not None:
raise ValueError("ImageURL cannot have both a URL and base64 data")

def __str__(self) -> str:
kwargs = {}
if self.url:
kwargs["url"] = self.url
if self.base64:
kwargs["base64"] = self.base64
return _format_header(_CONTENT_BLOCK_TAGS, "image", **kwargs)


@frozen
class Message:
pass


@frozen
class SystemMessage(Message):
name: str | None
content: str
name: str | None = None

def __str__(self) -> str:
return _pretty_print_message("system", self.name, self.content)
kwargs = {}
if self.name:
kwargs["name"] = self.name
header = _format_header(_MESSAGE_TAGS, "system", **kwargs)
return f"{header}\n{self.content}"


@frozen
class UserMessage(Message):
name: str | None
content: str | list[str | ImageContent]
name: str | None = None

def __str__(self) -> str:
return _pretty_print_message("user", self.name, self.content)
kwargs = {}
if self.name:
kwargs["name"] = self.name
header = _format_header(_MESSAGE_TAGS, "user", **kwargs)
content = "".join(str(item) for item in self.content)
return f"{header}\n{content}"


@frozen
class ToolInvocation:
invocation_id: str
tool_name: str
tool_args: JSON
name: str
arguments: JSON

def __str__(self) -> str:
pretty_json = pformat(self.tool_args, sort_dicts=True)
call_header = (
f'{_CONTENT_BLOCK_TAG_OPEN}call {self.tool_name} id="{self.invocation_id}"{_CONTENT_BLOCK_TAG_CLOSE}'
)
return f"{call_header}\n{pretty_json}"
pretty_json = json.dumps(self.arguments, indent=4, sort_keys=True)
header = _format_header(_CONTENT_BLOCK_TAGS, "tool", name=self.name, id=self.invocation_id)
return f"{header}\n{pretty_json}"


@frozen
class AssistantMessage(Message):
name: str | None
content: str | None
tool_invocations: list[ToolInvocation]
name: str | None = None
content: str | None = None
tool_invocations: list[ToolInvocation] = field(factory=list)

def __str__(self) -> str:
kwargs = {}
if self.name:
kwargs["name"] = self.name
header = _format_header(_MESSAGE_TAGS, "assistant", **kwargs)

content_blocks = []
if self.content:
content_blocks.append(self.content)
for tool_invocation in self.tool_invocations:
content_blocks.append(str(tool_invocation))
return "\n".join(content_blocks)
joined_content = "\n".join(content_blocks)

return f"{header}\n{joined_content}"


@frozen
class ToolMessage:
class ToolMessage(Message):
invocation_id: str
tool_result: str | list[str | ImageContent]
content: str | list[str | ImageContent]

def __str__(self) -> str:
return _pretty_print_message("tool", f"id={self.invocation_id}", self.tool_result)
header = _format_header(_MESSAGE_TAGS, "tool", id=self.invocation_id)
tool_result = self.content
if not isinstance(self.content, list):
tool_result = [self.content]
content = "".join(str(item) for item in tool_result)
return f"{header}\n{content}"


class ParsedHeader(NamedTuple):
name: str
params: dict[str, str]


class MessageSyntaxError(ValueError):
pass


class HeaderRegexes(NamedTuple):
tag_regex: re.Pattern
full_regex: re.Pattern

def parse(self, header_str: str) -> ParsedHeader:
header_match = self.full_regex.match(header_str)
if header_match is None:
raise MessageSyntaxError(f"Invalid header: {header_str}")
name = header_match.group("name")
params_str = header_match.group("params")
params = {}
for param_match in _PARAM_REGEX.finditer(params_str):
params[param_match.group("param_name")] = param_match.group("param_value")
return ParsedHeader(name, params)

def split_text(self, text: str) -> Iterable[ParsedHeader | str]:
for text_block in self.tag_regex.split(text):
if self.tag_regex.match(text_block):
yield self.parse(text_block)
elif text_block:
yield text_block


_URL_SPECIAL_SYMBOLS = re.escape(r"/:!#$%&'*+-.^_`|~")
_PARAM_KEY_SYMBOLS = r"[a-zA-Z0-9_\-]"
_PARAM_VALUE_SYMBOLS = f"[a-zA-Z0-9{_URL_SPECIAL_SYMBOLS}]"


_PARAM_REGEX_STR = f'(?P<param_name>{_PARAM_KEY_SYMBOLS}+) *= *["](?P<param_value>{_PARAM_VALUE_SYMBOLS}*)["]'
_PARAM_REGEX = re.compile(_PARAM_REGEX_STR)

_PARAMS_REGEX_STR = f"( +(?P<param>{_PARAM_REGEX_STR}))* *"
_HEADER_BODY_REGEX_STR = f" *(?P<name>{_PARAM_KEY_SYMBOLS}+)(?P<params>{_PARAMS_REGEX_STR})"


def _make_header_regex(tags: TagPair, inner_regex: str) -> HeaderRegexes:
open_tag = re.escape(tags.open)
close_tag = re.escape(tags.close)
tag_regex = re.compile(f"({open_tag}.*{close_tag})")
full_regex = re.compile(f"(^{open_tag}{_HEADER_BODY_REGEX_STR}{close_tag}$)")
return HeaderRegexes(tag_regex, full_regex)


_MESSAGE_HEADER_REGEXES = _make_header_regex(_MESSAGE_TAGS, _HEADER_BODY_REGEX_STR)
_CONTENT_HEADER_REGEXES = _make_header_regex(_CONTENT_BLOCK_TAGS, _HEADER_BODY_REGEX_STR)


def _parse_image_content(header: ParsedHeader) -> ImageContent:
assert header.name == "image"
try:
return ImageContent(**header.params)
except ValueError as e:
raise MessageSyntaxError(f"Invalid image parameters: {header.params}") from e


def _parse_tool_invocation(header: ParsedHeader, arg_text: str) -> ToolInvocation:
assert header.name == "tool"
try:
kwargs = dict(header.params)
kwargs["invocation_id"] = kwargs.pop("id")
return ToolInvocation(**kwargs, arguments=json.loads(arg_text))
except (ValueError, KeyError) as e:
raise MessageSyntaxError(f"Invalid tool invocation parameters: {header}\n{arg_text}") from e


def _map_headers_to_content(blocks: Iterable[str | ParsedHeader]) -> Iterable[str | ImageContent]:
for block in blocks:
if isinstance(block, str):
yield block
else:
yield _parse_image_content(block)


def _parse_message(header: ParsedHeader, message_str: str) -> Message:
message_blocks = list(_CONTENT_HEADER_REGEXES.split_text(message_str))
match header.name:
case "system":
if len(message_blocks) > 1:
raise MessageSyntaxError(f"System message cannot contain anything other than text: {message_str}")
return SystemMessage(name=header.params.get("name"), content=message_str)
case "user":
content = list(_map_headers_to_content(message_blocks))
return UserMessage(name=header.params.get("name"), content=content)
case "assistant":
if not message_blocks:
raise MessageSyntaxError(f"No content in assistant message: {message_str}")
if isinstance(message_blocks[0], ParsedHeader):
content = None
else:
content = message_blocks.pop(0)
assert isinstance(content, str)
tool_invocations = []
for block_header, arg_str in zip(message_blocks[::2], message_blocks[1::2]): # type: ignore
if not isinstance(block_header, ParsedHeader):
raise MessageSyntaxError(f"Invalid tool invocation header: {block_header}")
if not isinstance(arg_str, str):
raise MessageSyntaxError(f"No arguments specified to tool invocation {block_header}")
tool_invocations.append(_parse_tool_invocation(block_header, arg_str))
if tool_invocations and content is not None:
content = content.removesuffix("\n")
return AssistantMessage(name=header.params.get("name"), content=content, tool_invocations=tool_invocations)
case "tool":
content = list(_map_headers_to_content(message_blocks))
return ToolMessage(invocation_id=header.params["id"], content=content)
case _:
raise MessageSyntaxError(f"Invalid message type: {header.name}")


def parse_messages(chat_str: str) -> list[Message]:
blocks = list(_MESSAGE_HEADER_REGEXES.split_text(chat_str))
if not blocks:
return []
if not isinstance(blocks[0], ParsedHeader):
blocks = [ParsedHeader("user", {})] + blocks
messages = []
for header, message_str in zip(blocks[::2], blocks[1::2]):
if not isinstance(header, ParsedHeader):
raise MessageSyntaxError(f"Invalid message header: {header}")
if not isinstance(message_str, str):
raise MessageSyntaxError(f"No content for message: {header}")
message_str = message_str.removeprefix("\n").removesuffix("\n")
if not message_str.strip():
_logger.warning(
f"Message doesn't contain non-whitespace symbols: {header}."
"Check newlines at the begginning and the end of the prompt."
)
messages.append(_parse_message(header, message_str))
return messages

0 comments on commit 5b5c22d

Please sign in to comment.