Skip to content

Commit

Permalink
Implement backoff logic for openai models.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Sep 2, 2024
1 parent c5d411f commit fa2cf3b
Show file tree
Hide file tree
Showing 8 changed files with 632 additions and 545 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
.venv
.mypy_cache
.pytest_cache
.ruff_cache
**/.DS_Store
**/__pycache__

Expand Down
4 changes: 2 additions & 2 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
];
shellHook = ''
poetry env use ${python}/bin/python
poetry install --no-root
export PYTHONPATH=$PYTHONPATH:$PWD/nuggets
poetry install --no-root --with dev
ln -sf $(poetry env info -p) .venv
'';
};
};
Expand Down
964 changes: 441 additions & 523 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ docstring-parser = "^0.16"
jinja2 = "^3.1.4"
python = ">3.10,<4.0"
tiktoken = "^0.7.0"
tenacity = "^8.5.0"

[tool.poetry.extras]
openai = ["openai", "openai-function-tokens"]
Expand All @@ -22,14 +23,13 @@ optional = true

[tool.poetry.group.dev.dependencies]
anthropic = "^0.30.0"
isort = "^5.13.2"
openai = "^1.35.7"
pillow = "^10.3.0"
pydantic = "^2.7.4"
pyright = "^1.1.369"
pytest = "^8.2.2"
pytest-asyncio = "^0.24.0"
pytest-dotenv = "^0.5.2"
ruff-lsp = "^0.0.53"
openai-function-tokens = "^0.1.2"
# Documentation
sphinx = "^7.3.7"
Expand All @@ -41,6 +41,8 @@ toml = "^0.10.2"
addopts = "--doctest-modules -vv --showlocals"

[tool.pyright]
venv = "."
venvPath = ".venv"
pythonVersion = "3.10"
typeCheckingMode = "standard"

Expand Down
129 changes: 128 additions & 1 deletion tests/test_models/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import asyncio
from textwrap import dedent
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from httpx import Request
from openai import APIConnectionError, RateLimitError

from wanga.models.messages import parse_messages
from wanga.models.model import ToolParams
from wanga.models.model import FinishReason, ModelResponse, ToolParams
from wanga.models.model import RateLimitError as WangaRateLimitError
from wanga.models.openai import OpenaAIModel
from wanga.schema import default_schema_extractor

Expand Down Expand Up @@ -51,3 +57,124 @@ def tool(x: int, y: str):
messages = parse_messages(prompt)
tools = ToolParams(tools=[tool_schema])
assert abs(model.estimate_num_tokens(messages, tools) - model.calculate_num_tokens(messages, tools)) < 2


@pytest.fixture
def model():
return OpenaAIModel("gpt-3.5-turbo", num_retries=2, retry_on_request_limit=True)


def test_retry_on_rate_limit(model):
with patch.object(model._client.chat.completions, "create") as mock_create:
mock_create.side_effect = [
RateLimitError(
message="Rate limit exceeded",
response=MagicMock(status_code=429),
body={"error": {"message": "Rate limit exceeded"}},
),
RateLimitError(
message="Rate limit exceeded",
response=MagicMock(status_code=429),
body={"error": {"message": "Rate limit exceeded"}},
),
MagicMock(
choices=[MagicMock(message=MagicMock(content="4"), finish_reason="stop", index=0)],
usage=MagicMock(prompt_tokens=10, completion_tokens=5),
),
]

messages = parse_messages("2 + 2 = ?")
response = model.reply(messages)

assert mock_create.call_count == 3
assert isinstance(response, ModelResponse)
assert len(response.response_options) == 1
assert response.response_options[0].finish_reason == FinishReason.STOP
assert "4" in response.response_options[0].message.content # type: ignore
assert response.usage.prompt_tokens == 10
assert response.usage.response_tokens == 5


def test_retry_on_service_unavailable(model):
with patch.object(model._client.chat.completions, "create") as mock_create:
mock_create.side_effect = [
APIConnectionError(message="Service unavailable", request=Request("get", "https://api.openai.com")),
APIConnectionError(message="Service unavailable", request=Request("get", "https://api.openai.com")),
MagicMock(
choices=[MagicMock(message=MagicMock(content="4"), finish_reason="stop", index=0)],
usage=MagicMock(prompt_tokens=10, completion_tokens=5),
),
]

messages = parse_messages("2 + 2 = ?")
response = model.reply(messages)

assert mock_create.call_count == 3
assert isinstance(response, ModelResponse)
assert len(response.response_options) == 1
assert response.response_options[0].finish_reason == FinishReason.STOP
assert "4" in response.response_options[0].message.content # type: ignore
assert response.usage.prompt_tokens == 10
assert response.usage.response_tokens == 5


def test_no_retry_on_other_errors(model):
with patch.object(model._client.chat.completions, "create") as mock_create:
mock_create.side_effect = ValueError("Some other error")

messages = parse_messages("2 + 2 = ?")
with pytest.raises(ValueError):
model.reply(messages)

assert mock_create.call_count == 1


@pytest.mark.asyncio
async def test_async_retry_on_rate_limit(model):
mock_response = MagicMock(
choices=[MagicMock(message=MagicMock(content="4"), finish_reason="stop", index=0)],
usage=MagicMock(prompt_tokens=10, completion_tokens=5),
)

call_count = 0

async def side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
raise RateLimitError(
message="Rate limit exceeded",
response=MagicMock(status_code=429),
body={"error": {"message": "Rate limit exceeded"}},
)
return mock_response

with patch.object(model._async_client.chat.completions, "create", new_callable=AsyncMock) as mock_create:
mock_create.side_effect = side_effect

messages = parse_messages("2 + 2 = ?")
response = await model.reply_async(messages)

assert call_count == 3
assert isinstance(response, ModelResponse)
assert len(response.response_options) == 1
assert response.response_options[0].finish_reason == FinishReason.STOP
assert "4" in response.response_options[0].message.content # type: ignore
assert response.usage.prompt_tokens == 10
assert response.usage.response_tokens == 5


def test_no_retry_when_disabled(model):
model._retry_on_request_limit = False
with patch.object(model._client.chat.completions, "create") as mock_create:
mock_create.side_effect = RateLimitError(
message="Rate limit exceeded",
response=MagicMock(status_code=429),
body={"error": {"message": "Rate limit exceeded"}},
)

messages = parse_messages("2 + 2 = ?")
with pytest.raises(WangaRateLimitError):
model.reply(messages)

assert mock_create.call_count == 1
8 changes: 6 additions & 2 deletions wanga/models/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _make_header_regex(tags: TagPair, inner_regex: str) -> HeaderRegexes:
open_tag = re.escape(tags.open)
close_tag = re.escape(tags.close)
tag_regex = re.compile(f"({open_tag}.*{close_tag})")
full_regex = re.compile(f"(^{open_tag}{_HEADER_BODY_REGEX_STR}{close_tag}$)")
full_regex = re.compile(f"(^{open_tag}{inner_regex}{close_tag}$)")
return HeaderRegexes(tag_regex, full_regex)


Expand Down Expand Up @@ -248,7 +248,11 @@ def _parse_message(header: ParsedHeader, message_str: str) -> Message:
tool_invocations.append(_parse_tool_invocation(block_header, arg_str))
if tool_invocations and content is not None:
content = content.removesuffix("\n")
return AssistantMessage(name=header.params.get("name"), content=content, tool_invocations=tool_invocations)
return AssistantMessage(
name=header.params.get("name"),
content=content,
tool_invocations=tool_invocations,
)
case "tool":
content = list(_map_headers_to_content(message_blocks))
return ToolMessage(invocation_id=header.params["id"], content=content)
Expand Down
63 changes: 49 additions & 14 deletions wanga/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import re

from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential

from ..common import openai, openai_function_tokens
from ..schema import CallableSchema, JsonSchemaFlavor
from .messages import AssistantMessage, ImageContent, Message, SystemMessage, ToolInvocation, ToolMessage, UserMessage
Expand All @@ -21,7 +23,6 @@
UsageStats,
)

_TOKENIZER = "cl100k_base"
_TOO_MANY_TOKENS = 100_000_000_000

_NUM_TOKENS_ERR_RE = re.compile(r"\((?P<messages>\d+) in the messages(, (?P<functions>\d+) in the functions,)?")
Expand All @@ -37,9 +38,18 @@ class OpenaAIModel(Model):
# We sort the keys such that 'gpt-4-turbo' comes before 'gpt-4'.
_NAME_PREFIX_TO_CONTEXT_LENGTH = {k: v for k, v in sorted(_NAME_PREFIX_TO_CONTEXT_LENGTH.items(), reverse=True)}

def __init__(self, model_name: str, api_key: str | None = None, timeout: float = 10 * 60, num_retries: int = 0):
self._client = openai.OpenAI(api_key=api_key, timeout=timeout, max_retries=num_retries)
self._async_client = openai.AsyncOpenAI(api_key=api_key, timeout=timeout, max_retries=num_retries)
def __init__(
self,
model_name: str,
api_key: str | None = None,
request_timeout: float = 10 * 60,
num_retries: int = 0,
retry_on_request_limit: bool = True,
):
self._retry_on_request_limit = retry_on_request_limit
self._num_retries = num_retries
self._client = openai.OpenAI(api_key=api_key, timeout=request_timeout, max_retries=0)
self._async_client = openai.AsyncOpenAI(api_key=api_key, timeout=request_timeout, max_retries=0)
if model_name not in self._list_available_models():
raise ValueError(f"Model {model_name} is not available")
self._model_name = model_name
Expand Down Expand Up @@ -129,11 +139,16 @@ def reply(
user_id: str | None = None,
) -> ModelResponse:
kwargs = self._get_reply_kwargs(messages, tools, params, num_options, user_id)
try:
response = self._client.chat.completions.create(**kwargs)
except openai.OpenAIError as e:
raise _wrap_error(e)
return _parse_response(response)

@self._create_retry_decorator()
def _reply_with_retry():
try:
response = self._client.chat.completions.create(**kwargs)
return _parse_response(response)
except openai.OpenAIError as e:
raise _wrap_error(e)

return _reply_with_retry()

async def reply_async(
self,
Expand All @@ -144,11 +159,28 @@ async def reply_async(
user_id: str | None = None,
) -> ModelResponse:
kwargs = self._get_reply_kwargs(messages, tools, params, num_options, user_id)
try:
response = await self._async_client.chat.completions.create(**kwargs)
except Exception as e:
raise _wrap_error(e)
return _parse_response(response)

@self._create_retry_decorator()
async def _reply_async_with_retry():
try:
response = await self._async_client.chat.completions.create(**kwargs)
return _parse_response(response)
except openai.OpenAIError as e:
raise _wrap_error(e)

return await _reply_async_with_retry()

def _create_retry_decorator(self):
retry_conditions = retry_if_exception_type(ServiceUnvailableError)
if self._retry_on_request_limit:
retry_conditions |= retry_if_exception_type(RateLimitError)

return retry(
stop=stop_after_attempt(self._num_retries + 1),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_conditions,
retry_error_callback=lambda retry_state: retry_state.outcome.exception() if retry_state.outcome else None,
)


def _wrap_error(error: Exception) -> Exception:
Expand All @@ -163,6 +195,9 @@ def _wrap_error(error: Exception) -> Exception:
return RateLimitError(e)
case openai.InternalServerError() | openai.APIConnectionError() as e:
return ServiceUnvailableError(e)
case openai.APIError() as e:
# Handle other API errors that might be retryable
return ServiceUnvailableError(e)
case _:
return error

Expand Down
2 changes: 1 addition & 1 deletion wanga/schema/normalize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import collections.abc
import typing # noqa
import typing # noqa: F401
from types import NoneType, UnionType
from typing import Annotated, Literal, Union, get_args, get_origin

Expand Down

0 comments on commit fa2cf3b

Please sign in to comment.