Skip to content

Commit

Permalink
Add handle_model_errors decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
skinnerjc committed Sep 20, 2024
1 parent 61e3415 commit a28f10b
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/inspect_ai/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from inspect_ai._util.content import Content, ContentImage, ContentText
from inspect_ai._util.deprecation import relocated_module_attribute
from inspect_ai.model._error_handler import LLMCannotAssistError, PromptTooLongError

from ._cache import (
CachePolicy,
Expand Down Expand Up @@ -52,13 +53,15 @@
"ChatMessageTool",
"ChatCompletionChoice",
"ModelOutput",
"LLMCannotAssistError",
"Logprobs",
"Logprob",
"TopLogprob",
"Model",
"ModelAPI",
"ModelName",
"ModelUsage",
"PromptTooLongError",
"StopReason",
"call_tools",
"cache_clear",
Expand Down
86 changes: 86 additions & 0 deletions src/inspect_ai/model/_error_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import re
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable

from ._model import ModelOutput


@dataclass
class PromptTooLongError(Exception):
message: str | None = None
tokens_used: int | None = None
model_limit_tokens: int | None = None


CANT_ASSIST = "Sorry, but I can't assist with that."


@dataclass
class LLMCannotAssistError(Exception):
message: str | None = CANT_ASSIST


def handle_model_errors(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
result = await func(*args, **kwargs)
if isinstance(result, tuple):
model_output, model_call = result
else:
model_output = result
raise_if_error(model_output)
return result

return wrapper


def raise_if_error(model_output: ModelOutput) -> None:
if not model_output.choices or not model_output.choices[0].message.content:
return

msg = str(model_output.choices[0].message.content)

error_checkers = [
check_openai_too_long,
check_gpt4_token_limit,
check_claude_token_limit,
check_openai_cannot_assist_with_that,
]

for checker in error_checkers:
checker(msg)


def check_openai_too_long(msg: str) -> None:
if re.search(r"Invalid 'messages\[[0-9]*].content': string too long", msg):
raise PromptTooLongError(message=msg)


def check_gpt4_token_limit(msg: str) -> None:
match = re.search(
r"This model's maximum context length is (\d+) tokens\. However, your messages resulted in (\d+) tokens",
msg,
)
if match:
raise PromptTooLongError(
message=msg,
tokens_used=int(match.group(2)),
model_limit_tokens=int(match.group(1)),
)


def check_claude_token_limit(msg: str) -> None:
match = re.search(r"prompt is too long: (\d+) tokens > (\d+) maximum", msg)
if match:
raise PromptTooLongError(
message=msg,
tokens_used=int(match.group(1)),
model_limit_tokens=int(match.group(2)),
)


def check_openai_cannot_assist_with_that(msg: str) -> None:
match = re.search(CANT_ASSIST, msg)
if match:
raise LLMCannotAssistError(message=msg)
12 changes: 8 additions & 4 deletions src/inspect_ai/model/_providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
ChatMessageAssistant,
ChatMessageSystem,
)
from .._error_handler import (
LLMCannotAssistError,
PromptTooLongError,
handle_model_errors,
)
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._model_call import ModelCall
Expand Down Expand Up @@ -110,6 +115,7 @@ def __init__(
**model_args,
)

@handle_model_errors
async def generate(
self,
input: list[ChatMessage],
Expand Down Expand Up @@ -220,11 +226,9 @@ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | None:
stop_reason: StopReason | None = None

if "prompt is too long" in error:
content = "Sorry, but your prompt is too long."
stop_reason = "length"
raise PromptTooLongError(message=content)
elif "content filtering" in error:
content = "Sorry, but I am unable to help with that request."
stop_reason = "content_filter"
raise LLMCannotAssistError

if content and stop_reason:
return ModelOutput.from_content(
Expand Down
5 changes: 3 additions & 2 deletions src/inspect_ai/model/_providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo

from .._chat_message import ChatMessage, ChatMessageAssistant
from .._error_handler import LLMCannotAssistError, handle_model_errors
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._model_call import ModelCall
Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
**model_args,
)

@handle_model_errors
async def generate(
self,
input: list[ChatMessage],
Expand Down Expand Up @@ -418,13 +420,12 @@ async def as_chat_completion_part(
# does not exhibit this behavior (it just returns the completion
# "Sorry, but I can't assist with that."
def handle_content_filter_error(e: APIStatusError) -> tuple[str, object | None]:
CANT_ASSIST = "Sorry, but I can't assist with that."
if e.status_code == 400:
if isinstance(e.body, dict) and "message" in e.body.keys():
message = str(e.body.get("message"))
return message, e.body
else:
return CANT_ASSIST, e.body
raise LLMCannotAssistError()
else:
raise e

Expand Down
102 changes: 102 additions & 0 deletions tests/model/test_error_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pytest

from inspect_ai.model import (
ChatCompletionChoice,
ChatMessageAssistant,
ModelOutput,
)
from inspect_ai.model._error_handler import (
CANT_ASSIST,
LLMCannotAssistError,
PromptTooLongError,
handle_model_errors,
raise_if_error,
)
from inspect_ai.model._model_call import ModelCall


def create_model_output(content):
return ModelOutput(
model="test_model",
choices=[
ChatCompletionChoice(
message=ChatMessageAssistant(content=content, source="generate"),
stop_reason="stop",
)
],
)


@pytest.mark.parametrize(
"content,expected_exception,expected_tokens_used,expected_model_limit",
[
("Normal response", None, None, None),
(
"Invalid 'messages[0].content': string too long",
PromptTooLongError,
None,
None,
),
(
"This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens",
PromptTooLongError,
5000,
4096,
),
(
"prompt is too long: 5000 tokens > 4096 maximum",
PromptTooLongError,
5000,
4096,
),
(CANT_ASSIST, LLMCannotAssistError, None, None),
],
)
def test_raise_if_error(
content, expected_exception, expected_tokens_used, expected_model_limit
):
model_output = create_model_output(content)

if expected_exception:
with pytest.raises(expected_exception) as exc_info:
raise_if_error(model_output)

if expected_exception == PromptTooLongError:
error = exc_info.value
assert error.message == content
assert error.tokens_used == expected_tokens_used
assert error.model_limit_tokens == expected_model_limit
else:
raise_if_error(model_output) # Should not raise an exception


@pytest.mark.asyncio
async def test_handle_model_errors_decorator():
@handle_model_errors
async def mock_generate():
return create_model_output(CANT_ASSIST)

with pytest.raises(LLMCannotAssistError):
await mock_generate()


@pytest.mark.asyncio
async def test_handle_model_errors_decorator_no_error():
@handle_model_errors
async def mock_generate():
return create_model_output("This is a normal response")

result = await mock_generate()
assert isinstance(result, ModelOutput)
assert result.choices[0].message.content == "This is a normal response"


@pytest.mark.asyncio
async def test_handle_model_errors_decorator_with_tuple():
@handle_model_errors
async def mock_generate():
model_output = create_model_output(CANT_ASSIST)
return model_output, ModelCall.create({}, {})

with pytest.raises(LLMCannotAssistError):
await mock_generate()

0 comments on commit a28f10b

Please sign in to comment.