Skip to content

Commit

Permalink
Add model scaffolding.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Jun 29, 2024
1 parent f219ee3 commit cd59c71
Show file tree
Hide file tree
Showing 8 changed files with 527 additions and 37 deletions.
301 changes: 273 additions & 28 deletions poetry.lock

Large diffs are not rendered by default.

27 changes: 20 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,33 @@ license = "MIT"
readme = "README.md"

[tool.poetry.dependencies]
python = ">3.10,<4.0"
jinja2 = "^3.1.4"
attrs = "^23.2.0"
docstring-parser = "^0.16"
jinja2 = "^3.1.4"
python = ">3.10,<4.0"

[tool.poetry.extras]
openai = ["openai", "openai-function-tokens", "tiktoken"]
anthropic = ["anthropic"]

[tool.poetry.group.dev]
optional = true

[tool.poetry.group.dev.dependencies]
anthropic = "^0.30.0"
isort = "^5.13.2"
openai = "^1.35.7"
pillow = "^10.3.0"
pydantic = "^2.7.4"
pyright = "^1.1.369"
pytest = "^8.2.2"
pytest-dotenv = "^0.5.2"
openai = "^1.35.3"
anthropic = "^0.29.0"
ruff-lsp = "^0.0.53"
isort = "^5.13.2"
pyright = "^1.1.369"
pydantic = "^2.7.4"
openai-function-tokens = "^0.1.2"
tiktoken = "^0.7.0"

[tool.poetry.group.docs]
optional = true

[tool.poetry.group.docs.dependencies]
sphinx = "^7.3.7"
Expand Down
Empty file added tests/test_models/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@ class Qux:
baz: The baz.
"""

x: int
x: "int"
baz: Baz

qux_schema = CallableSchema(
return_schema=UndefinedNode(original_annotation=None),
call_schema=ObjectNode(
constructor_fn=Qux,
constructor_signature=inspect.signature(Qux),
constructor_signature=inspect.signature(Qux, eval_str=True),
name="Qux",
hint="I am Qux.",
fields=[
Expand Down
4 changes: 4 additions & 0 deletions wanga/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from importlib.util import find_spec

_PIL_INSTALLED = find_spec("Pillow") is None
_OPENAI_INSTALLED = find_spec("openai") is None
_ANTHROPIC_INSTALLED = find_spec("anthropic") is None
3 changes: 3 additions & 0 deletions wanga/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import TypeAlias

JSON: TypeAlias = str | float | int | bool | None | dict[str, "JSON"] | list["JSON"]
118 changes: 118 additions & 0 deletions wanga/models/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from pprint import pformat
from typing import TYPE_CHECKING, TypeAlias

from attrs import frozen

from wanga import _PIL_INSTALLED

from ..common import JSON

__all__ = [
"ImageURL",
"ImageContent",
"Message",
"UserMessage",
"AssistantMessage",
"ToolInvocation",
"ToolMessage",
]


@frozen
class ImageURL:
url: str


if TYPE_CHECKING or _PIL_INSTALLED:
from PIL.Image import Image

ImageContent: TypeAlias = Image | ImageURL
else:
ImageContent: TypeAlias = ImageURL


_MESSAGE_TAG_OPEN = "[|"
_MESSAGE_TAG_CLOSE = "|]"

_CONTENT_BLOCK_TAG_OPEN = "<|"
_CONTENT_BLOCK_TAG_CLOSE = "|>"


def _pretty_print_message(role: str, name: str | None, content: str | list[str | ImageContent]) -> str:
if name:
role_str = f"{_MESSAGE_TAG_OPEN}{role} {name}{_MESSAGE_TAG_CLOSE}"
else:
role_str = f"{_MESSAGE_TAG_OPEN}{role}{_MESSAGE_TAG_CLOSE}"

content_strings = []
if isinstance(content, str):
content = [content]
for content_block in content:
if isinstance(content_block, ImageURL):
content_block = f'{_CONTENT_BLOCK_TAG_OPEN}image url="{content_block.url}"{_CONTENT_BLOCK_TAG_CLOSE}'
if not isinstance(content_block, str):
content_block = f"{_CONTENT_BLOCK_TAG_OPEN}image{_CONTENT_BLOCK_TAG_CLOSE}"
content_strings.append(content_block)
joined_content = "\n".join(content_strings)

return f"{role_str}\n{joined_content}"


@frozen
class Message:
pass


class SystemMessage(Message):
name: str | None
content: str

def __str__(self) -> str:
return _pretty_print_message("system", self.name, self.content)


@frozen
class UserMessage(Message):
name: str | None
content: str | list[str | ImageContent]

def __str__(self) -> str:
return _pretty_print_message("user", self.name, self.content)


@frozen
class ToolInvocation:
invocation_id: str
tool_name: str
tool_args: JSON

def __str__(self) -> str:
pretty_json = pformat(self.tool_args, sort_dicts=True)
call_header = (
f'{_CONTENT_BLOCK_TAG_OPEN}call {self.tool_name} id="{self.invocation_id}"{_CONTENT_BLOCK_TAG_CLOSE}'
)
return f"{call_header}\n{pretty_json}"


@frozen
class AssistantMessage(Message):
name: str | None
content: str | None
tool_invocations: list[ToolInvocation]

def __str__(self) -> str:
content_blocks = []
if self.content:
content_blocks.append(self.content)
for tool_invocation in self.tool_invocations:
content_blocks.append(str(tool_invocation))
return "\n".join(content_blocks)


@frozen
class ToolMessage:
invocation_id: str
tool_result: str | list[str | ImageContent]

def __str__(self) -> str:
return _pretty_print_message("tool", f"id={self.invocation_id}", self.tool_result)
107 changes: 107 additions & 0 deletions wanga/models/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from enum import Enum

from attrs import field, frozen

from ..schema.schema import CallableSchema
from .messages import AssistantMessage, Message


class ToolUseMode(Enum):
AUTO = "auto"
FORCE = "force"
NEVER = "never"


@frozen
class ToolParams:
tools: list[CallableSchema] = field(factory=list)
tool_use_mode: ToolUseMode | str = ToolUseMode.AUTO
allow_parallel_calls: bool = False


@frozen
class GenerationParams:
max_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
frequency_penalty: float | None = None
presence_penalty: float | None = None
stop_sequences: list[str] = field(factory=list)
random_seed: int | None = None
force_json: bool = False


@frozen
class UsageStats:
prompt_tokens: int
response_tokens: int

@property
def total_tokens(self) -> int:
return self.prompt_tokens + self.response_tokens


class FinishReason(Enum):
STOP = "stop"
LENGTH = "length"
TOOL_CALL = "tool_call"
CONTENT_FILTER = "content_filter"


@frozen
class ResponseOption:
message: AssistantMessage
finish_reason: FinishReason


@frozen
class ModelResponse:
response_options: list[ResponseOption]
usage: UsageStats


class PromptTooLongError(ValueError):
pass


class ModelTimeoutError(TimeoutError):
pass


class Model:
def reply(
self,
messages: list[Message],
tools: ToolParams = ToolParams(),
params: GenerationParams = GenerationParams(),
num_options: int = 1,
) -> ModelResponse:
raise NotImplementedError

async def reply_async(
self,
messages: list[Message],
tools: ToolParams = ToolParams(),
params: GenerationParams = GenerationParams(),
num_options: int = 1,
) -> ModelResponse:
raise NotImplementedError

@property
def context_length(self) -> int:
raise NotImplementedError

def estimate_num_tokens(self, messages: list[Message], tools: ToolParams) -> int:
r"""Returns the rough estimate of the total number of tokens in the prompt.
May return inaccurate results, use `calculate_num_tokens` instead for precise results.
Note that `calculate_num_tokens` may send request to the LLM api, which will count towards your usage bill.
"""
raise NotImplementedError

def calculate_num_tokens(self, messages: list[Message], tools: ToolParams) -> int:
r"""Returns the precise number of tokens in the prompt.
This method will send a request to the LLM api, which will count towards your usage bill.
"""
raise NotImplementedError

0 comments on commit cd59c71

Please sign in to comment.