-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
527 additions
and
37 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
from importlib.util import find_spec | ||
|
||
_PIL_INSTALLED = find_spec("Pillow") is None | ||
_OPENAI_INSTALLED = find_spec("openai") is None | ||
_ANTHROPIC_INSTALLED = find_spec("anthropic") is None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from typing import TypeAlias | ||
|
||
JSON: TypeAlias = str | float | int | bool | None | dict[str, "JSON"] | list["JSON"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from pprint import pformat | ||
from typing import TYPE_CHECKING, TypeAlias | ||
|
||
from attrs import frozen | ||
|
||
from wanga import _PIL_INSTALLED | ||
|
||
from ..common import JSON | ||
|
||
__all__ = [ | ||
"ImageURL", | ||
"ImageContent", | ||
"Message", | ||
"UserMessage", | ||
"AssistantMessage", | ||
"ToolInvocation", | ||
"ToolMessage", | ||
] | ||
|
||
|
||
@frozen | ||
class ImageURL: | ||
url: str | ||
|
||
|
||
if TYPE_CHECKING or _PIL_INSTALLED: | ||
from PIL.Image import Image | ||
|
||
ImageContent: TypeAlias = Image | ImageURL | ||
else: | ||
ImageContent: TypeAlias = ImageURL | ||
|
||
|
||
_MESSAGE_TAG_OPEN = "[|" | ||
_MESSAGE_TAG_CLOSE = "|]" | ||
|
||
_CONTENT_BLOCK_TAG_OPEN = "<|" | ||
_CONTENT_BLOCK_TAG_CLOSE = "|>" | ||
|
||
|
||
def _pretty_print_message(role: str, name: str | None, content: str | list[str | ImageContent]) -> str: | ||
if name: | ||
role_str = f"{_MESSAGE_TAG_OPEN}{role} {name}{_MESSAGE_TAG_CLOSE}" | ||
else: | ||
role_str = f"{_MESSAGE_TAG_OPEN}{role}{_MESSAGE_TAG_CLOSE}" | ||
|
||
content_strings = [] | ||
if isinstance(content, str): | ||
content = [content] | ||
for content_block in content: | ||
if isinstance(content_block, ImageURL): | ||
content_block = f'{_CONTENT_BLOCK_TAG_OPEN}image url="{content_block.url}"{_CONTENT_BLOCK_TAG_CLOSE}' | ||
if not isinstance(content_block, str): | ||
content_block = f"{_CONTENT_BLOCK_TAG_OPEN}image{_CONTENT_BLOCK_TAG_CLOSE}" | ||
content_strings.append(content_block) | ||
joined_content = "\n".join(content_strings) | ||
|
||
return f"{role_str}\n{joined_content}" | ||
|
||
|
||
@frozen | ||
class Message: | ||
pass | ||
|
||
|
||
class SystemMessage(Message): | ||
name: str | None | ||
content: str | ||
|
||
def __str__(self) -> str: | ||
return _pretty_print_message("system", self.name, self.content) | ||
|
||
|
||
@frozen | ||
class UserMessage(Message): | ||
name: str | None | ||
content: str | list[str | ImageContent] | ||
|
||
def __str__(self) -> str: | ||
return _pretty_print_message("user", self.name, self.content) | ||
|
||
|
||
@frozen | ||
class ToolInvocation: | ||
invocation_id: str | ||
tool_name: str | ||
tool_args: JSON | ||
|
||
def __str__(self) -> str: | ||
pretty_json = pformat(self.tool_args, sort_dicts=True) | ||
call_header = ( | ||
f'{_CONTENT_BLOCK_TAG_OPEN}call {self.tool_name} id="{self.invocation_id}"{_CONTENT_BLOCK_TAG_CLOSE}' | ||
) | ||
return f"{call_header}\n{pretty_json}" | ||
|
||
|
||
@frozen | ||
class AssistantMessage(Message): | ||
name: str | None | ||
content: str | None | ||
tool_invocations: list[ToolInvocation] | ||
|
||
def __str__(self) -> str: | ||
content_blocks = [] | ||
if self.content: | ||
content_blocks.append(self.content) | ||
for tool_invocation in self.tool_invocations: | ||
content_blocks.append(str(tool_invocation)) | ||
return "\n".join(content_blocks) | ||
|
||
|
||
@frozen | ||
class ToolMessage: | ||
invocation_id: str | ||
tool_result: str | list[str | ImageContent] | ||
|
||
def __str__(self) -> str: | ||
return _pretty_print_message("tool", f"id={self.invocation_id}", self.tool_result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from enum import Enum | ||
|
||
from attrs import field, frozen | ||
|
||
from ..schema.schema import CallableSchema | ||
from .messages import AssistantMessage, Message | ||
|
||
|
||
class ToolUseMode(Enum): | ||
AUTO = "auto" | ||
FORCE = "force" | ||
NEVER = "never" | ||
|
||
|
||
@frozen | ||
class ToolParams: | ||
tools: list[CallableSchema] = field(factory=list) | ||
tool_use_mode: ToolUseMode | str = ToolUseMode.AUTO | ||
allow_parallel_calls: bool = False | ||
|
||
|
||
@frozen | ||
class GenerationParams: | ||
max_tokens: int | None = None | ||
temperature: float | None = None | ||
top_p: float | None = None | ||
frequency_penalty: float | None = None | ||
presence_penalty: float | None = None | ||
stop_sequences: list[str] = field(factory=list) | ||
random_seed: int | None = None | ||
force_json: bool = False | ||
|
||
|
||
@frozen | ||
class UsageStats: | ||
prompt_tokens: int | ||
response_tokens: int | ||
|
||
@property | ||
def total_tokens(self) -> int: | ||
return self.prompt_tokens + self.response_tokens | ||
|
||
|
||
class FinishReason(Enum): | ||
STOP = "stop" | ||
LENGTH = "length" | ||
TOOL_CALL = "tool_call" | ||
CONTENT_FILTER = "content_filter" | ||
|
||
|
||
@frozen | ||
class ResponseOption: | ||
message: AssistantMessage | ||
finish_reason: FinishReason | ||
|
||
|
||
@frozen | ||
class ModelResponse: | ||
response_options: list[ResponseOption] | ||
usage: UsageStats | ||
|
||
|
||
class PromptTooLongError(ValueError): | ||
pass | ||
|
||
|
||
class ModelTimeoutError(TimeoutError): | ||
pass | ||
|
||
|
||
class Model: | ||
def reply( | ||
self, | ||
messages: list[Message], | ||
tools: ToolParams = ToolParams(), | ||
params: GenerationParams = GenerationParams(), | ||
num_options: int = 1, | ||
) -> ModelResponse: | ||
raise NotImplementedError | ||
|
||
async def reply_async( | ||
self, | ||
messages: list[Message], | ||
tools: ToolParams = ToolParams(), | ||
params: GenerationParams = GenerationParams(), | ||
num_options: int = 1, | ||
) -> ModelResponse: | ||
raise NotImplementedError | ||
|
||
@property | ||
def context_length(self) -> int: | ||
raise NotImplementedError | ||
|
||
def estimate_num_tokens(self, messages: list[Message], tools: ToolParams) -> int: | ||
r"""Returns the rough estimate of the total number of tokens in the prompt. | ||
May return inaccurate results, use `calculate_num_tokens` instead for precise results. | ||
Note that `calculate_num_tokens` may send request to the LLM api, which will count towards your usage bill. | ||
""" | ||
raise NotImplementedError | ||
|
||
def calculate_num_tokens(self, messages: list[Message], tools: ToolParams) -> int: | ||
r"""Returns the precise number of tokens in the prompt. | ||
This method will send a request to the LLM api, which will count towards your usage bill. | ||
""" | ||
raise NotImplementedError |