Skip to content

Commit

Permalink
Implement OpenAI wrapper.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Jul 2, 2024
1 parent 01c8a9c commit 30bd142
Show file tree
Hide file tree
Showing 10 changed files with 467 additions and 14 deletions.
24 changes: 21 additions & 3 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,37 @@ API Reference
.. module:: wanga


LLM API
-------

.. currentmodule:: wanga.models

Messages
^^^^^^^^

.. automodule:: wanga.models.messages
:members:

Models
^^^^^^

.. automodule:: wanga.models.model
:members:


Schema extraction and manipulation
----------------------------------

.. currentmodule:: wanga
.. currentmodule:: wanga.schema

Schema Definition
^^^^^^^^^^^^^^^^^

.. automodule:: wanga.schema.schema
.. automodule:: wanga.schema
:members:

Schema Extraction
^^^^^^^^^^^^^^^^^

.. automodule:: wanga.schema.extract
.. automodule:: wanga.schema.extractor
:members:
21 changes: 19 additions & 2 deletions tests/test_models/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def test_consistency():
messages = [
SystemMessage(content="Be a helpful assistant!\n"),
UserMessage(name="Alice", content=["Hello, world!"]),
UserMessage(name="Alice", content="Hello, world!"),
AssistantMessage(
name="Bob",
content="Here's a helpful image!",
Expand All @@ -33,7 +33,7 @@ def test_consistency():
ImageContent(url="https://example.com/image.png"),
],
),
UserMessage(name="Alice", content=["Thanks for the image!"]),
UserMessage(name="Alice", content="Thanks for the image!"),
]

chat_str = "\n".join(str(message) for message in messages)
Expand All @@ -54,3 +54,20 @@ def test_consistency():

parsed_messages = parse_messages(chat_str)
assert "\n".join(str(message) for message in parsed_messages) == chat_str.strip()


def test_num_blocks():
chat_str = r"""
[|system|]
You are a helpful assistant.
[|user|]
2 + 2 = ?
"""
chat_str = dedent(chat_str.removeprefix("\n"))
parsed_messages = parse_messages(chat_str)
system_message = parsed_messages[0]
user_message = parsed_messages[1]
assert isinstance(system_message, SystemMessage)
assert isinstance(system_message.content, str)
assert isinstance(user_message, UserMessage)
assert isinstance(user_message.content, str)
53 changes: 53 additions & 0 deletions tests/test_models/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import asyncio
from textwrap import dedent

from wanga.models.messages import parse_messages
from wanga.models.model import ToolParams
from wanga.models.openai import OpenaAIModel
from wanga.schema import default_schema_extractor


def test_reply():
model = OpenaAIModel("gpt-3.5-turbo")
prompt = r"""
[|system|]
You are a helpful assistant.
[|user|]
2 + 2 = ?
"""
prompt = dedent(prompt.removeprefix("\n"))

messages = parse_messages(prompt)
response = model.reply(messages)

response_text = response.response_options[0].message.content
assert isinstance(response_text, str)
assert "4" in response_text

async_response_text = asyncio.run(model.reply_async(messages), debug=True).response_options[0].message.content
assert isinstance(async_response_text, str)
assert "4" in async_response_text


def test_context_size():
assert OpenaAIModel("gpt-4-turbo").context_length == 128000
assert OpenaAIModel("gpt-4").context_length == 8192


def test_num_tokens():
model = OpenaAIModel("gpt-3.5-turbo")
prompt = r"""
[|system|]
You are a helpful assistant.
[|user|]
2 + 2 = ?
"""

def tool(x: int, y: str):
pass

tool_schema = default_schema_extractor.extract_schema(tool)
prompt = dedent(prompt.removeprefix("\n"))
messages = parse_messages(prompt)
tools = ToolParams(tools=[tool_schema])
assert abs(model.estimate_num_tokens(messages, tools) - model.calculate_num_tokens(messages, tools)) < 2
4 changes: 3 additions & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def test_normalize_schema():
typing.List: list,
typing.Union[typing.Union[int, float], str]: int | float | str,
(typing.Literal[1] | typing.Literal[2] | typing.Literal[3]): typing.Literal[1, 2, 3],
(typing.Literal[1, 2] | typing.Union[typing.Literal[2, 3], typing.Literal[3, 4]]): (typing.Literal[1, 2, 3, 4]),
(typing.Literal[1, 2] | typing.Union[typing.Literal[2, 3], typing.Literal[3, 4]]): (
typing.Literal[1, 2, 3, 4]
),
}
for annotation, result in expected.items():
assert normalize_annotation(annotation) == result
Expand Down
37 changes: 36 additions & 1 deletion wanga/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,38 @@
from typing import TypeAlias
from importlib import import_module
from types import ModuleType
from typing import TYPE_CHECKING, TypeAlias

__all__ = [
"JSON",
"openai",
"anthropic",
]

JSON: TypeAlias = str | float | int | bool | None | dict[str, "JSON"] | list["JSON"]


class LazyModule(ModuleType):
def __init__(self, name: str, extra: str, self_name: str):
super().__init__(name)
self.extra = extra
self.self_name = self_name

def __getattr__(self, name: str):
try:
module = import_module(self.__name__)
except ImportError:
module = import_module(
f"Module {self.__name__} not installed, install {self.extra}"
f"capability using `pip install {self.self_name}[{self.extra}]`"
)
return getattr(module, name)


if TYPE_CHECKING:
import anthropic
import openai
import openai_function_tokens
else:
openai: ModuleType = LazyModule("openai", "openai", "wanga")
openai_function_tokens = LazyModule("openai_function_tokens", "openai", "wanga")
anthropic: ModuleType = LazyModule("anthropic", "anthropic", "wanga")
2 changes: 2 additions & 0 deletions wanga/models/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def _parse_message(header: ParsedHeader, message_str: str) -> Message:
return SystemMessage(name=header.params.get("name"), content=message_str)
case "user":
content = list(_map_headers_to_content(message_blocks))
if len(content) == 1 and isinstance(content[0], str):
content = content[0]
return UserMessage(name=header.params.get("name"), content=content)
case "assistant":
if not message_blocks:
Expand Down
24 changes: 22 additions & 2 deletions wanga/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,30 @@ class ModelResponse:
usage: UsageStats


class AuthenticationError(Exception):
r"""Raised when the API credentials are invalid."""


class PromptTooLongError(ValueError):
pass
r"""Raised when the total length of prompt and requested response exceeds the maximum allowed number of tokens,
or the size of requested completion exceeds the maximum response size.
"""


class ModelTimeoutError(TimeoutError):
pass
r"""Raised when the model takes too long to generate a response."""


class InvalidJsonError(RuntimeError):
r"""Raised when the model returns a malformed JSON as a response to a function call."""


class RateLimitError(Exception):
r"""Raised when the request limit is exceeded."""


class ServiceUnvailableError(RuntimeError):
r"""Raised when the API is down."""


class Model:
Expand All @@ -75,6 +93,7 @@ def reply(
tools: ToolParams = ToolParams(),
params: GenerationParams = GenerationParams(),
num_options: int = 1,
user_id: str | None = None,
) -> ModelResponse:
raise NotImplementedError

Expand All @@ -84,6 +103,7 @@ async def reply_async(
tools: ToolParams = ToolParams(),
params: GenerationParams = GenerationParams(),
num_options: int = 1,
user_id: str | None = None,
) -> ModelResponse:
raise NotImplementedError

Expand Down
Loading

0 comments on commit 30bd142

Please sign in to comment.