-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
269 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from textwrap import dedent | ||
|
||
from wanga.models.messages import ( | ||
AssistantMessage, | ||
ImageContent, | ||
SystemMessage, | ||
ToolInvocation, | ||
ToolMessage, | ||
UserMessage, | ||
parse_messages, | ||
) | ||
|
||
|
||
def test_consistency(): | ||
messages = [ | ||
SystemMessage(content="Be a helpful assistant!\n"), | ||
UserMessage(name="Alice", content=["Hello, world!"]), | ||
AssistantMessage( | ||
name="Bob", | ||
content="Here's a helpful image!", | ||
tool_invocations=[ | ||
ToolInvocation( | ||
invocation_id="abc123", | ||
name="paste_image", | ||
arguments={"url": "https://example.com/image.png"}, | ||
) | ||
], | ||
), | ||
ToolMessage( | ||
invocation_id="abc123", | ||
content=[ | ||
"Here is the image you requested:\n\n", | ||
ImageContent(url="https://example.com/image.png"), | ||
], | ||
), | ||
UserMessage(name="Alice", content=["Thanks for the image!"]), | ||
] | ||
|
||
chat_str = "\n".join(str(message) for message in messages) | ||
parsed_messages = parse_messages(chat_str) | ||
assert parsed_messages == messages | ||
|
||
chat_str = r""" | ||
[|user|] | ||
Hi! Here is the pic<|image url="https://example.com/image.png"|> | ||
[|assistant|] | ||
It is beautiful! | ||
[|user|] | ||
I love you! | ||
[|assistant|] | ||
I love you too! | ||
""" | ||
chat_str = dedent(chat_str.removeprefix("\n")) | ||
|
||
parsed_messages = parse_messages(chat_str) | ||
assert "\n".join(str(message) for message in parsed_messages) == chat_str.strip() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,118 +1,276 @@ | ||
from pprint import pformat | ||
from typing import TYPE_CHECKING, TypeAlias | ||
import json | ||
import logging | ||
import re | ||
from collections.abc import Iterable | ||
from typing import NamedTuple | ||
|
||
from attrs import frozen | ||
|
||
from wanga import _PIL_INSTALLED | ||
from attrs import field, frozen | ||
|
||
from ..common import JSON | ||
|
||
__all__ = [ | ||
"ImageURL", | ||
"AssistantMessage", | ||
"ImageContent", | ||
"Message", | ||
"UserMessage", | ||
"AssistantMessage", | ||
"ToolInvocation", | ||
"ToolMessage", | ||
"UserMessage", | ||
] | ||
|
||
|
||
@frozen | ||
class ImageURL: | ||
url: str | ||
class TagPair(NamedTuple): | ||
open: str | ||
close: str | ||
|
||
|
||
if TYPE_CHECKING or _PIL_INSTALLED: | ||
from PIL.Image import Image | ||
_MESSAGE_TAGS = TagPair(r"[|", r"|]") | ||
_CONTENT_BLOCK_TAGS = TagPair(r"<|", r"|>") | ||
|
||
ImageContent: TypeAlias = Image | ImageURL | ||
else: | ||
ImageContent: TypeAlias = ImageURL | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
_MESSAGE_TAG_OPEN = "[|" | ||
_MESSAGE_TAG_CLOSE = "|]" | ||
|
||
_CONTENT_BLOCK_TAG_OPEN = "<|" | ||
_CONTENT_BLOCK_TAG_CLOSE = "|>" | ||
def _format_header(_tags: TagPair, _name: str, /, **kwargs: str) -> str: | ||
param_str = "".join(f' {key}="{value}"' for key, value in kwargs.items()) | ||
return f"{_tags.open}{_name}{param_str}{_tags.close}" | ||
|
||
|
||
def _pretty_print_message(role: str, name: str | None, content: str | list[str | ImageContent]) -> str: | ||
if name: | ||
role_str = f"{_MESSAGE_TAG_OPEN}{role} {name}{_MESSAGE_TAG_CLOSE}" | ||
else: | ||
role_str = f"{_MESSAGE_TAG_OPEN}{role}{_MESSAGE_TAG_CLOSE}" | ||
@frozen | ||
class ImageContent: | ||
r"""Image embedded in a message. Can either be a URL, or a base64-encoded file.""" | ||
|
||
content_strings = [] | ||
if isinstance(content, str): | ||
content = [content] | ||
for content_block in content: | ||
if isinstance(content_block, ImageURL): | ||
content_block = f'{_CONTENT_BLOCK_TAG_OPEN}image url="{content_block.url}"{_CONTENT_BLOCK_TAG_CLOSE}' | ||
if not isinstance(content_block, str): | ||
content_block = f"{_CONTENT_BLOCK_TAG_OPEN}image{_CONTENT_BLOCK_TAG_CLOSE}" | ||
content_strings.append(content_block) | ||
joined_content = "\n".join(content_strings) | ||
url: str | None = None | ||
base64: str | None = None | ||
|
||
return f"{role_str}\n{joined_content}" | ||
def __attrs_post_init__(self): | ||
if self.url is None and self.base64 is None: | ||
raise ValueError("ImageURL must have either a URL or base64 data") | ||
if self.url is not None and self.base64 is not None: | ||
raise ValueError("ImageURL cannot have both a URL and base64 data") | ||
|
||
def __str__(self) -> str: | ||
kwargs = {} | ||
if self.url: | ||
kwargs["url"] = self.url | ||
if self.base64: | ||
kwargs["base64"] = self.base64 | ||
return _format_header(_CONTENT_BLOCK_TAGS, "image", **kwargs) | ||
|
||
|
||
@frozen | ||
class Message: | ||
pass | ||
|
||
|
||
@frozen | ||
class SystemMessage(Message): | ||
name: str | None | ||
content: str | ||
name: str | None = None | ||
|
||
def __str__(self) -> str: | ||
return _pretty_print_message("system", self.name, self.content) | ||
kwargs = {} | ||
if self.name: | ||
kwargs["name"] = self.name | ||
header = _format_header(_MESSAGE_TAGS, "system", **kwargs) | ||
return f"{header}\n{self.content}" | ||
|
||
|
||
@frozen | ||
class UserMessage(Message): | ||
name: str | None | ||
content: str | list[str | ImageContent] | ||
name: str | None = None | ||
|
||
def __str__(self) -> str: | ||
return _pretty_print_message("user", self.name, self.content) | ||
kwargs = {} | ||
if self.name: | ||
kwargs["name"] = self.name | ||
header = _format_header(_MESSAGE_TAGS, "user", **kwargs) | ||
content = "".join(str(item) for item in self.content) | ||
return f"{header}\n{content}" | ||
|
||
|
||
@frozen | ||
class ToolInvocation: | ||
invocation_id: str | ||
tool_name: str | ||
tool_args: JSON | ||
name: str | ||
arguments: JSON | ||
|
||
def __str__(self) -> str: | ||
pretty_json = pformat(self.tool_args, sort_dicts=True) | ||
call_header = ( | ||
f'{_CONTENT_BLOCK_TAG_OPEN}call {self.tool_name} id="{self.invocation_id}"{_CONTENT_BLOCK_TAG_CLOSE}' | ||
) | ||
return f"{call_header}\n{pretty_json}" | ||
pretty_json = json.dumps(self.arguments, indent=4, sort_keys=True) | ||
header = _format_header(_CONTENT_BLOCK_TAGS, "tool", name=self.name, id=self.invocation_id) | ||
return f"{header}\n{pretty_json}" | ||
|
||
|
||
@frozen | ||
class AssistantMessage(Message): | ||
name: str | None | ||
content: str | None | ||
tool_invocations: list[ToolInvocation] | ||
name: str | None = None | ||
content: str | None = None | ||
tool_invocations: list[ToolInvocation] = field(factory=list) | ||
|
||
def __str__(self) -> str: | ||
kwargs = {} | ||
if self.name: | ||
kwargs["name"] = self.name | ||
header = _format_header(_MESSAGE_TAGS, "assistant", **kwargs) | ||
|
||
content_blocks = [] | ||
if self.content: | ||
content_blocks.append(self.content) | ||
for tool_invocation in self.tool_invocations: | ||
content_blocks.append(str(tool_invocation)) | ||
return "\n".join(content_blocks) | ||
joined_content = "\n".join(content_blocks) | ||
|
||
return f"{header}\n{joined_content}" | ||
|
||
|
||
@frozen | ||
class ToolMessage: | ||
class ToolMessage(Message): | ||
invocation_id: str | ||
tool_result: str | list[str | ImageContent] | ||
content: str | list[str | ImageContent] | ||
|
||
def __str__(self) -> str: | ||
return _pretty_print_message("tool", f"id={self.invocation_id}", self.tool_result) | ||
header = _format_header(_MESSAGE_TAGS, "tool", id=self.invocation_id) | ||
tool_result = self.content | ||
if not isinstance(self.content, list): | ||
tool_result = [self.content] | ||
content = "".join(str(item) for item in tool_result) | ||
return f"{header}\n{content}" | ||
|
||
|
||
class ParsedHeader(NamedTuple): | ||
name: str | ||
params: dict[str, str] | ||
|
||
|
||
class MessageSyntaxError(ValueError): | ||
pass | ||
|
||
|
||
class HeaderRegexes(NamedTuple): | ||
tag_regex: re.Pattern | ||
full_regex: re.Pattern | ||
|
||
def parse(self, header_str: str) -> ParsedHeader: | ||
header_match = self.full_regex.match(header_str) | ||
if header_match is None: | ||
raise MessageSyntaxError(f"Invalid header: {header_str}") | ||
name = header_match.group("name") | ||
params_str = header_match.group("params") | ||
params = {} | ||
for param_match in _PARAM_REGEX.finditer(params_str): | ||
params[param_match.group("param_name")] = param_match.group("param_value") | ||
return ParsedHeader(name, params) | ||
|
||
def split_text(self, text: str) -> Iterable[ParsedHeader | str]: | ||
for text_block in self.tag_regex.split(text): | ||
if self.tag_regex.match(text_block): | ||
yield self.parse(text_block) | ||
elif text_block: | ||
yield text_block | ||
|
||
|
||
_URL_SPECIAL_SYMBOLS = re.escape(r"/:!#$%&'*+-.^_`|~") | ||
_PARAM_KEY_SYMBOLS = r"[a-zA-Z0-9_\-]" | ||
_PARAM_VALUE_SYMBOLS = f"[a-zA-Z0-9{_URL_SPECIAL_SYMBOLS}]" | ||
|
||
|
||
_PARAM_REGEX_STR = f'(?P<param_name>{_PARAM_KEY_SYMBOLS}+) *= *["](?P<param_value>{_PARAM_VALUE_SYMBOLS}*)["]' | ||
_PARAM_REGEX = re.compile(_PARAM_REGEX_STR) | ||
|
||
_PARAMS_REGEX_STR = f"( +(?P<param>{_PARAM_REGEX_STR}))* *" | ||
_HEADER_BODY_REGEX_STR = f" *(?P<name>{_PARAM_KEY_SYMBOLS}+)(?P<params>{_PARAMS_REGEX_STR})" | ||
|
||
|
||
def _make_header_regex(tags: TagPair, inner_regex: str) -> HeaderRegexes: | ||
open_tag = re.escape(tags.open) | ||
close_tag = re.escape(tags.close) | ||
tag_regex = re.compile(f"({open_tag}.*{close_tag})") | ||
full_regex = re.compile(f"(^{open_tag}{_HEADER_BODY_REGEX_STR}{close_tag}$)") | ||
return HeaderRegexes(tag_regex, full_regex) | ||
|
||
|
||
_MESSAGE_HEADER_REGEXES = _make_header_regex(_MESSAGE_TAGS, _HEADER_BODY_REGEX_STR) | ||
_CONTENT_HEADER_REGEXES = _make_header_regex(_CONTENT_BLOCK_TAGS, _HEADER_BODY_REGEX_STR) | ||
|
||
|
||
def _parse_image_content(header: ParsedHeader) -> ImageContent: | ||
assert header.name == "image" | ||
try: | ||
return ImageContent(**header.params) | ||
except ValueError as e: | ||
raise MessageSyntaxError(f"Invalid image parameters: {header.params}") from e | ||
|
||
|
||
def _parse_tool_invocation(header: ParsedHeader, arg_text: str) -> ToolInvocation: | ||
assert header.name == "tool" | ||
try: | ||
kwargs = dict(header.params) | ||
kwargs["invocation_id"] = kwargs.pop("id") | ||
return ToolInvocation(**kwargs, arguments=json.loads(arg_text)) | ||
except (ValueError, KeyError) as e: | ||
raise MessageSyntaxError(f"Invalid tool invocation parameters: {header}\n{arg_text}") from e | ||
|
||
|
||
def _map_headers_to_content(blocks: Iterable[str | ParsedHeader]) -> Iterable[str | ImageContent]: | ||
for block in blocks: | ||
if isinstance(block, str): | ||
yield block | ||
else: | ||
yield _parse_image_content(block) | ||
|
||
|
||
def _parse_message(header: ParsedHeader, message_str: str) -> Message: | ||
message_blocks = list(_CONTENT_HEADER_REGEXES.split_text(message_str)) | ||
match header.name: | ||
case "system": | ||
if len(message_blocks) > 1: | ||
raise MessageSyntaxError(f"System message cannot contain anything other than text: {message_str}") | ||
return SystemMessage(name=header.params.get("name"), content=message_str) | ||
case "user": | ||
content = list(_map_headers_to_content(message_blocks)) | ||
return UserMessage(name=header.params.get("name"), content=content) | ||
case "assistant": | ||
if not message_blocks: | ||
raise MessageSyntaxError(f"No content in assistant message: {message_str}") | ||
if isinstance(message_blocks[0], ParsedHeader): | ||
content = None | ||
else: | ||
content = message_blocks.pop(0) | ||
assert isinstance(content, str) | ||
tool_invocations = [] | ||
for block_header, arg_str in zip(message_blocks[::2], message_blocks[1::2]): # type: ignore | ||
if not isinstance(block_header, ParsedHeader): | ||
raise MessageSyntaxError(f"Invalid tool invocation header: {block_header}") | ||
if not isinstance(arg_str, str): | ||
raise MessageSyntaxError(f"No arguments specified to tool invocation {block_header}") | ||
tool_invocations.append(_parse_tool_invocation(block_header, arg_str)) | ||
if tool_invocations and content is not None: | ||
content = content.removesuffix("\n") | ||
return AssistantMessage(name=header.params.get("name"), content=content, tool_invocations=tool_invocations) | ||
case "tool": | ||
content = list(_map_headers_to_content(message_blocks)) | ||
return ToolMessage(invocation_id=header.params["id"], content=content) | ||
case _: | ||
raise MessageSyntaxError(f"Invalid message type: {header.name}") | ||
|
||
|
||
def parse_messages(chat_str: str) -> list[Message]: | ||
blocks = list(_MESSAGE_HEADER_REGEXES.split_text(chat_str)) | ||
if not blocks: | ||
return [] | ||
if not isinstance(blocks[0], ParsedHeader): | ||
blocks = [ParsedHeader("user", {})] + blocks | ||
messages = [] | ||
for header, message_str in zip(blocks[::2], blocks[1::2]): | ||
if not isinstance(header, ParsedHeader): | ||
raise MessageSyntaxError(f"Invalid message header: {header}") | ||
if not isinstance(message_str, str): | ||
raise MessageSyntaxError(f"No content for message: {header}") | ||
message_str = message_str.removeprefix("\n").removesuffix("\n") | ||
if not message_str.strip(): | ||
_logger.warning( | ||
f"Message doesn't contain non-whitespace symbols: {header}." | ||
"Check newlines at the begginning and the end of the prompt." | ||
) | ||
messages.append(_parse_message(header, message_str)) | ||
return messages |