diff --git a/src/inspect_ai/model/__init__.py b/src/inspect_ai/model/__init__.py index 6a21dd1a8..0800e0430 100644 --- a/src/inspect_ai/model/__init__.py +++ b/src/inspect_ai/model/__init__.py @@ -2,6 +2,7 @@ from inspect_ai._util.content import Content, ContentImage, ContentText from inspect_ai._util.deprecation import relocated_module_attribute +from inspect_ai.model._error_handler import LLMCannotAssistError, PromptTooLongError from ._cache import ( CachePolicy, @@ -52,6 +53,7 @@ "ChatMessageTool", "ChatCompletionChoice", "ModelOutput", + "LLMCannotAssistError", "Logprobs", "Logprob", "TopLogprob", @@ -59,6 +61,7 @@ "ModelAPI", "ModelName", "ModelUsage", + "PromptTooLongError", "StopReason", "call_tools", "cache_clear", diff --git a/src/inspect_ai/model/_error_handler.py b/src/inspect_ai/model/_error_handler.py new file mode 100644 index 000000000..a7717710a --- /dev/null +++ b/src/inspect_ai/model/_error_handler.py @@ -0,0 +1,86 @@ +import re +from dataclasses import dataclass +from functools import wraps +from typing import Any, Callable + +from ._model import ModelOutput + + +@dataclass +class PromptTooLongError(Exception): + message: str | None = None + tokens_used: int | None = None + model_limit_tokens: int | None = None + + +CANT_ASSIST = "Sorry, but I can't assist with that." + + +@dataclass +class LLMCannotAssistError(Exception): + message: str | None = CANT_ASSIST + + +def handle_model_errors(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + result = await func(*args, **kwargs) + if isinstance(result, tuple): + model_output, model_call = result + else: + model_output = result + raise_if_error(model_output) + return result + + return wrapper + + +def raise_if_error(model_output: ModelOutput) -> None: + if not model_output.choices or not model_output.choices[0].message.content: + return + + msg = str(model_output.choices[0].message.content) + + error_checkers = [ + check_openai_too_long, + check_gpt4_token_limit, + check_claude_token_limit, + check_openai_cannot_assist_with_that, + ] + + for checker in error_checkers: + checker(msg) + + +def check_openai_too_long(msg: str) -> None: + if re.search(r"Invalid 'messages\[[0-9]*].content': string too long", msg): + raise PromptTooLongError(message=msg) + + +def check_gpt4_token_limit(msg: str) -> None: + match = re.search( + r"This model's maximum context length is (\d+) tokens\. However, your messages resulted in (\d+) tokens", + msg, + ) + if match: + raise PromptTooLongError( + message=msg, + tokens_used=int(match.group(2)), + model_limit_tokens=int(match.group(1)), + ) + + +def check_claude_token_limit(msg: str) -> None: + match = re.search(r"prompt is too long: (\d+) tokens > (\d+) maximum", msg) + if match: + raise PromptTooLongError( + message=msg, + tokens_used=int(match.group(1)), + model_limit_tokens=int(match.group(2)), + ) + + +def check_openai_cannot_assist_with_that(msg: str) -> None: + match = re.search(CANT_ASSIST, msg) + if match: + raise LLMCannotAssistError(message=msg) diff --git a/src/inspect_ai/model/_providers/anthropic.py b/src/inspect_ai/model/_providers/anthropic.py index baeaa1d06..ec3e10eba 100644 --- a/src/inspect_ai/model/_providers/anthropic.py +++ b/src/inspect_ai/model/_providers/anthropic.py @@ -40,6 +40,11 @@ ChatMessageAssistant, ChatMessageSystem, ) +from .._error_handler import ( + LLMCannotAssistError, + PromptTooLongError, + handle_model_errors, +) from .._generate_config import GenerateConfig from .._model import ModelAPI from .._model_call import ModelCall @@ -110,6 +115,7 @@ def __init__( **model_args, ) + @handle_model_errors async def generate( self, input: list[ChatMessage], @@ -220,11 +226,9 @@ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | None: stop_reason: StopReason | None = None if "prompt is too long" in error: - content = "Sorry, but your prompt is too long." - stop_reason = "length" + raise PromptTooLongError(message=content) elif "content filtering" in error: - content = "Sorry, but I am unable to help with that request." - stop_reason = "content_filter" + raise LLMCannotAssistError if content and stop_reason: return ModelOutput.from_content( diff --git a/src/inspect_ai/model/_providers/openai.py b/src/inspect_ai/model/_providers/openai.py index 028da2000..3e5b31baa 100644 --- a/src/inspect_ai/model/_providers/openai.py +++ b/src/inspect_ai/model/_providers/openai.py @@ -40,6 +40,7 @@ from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo from .._chat_message import ChatMessage, ChatMessageAssistant +from .._error_handler import LLMCannotAssistError, handle_model_errors from .._generate_config import GenerateConfig from .._model import ModelAPI from .._model_call import ModelCall @@ -132,6 +133,7 @@ def __init__( **model_args, ) + @handle_model_errors async def generate( self, input: list[ChatMessage], @@ -418,13 +420,12 @@ async def as_chat_completion_part( # does not exhibit this behavior (it just returns the completion # "Sorry, but I can't assist with that." def handle_content_filter_error(e: APIStatusError) -> tuple[str, object | None]: - CANT_ASSIST = "Sorry, but I can't assist with that." if e.status_code == 400: if isinstance(e.body, dict) and "message" in e.body.keys(): message = str(e.body.get("message")) return message, e.body else: - return CANT_ASSIST, e.body + raise LLMCannotAssistError() else: raise e diff --git a/tests/model/test_error_handler.py b/tests/model/test_error_handler.py new file mode 100644 index 000000000..aacee8336 --- /dev/null +++ b/tests/model/test_error_handler.py @@ -0,0 +1,102 @@ +import pytest + +from inspect_ai.model import ( + ChatCompletionChoice, + ChatMessageAssistant, + ModelOutput, +) +from inspect_ai.model._error_handler import ( + CANT_ASSIST, + LLMCannotAssistError, + PromptTooLongError, + handle_model_errors, + raise_if_error, +) +from inspect_ai.model._model_call import ModelCall + + +def create_model_output(content): + return ModelOutput( + model="test_model", + choices=[ + ChatCompletionChoice( + message=ChatMessageAssistant(content=content, source="generate"), + stop_reason="stop", + ) + ], + ) + + +@pytest.mark.parametrize( + "content,expected_exception,expected_tokens_used,expected_model_limit", + [ + ("Normal response", None, None, None), + ( + "Invalid 'messages[0].content': string too long", + PromptTooLongError, + None, + None, + ), + ( + "This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens", + PromptTooLongError, + 5000, + 4096, + ), + ( + "prompt is too long: 5000 tokens > 4096 maximum", + PromptTooLongError, + 5000, + 4096, + ), + (CANT_ASSIST, LLMCannotAssistError, None, None), + ], +) +def test_raise_if_error( + content, expected_exception, expected_tokens_used, expected_model_limit +): + model_output = create_model_output(content) + + if expected_exception: + with pytest.raises(expected_exception) as exc_info: + raise_if_error(model_output) + + if expected_exception == PromptTooLongError: + error = exc_info.value + assert error.message == content + assert error.tokens_used == expected_tokens_used + assert error.model_limit_tokens == expected_model_limit + else: + raise_if_error(model_output) # Should not raise an exception + + +@pytest.mark.asyncio +async def test_handle_model_errors_decorator(): + @handle_model_errors + async def mock_generate(): + return create_model_output(CANT_ASSIST) + + with pytest.raises(LLMCannotAssistError): + await mock_generate() + + +@pytest.mark.asyncio +async def test_handle_model_errors_decorator_no_error(): + @handle_model_errors + async def mock_generate(): + return create_model_output("This is a normal response") + + result = await mock_generate() + assert isinstance(result, ModelOutput) + assert result.choices[0].message.content == "This is a normal response" + + +@pytest.mark.asyncio +async def test_handle_model_errors_decorator_with_tuple(): + @handle_model_errors + async def mock_generate(): + model_output = create_model_output(CANT_ASSIST) + return model_output, ModelCall.create({}, {}) + + with pytest.raises(LLMCannotAssistError): + await mock_generate()