diff --git a/pyproject.toml b/pyproject.toml index 411e2df..9dede81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ sphinxcontrib-napoleon = "^0.7" toml = "^0.10.2" [tool.pytest.ini_options] -addopts = "--doctest-modules" +addopts = "--doctest-modules -vv --showlocals" [tool.pyright] pythonVersion = "3.10" diff --git a/tests/test_models/test_messages.py b/tests/test_models/test_messages.py new file mode 100644 index 0000000..6a1571c --- /dev/null +++ b/tests/test_models/test_messages.py @@ -0,0 +1,56 @@ +from textwrap import dedent + +from wanga.models.messages import ( + AssistantMessage, + ImageContent, + SystemMessage, + ToolInvocation, + ToolMessage, + UserMessage, + parse_messages, +) + + +def test_consistency(): + messages = [ + SystemMessage(content="Be a helpful assistant!\n"), + UserMessage(name="Alice", content=["Hello, world!"]), + AssistantMessage( + name="Bob", + content="Here's a helpful image!", + tool_invocations=[ + ToolInvocation( + invocation_id="abc123", + name="paste_image", + arguments={"url": "https://example.com/image.png"}, + ) + ], + ), + ToolMessage( + invocation_id="abc123", + content=[ + "Here is the image you requested:\n\n", + ImageContent(url="https://example.com/image.png"), + ], + ), + UserMessage(name="Alice", content=["Thanks for the image!"]), + ] + + chat_str = "\n".join(str(message) for message in messages) + parsed_messages = parse_messages(chat_str) + assert parsed_messages == messages + + chat_str = r""" + [|user|] + Hi! Here is the pic<|image url="https://example.com/image.png"|> + [|assistant|] + It is beautiful! + [|user|] + I love you! + [|assistant|] + I love you too! + """ + chat_str = dedent(chat_str.removeprefix("\n")) + + parsed_messages = parse_messages(chat_str) + assert "\n".join(str(message) for message in parsed_messages) == chat_str.strip() diff --git a/wanga/models/messages.py b/wanga/models/messages.py index 7a2d372..705b300 100644 --- a/wanga/models/messages.py +++ b/wanga/models/messages.py @@ -1,61 +1,60 @@ -from pprint import pformat -from typing import TYPE_CHECKING, TypeAlias +import json +import logging +import re +from collections.abc import Iterable +from typing import NamedTuple -from attrs import frozen - -from wanga import _PIL_INSTALLED +from attrs import field, frozen from ..common import JSON __all__ = [ - "ImageURL", + "AssistantMessage", "ImageContent", "Message", - "UserMessage", - "AssistantMessage", "ToolInvocation", "ToolMessage", + "UserMessage", ] -@frozen -class ImageURL: - url: str +class TagPair(NamedTuple): + open: str + close: str -if TYPE_CHECKING or _PIL_INSTALLED: - from PIL.Image import Image +_MESSAGE_TAGS = TagPair(r"[|", r"|]") +_CONTENT_BLOCK_TAGS = TagPair(r"<|", r"|>") - ImageContent: TypeAlias = Image | ImageURL -else: - ImageContent: TypeAlias = ImageURL +_logger = logging.getLogger(__name__) -_MESSAGE_TAG_OPEN = "[|" -_MESSAGE_TAG_CLOSE = "|]" -_CONTENT_BLOCK_TAG_OPEN = "<|" -_CONTENT_BLOCK_TAG_CLOSE = "|>" +def _format_header(_tags: TagPair, _name: str, /, **kwargs: str) -> str: + param_str = "".join(f' {key}="{value}"' for key, value in kwargs.items()) + return f"{_tags.open}{_name}{param_str}{_tags.close}" -def _pretty_print_message(role: str, name: str | None, content: str | list[str | ImageContent]) -> str: - if name: - role_str = f"{_MESSAGE_TAG_OPEN}{role} {name}{_MESSAGE_TAG_CLOSE}" - else: - role_str = f"{_MESSAGE_TAG_OPEN}{role}{_MESSAGE_TAG_CLOSE}" +@frozen +class ImageContent: + r"""Image embedded in a message. Can either be a URL, or a base64-encoded file.""" - content_strings = [] - if isinstance(content, str): - content = [content] - for content_block in content: - if isinstance(content_block, ImageURL): - content_block = f'{_CONTENT_BLOCK_TAG_OPEN}image url="{content_block.url}"{_CONTENT_BLOCK_TAG_CLOSE}' - if not isinstance(content_block, str): - content_block = f"{_CONTENT_BLOCK_TAG_OPEN}image{_CONTENT_BLOCK_TAG_CLOSE}" - content_strings.append(content_block) - joined_content = "\n".join(content_strings) + url: str | None = None + base64: str | None = None - return f"{role_str}\n{joined_content}" + def __attrs_post_init__(self): + if self.url is None and self.base64 is None: + raise ValueError("ImageURL must have either a URL or base64 data") + if self.url is not None and self.base64 is not None: + raise ValueError("ImageURL cannot have both a URL and base64 data") + + def __str__(self) -> str: + kwargs = {} + if self.url: + kwargs["url"] = self.url + if self.base64: + kwargs["base64"] = self.base64 + return _format_header(_CONTENT_BLOCK_TAGS, "image", **kwargs) @frozen @@ -63,56 +62,215 @@ class Message: pass +@frozen class SystemMessage(Message): - name: str | None content: str + name: str | None = None def __str__(self) -> str: - return _pretty_print_message("system", self.name, self.content) + kwargs = {} + if self.name: + kwargs["name"] = self.name + header = _format_header(_MESSAGE_TAGS, "system", **kwargs) + return f"{header}\n{self.content}" @frozen class UserMessage(Message): - name: str | None content: str | list[str | ImageContent] + name: str | None = None def __str__(self) -> str: - return _pretty_print_message("user", self.name, self.content) + kwargs = {} + if self.name: + kwargs["name"] = self.name + header = _format_header(_MESSAGE_TAGS, "user", **kwargs) + content = "".join(str(item) for item in self.content) + return f"{header}\n{content}" @frozen class ToolInvocation: invocation_id: str - tool_name: str - tool_args: JSON + name: str + arguments: JSON def __str__(self) -> str: - pretty_json = pformat(self.tool_args, sort_dicts=True) - call_header = ( - f'{_CONTENT_BLOCK_TAG_OPEN}call {self.tool_name} id="{self.invocation_id}"{_CONTENT_BLOCK_TAG_CLOSE}' - ) - return f"{call_header}\n{pretty_json}" + pretty_json = json.dumps(self.arguments, indent=4, sort_keys=True) + header = _format_header(_CONTENT_BLOCK_TAGS, "tool", name=self.name, id=self.invocation_id) + return f"{header}\n{pretty_json}" @frozen class AssistantMessage(Message): - name: str | None - content: str | None - tool_invocations: list[ToolInvocation] + name: str | None = None + content: str | None = None + tool_invocations: list[ToolInvocation] = field(factory=list) def __str__(self) -> str: + kwargs = {} + if self.name: + kwargs["name"] = self.name + header = _format_header(_MESSAGE_TAGS, "assistant", **kwargs) + content_blocks = [] if self.content: content_blocks.append(self.content) for tool_invocation in self.tool_invocations: content_blocks.append(str(tool_invocation)) - return "\n".join(content_blocks) + joined_content = "\n".join(content_blocks) + + return f"{header}\n{joined_content}" @frozen -class ToolMessage: +class ToolMessage(Message): invocation_id: str - tool_result: str | list[str | ImageContent] + content: str | list[str | ImageContent] def __str__(self) -> str: - return _pretty_print_message("tool", f"id={self.invocation_id}", self.tool_result) + header = _format_header(_MESSAGE_TAGS, "tool", id=self.invocation_id) + tool_result = self.content + if not isinstance(self.content, list): + tool_result = [self.content] + content = "".join(str(item) for item in tool_result) + return f"{header}\n{content}" + + +class ParsedHeader(NamedTuple): + name: str + params: dict[str, str] + + +class MessageSyntaxError(ValueError): + pass + + +class HeaderRegexes(NamedTuple): + tag_regex: re.Pattern + full_regex: re.Pattern + + def parse(self, header_str: str) -> ParsedHeader: + header_match = self.full_regex.match(header_str) + if header_match is None: + raise MessageSyntaxError(f"Invalid header: {header_str}") + name = header_match.group("name") + params_str = header_match.group("params") + params = {} + for param_match in _PARAM_REGEX.finditer(params_str): + params[param_match.group("param_name")] = param_match.group("param_value") + return ParsedHeader(name, params) + + def split_text(self, text: str) -> Iterable[ParsedHeader | str]: + for text_block in self.tag_regex.split(text): + if self.tag_regex.match(text_block): + yield self.parse(text_block) + elif text_block: + yield text_block + + +_URL_SPECIAL_SYMBOLS = re.escape(r"/:!#$%&'*+-.^_`|~") +_PARAM_KEY_SYMBOLS = r"[a-zA-Z0-9_\-]" +_PARAM_VALUE_SYMBOLS = f"[a-zA-Z0-9{_URL_SPECIAL_SYMBOLS}]" + + +_PARAM_REGEX_STR = f'(?P{_PARAM_KEY_SYMBOLS}+) *= *["](?P{_PARAM_VALUE_SYMBOLS}*)["]' +_PARAM_REGEX = re.compile(_PARAM_REGEX_STR) + +_PARAMS_REGEX_STR = f"( +(?P{_PARAM_REGEX_STR}))* *" +_HEADER_BODY_REGEX_STR = f" *(?P{_PARAM_KEY_SYMBOLS}+)(?P{_PARAMS_REGEX_STR})" + + +def _make_header_regex(tags: TagPair, inner_regex: str) -> HeaderRegexes: + open_tag = re.escape(tags.open) + close_tag = re.escape(tags.close) + tag_regex = re.compile(f"({open_tag}.*{close_tag})") + full_regex = re.compile(f"(^{open_tag}{_HEADER_BODY_REGEX_STR}{close_tag}$)") + return HeaderRegexes(tag_regex, full_regex) + + +_MESSAGE_HEADER_REGEXES = _make_header_regex(_MESSAGE_TAGS, _HEADER_BODY_REGEX_STR) +_CONTENT_HEADER_REGEXES = _make_header_regex(_CONTENT_BLOCK_TAGS, _HEADER_BODY_REGEX_STR) + + +def _parse_image_content(header: ParsedHeader) -> ImageContent: + assert header.name == "image" + try: + return ImageContent(**header.params) + except ValueError as e: + raise MessageSyntaxError(f"Invalid image parameters: {header.params}") from e + + +def _parse_tool_invocation(header: ParsedHeader, arg_text: str) -> ToolInvocation: + assert header.name == "tool" + try: + kwargs = dict(header.params) + kwargs["invocation_id"] = kwargs.pop("id") + return ToolInvocation(**kwargs, arguments=json.loads(arg_text)) + except (ValueError, KeyError) as e: + raise MessageSyntaxError(f"Invalid tool invocation parameters: {header}\n{arg_text}") from e + + +def _map_headers_to_content(blocks: Iterable[str | ParsedHeader]) -> Iterable[str | ImageContent]: + for block in blocks: + if isinstance(block, str): + yield block + else: + yield _parse_image_content(block) + + +def _parse_message(header: ParsedHeader, message_str: str) -> Message: + message_blocks = list(_CONTENT_HEADER_REGEXES.split_text(message_str)) + match header.name: + case "system": + if len(message_blocks) > 1: + raise MessageSyntaxError(f"System message cannot contain anything other than text: {message_str}") + return SystemMessage(name=header.params.get("name"), content=message_str) + case "user": + content = list(_map_headers_to_content(message_blocks)) + return UserMessage(name=header.params.get("name"), content=content) + case "assistant": + if not message_blocks: + raise MessageSyntaxError(f"No content in assistant message: {message_str}") + if isinstance(message_blocks[0], ParsedHeader): + content = None + else: + content = message_blocks.pop(0) + assert isinstance(content, str) + tool_invocations = [] + for block_header, arg_str in zip(message_blocks[::2], message_blocks[1::2]): # type: ignore + if not isinstance(block_header, ParsedHeader): + raise MessageSyntaxError(f"Invalid tool invocation header: {block_header}") + if not isinstance(arg_str, str): + raise MessageSyntaxError(f"No arguments specified to tool invocation {block_header}") + tool_invocations.append(_parse_tool_invocation(block_header, arg_str)) + if tool_invocations and content is not None: + content = content.removesuffix("\n") + return AssistantMessage(name=header.params.get("name"), content=content, tool_invocations=tool_invocations) + case "tool": + content = list(_map_headers_to_content(message_blocks)) + return ToolMessage(invocation_id=header.params["id"], content=content) + case _: + raise MessageSyntaxError(f"Invalid message type: {header.name}") + + +def parse_messages(chat_str: str) -> list[Message]: + blocks = list(_MESSAGE_HEADER_REGEXES.split_text(chat_str)) + if not blocks: + return [] + if not isinstance(blocks[0], ParsedHeader): + blocks = [ParsedHeader("user", {})] + blocks + messages = [] + for header, message_str in zip(blocks[::2], blocks[1::2]): + if not isinstance(header, ParsedHeader): + raise MessageSyntaxError(f"Invalid message header: {header}") + if not isinstance(message_str, str): + raise MessageSyntaxError(f"No content for message: {header}") + message_str = message_str.removeprefix("\n").removesuffix("\n") + if not message_str.strip(): + _logger.warning( + f"Message doesn't contain non-whitespace symbols: {header}." + "Check newlines at the begginning and the end of the prompt." + ) + messages.append(_parse_message(header, message_str)) + return messages