-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
202 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import re | ||
from dataclasses import dataclass | ||
from functools import wraps | ||
from typing import Any, Callable | ||
|
||
from ._model import ModelOutput | ||
|
||
|
||
@dataclass | ||
class PromptTooLongError(Exception): | ||
message: str | None = None | ||
tokens_used: int | None = None | ||
model_limit_tokens: int | None = None | ||
|
||
|
||
CANT_ASSIST = "Sorry, but I can't assist with that." | ||
|
||
|
||
@dataclass | ||
class LLMCannotAssistError(Exception): | ||
message: str | None = CANT_ASSIST | ||
|
||
|
||
def handle_model_errors(func: Callable[..., Any]) -> Callable[..., Any]: | ||
@wraps(func) | ||
async def wrapper(*args: Any, **kwargs: Any) -> Any: | ||
result = await func(*args, **kwargs) | ||
if isinstance(result, tuple): | ||
model_output, model_call = result | ||
else: | ||
model_output = result | ||
raise_if_error(model_output) | ||
return result | ||
|
||
return wrapper | ||
|
||
|
||
def raise_if_error(model_output: ModelOutput) -> None: | ||
if not model_output.choices or not model_output.choices[0].message.content: | ||
return | ||
|
||
msg = str(model_output.choices[0].message.content) | ||
|
||
error_checkers = [ | ||
check_openai_too_long, | ||
check_gpt4_token_limit, | ||
check_claude_token_limit, | ||
check_openai_cannot_assist_with_that, | ||
] | ||
|
||
for checker in error_checkers: | ||
checker(msg) | ||
|
||
|
||
def check_openai_too_long(msg: str) -> None: | ||
if re.search(r"Invalid 'messages\[[0-9]*].content': string too long", msg): | ||
raise PromptTooLongError(message=msg) | ||
|
||
|
||
def check_gpt4_token_limit(msg: str) -> None: | ||
match = re.search( | ||
r"This model's maximum context length is (\d+) tokens\. However, your messages resulted in (\d+) tokens", | ||
msg, | ||
) | ||
if match: | ||
raise PromptTooLongError( | ||
message=msg, | ||
tokens_used=int(match.group(2)), | ||
model_limit_tokens=int(match.group(1)), | ||
) | ||
|
||
|
||
def check_claude_token_limit(msg: str) -> None: | ||
match = re.search(r"prompt is too long: (\d+) tokens > (\d+) maximum", msg) | ||
if match: | ||
raise PromptTooLongError( | ||
message=msg, | ||
tokens_used=int(match.group(1)), | ||
model_limit_tokens=int(match.group(2)), | ||
) | ||
|
||
|
||
def check_openai_cannot_assist_with_that(msg: str) -> None: | ||
match = re.search(CANT_ASSIST, msg) | ||
if match: | ||
raise LLMCannotAssistError(message=msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import pytest | ||
|
||
from inspect_ai.model import ( | ||
ChatCompletionChoice, | ||
ChatMessageAssistant, | ||
ModelOutput, | ||
) | ||
from inspect_ai.model._error_handler import ( | ||
CANT_ASSIST, | ||
LLMCannotAssistError, | ||
PromptTooLongError, | ||
handle_model_errors, | ||
raise_if_error, | ||
) | ||
from inspect_ai.model._model_call import ModelCall | ||
|
||
|
||
def create_model_output(content): | ||
return ModelOutput( | ||
model="test_model", | ||
choices=[ | ||
ChatCompletionChoice( | ||
message=ChatMessageAssistant(content=content, source="generate"), | ||
stop_reason="stop", | ||
) | ||
], | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"content,expected_exception,expected_tokens_used,expected_model_limit", | ||
[ | ||
("Normal response", None, None, None), | ||
( | ||
"Invalid 'messages[0].content': string too long", | ||
PromptTooLongError, | ||
None, | ||
None, | ||
), | ||
( | ||
"This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens", | ||
PromptTooLongError, | ||
5000, | ||
4096, | ||
), | ||
( | ||
"prompt is too long: 5000 tokens > 4096 maximum", | ||
PromptTooLongError, | ||
5000, | ||
4096, | ||
), | ||
(CANT_ASSIST, LLMCannotAssistError, None, None), | ||
], | ||
) | ||
def test_raise_if_error( | ||
content, expected_exception, expected_tokens_used, expected_model_limit | ||
): | ||
model_output = create_model_output(content) | ||
|
||
if expected_exception: | ||
with pytest.raises(expected_exception) as exc_info: | ||
raise_if_error(model_output) | ||
|
||
if expected_exception == PromptTooLongError: | ||
error = exc_info.value | ||
assert error.message == content | ||
assert error.tokens_used == expected_tokens_used | ||
assert error.model_limit_tokens == expected_model_limit | ||
else: | ||
raise_if_error(model_output) # Should not raise an exception | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_handle_model_errors_decorator(): | ||
@handle_model_errors | ||
async def mock_generate(): | ||
return create_model_output(CANT_ASSIST) | ||
|
||
with pytest.raises(LLMCannotAssistError): | ||
await mock_generate() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_handle_model_errors_decorator_no_error(): | ||
@handle_model_errors | ||
async def mock_generate(): | ||
return create_model_output("This is a normal response") | ||
|
||
result = await mock_generate() | ||
assert isinstance(result, ModelOutput) | ||
assert result.choices[0].message.content == "This is a normal response" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_handle_model_errors_decorator_with_tuple(): | ||
@handle_model_errors | ||
async def mock_generate(): | ||
model_output = create_model_output(CANT_ASSIST) | ||
return model_output, ModelCall.create({}, {}) | ||
|
||
with pytest.raises(LLMCannotAssistError): | ||
await mock_generate() |