Skip to content

Commit

Permalink
Add addition error propagation to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
morganmcg1 committed Jan 28, 2025
1 parent 4157c7a commit 8c95706
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 30 deletions.
28 changes: 28 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from wandbot.configs.chat_config import ChatConfig

class TestConfig(ChatConfig):
"""Test configuration with minimal retry settings for faster tests"""

# Override LLM retry settings
llm_max_retries: int = 1
llm_retry_min_wait: int = 1
llm_retry_max_wait: int = 2
llm_retry_multiplier: int = 1

# Override Embedding retry settings
embedding_max_retries: int = 1
embedding_retry_min_wait: int = 1
embedding_retry_max_wait: int = 2
embedding_retry_multiplier: int = 1

# Override Reranker retry settings
reranker_max_retries: int = 1
reranker_retry_min_wait: int = 1
reranker_retry_max_wait: int = 2
reranker_retry_multiplier: int = 1

# Override retry settings for faster tests
max_retries: int = 1 # Only try once
retry_min_wait: int = 1 # Wait 1 second minimum
retry_max_wait: int = 2 # Wait 2 seconds maximum
retry_multiplier: int = 1 # No exponential increase
96 changes: 66 additions & 30 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
from dotenv import load_dotenv
from pydantic import BaseModel
import asyncio
import json

from wandbot.models.llm import (
AsyncOpenAILLMModel,
AsyncAnthropicLLMModel,
LLMModel,
extract_system_and_messages,
LLMError
)
from wandbot.utils import ErrorInfo

# Load environment variables from .env
load_dotenv()
Expand Down Expand Up @@ -77,10 +76,11 @@ async def test_openai_llm_creation(model_name):
@pytest.mark.parametrize("model_name", openai_models)
async def test_openai_llm_create(model_name):
model = AsyncOpenAILLMModel(model_name=model_name, temperature=0)
response = await model.create([
result, error_info = await model.create([
{"role": "user", "content": "What is 2+2? Answer with just the number."}
])
assert response.strip() == "4"
assert result.strip() == "4"
assert not error_info.has_error

@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", openai_models)
Expand All @@ -90,18 +90,24 @@ async def test_openai_llm_create_with_response_model(model_name):
temperature=0,
response_model=SimpleResponse
)
response = await model.create([
result, error_info = await model.create([
{"role": "user", "content": "Return the number 4 as a JSON object with the key 'answer'. Respond with only valid JSON."}
])

assert isinstance(response, SimpleResponse)
assert response.answer == 4
assert isinstance(result, SimpleResponse)
assert result.answer == 4
assert not error_info.has_error

@pytest.mark.asyncio
async def test_openai_invalid_model():
model = AsyncOpenAILLMModel(model_name="invalid-model", temperature=0)
with pytest.raises(Exception): # OpenAI will raise an error on API call
await model.create([{"role": "user", "content": "test"}])
client = AsyncOpenAILLMModel(model_name="invalid-model")
result = await client.create([{"role": "user", "content": "test"}])
assert isinstance(result, tuple)
assert result[0] is None
assert isinstance(result[1], ErrorInfo)
assert result[1].has_error is True
assert "model_not_found" in result[1].error_message
assert result[1].component == "openai"

# Anthropic model tests
@pytest.mark.asyncio
Expand All @@ -116,16 +122,22 @@ async def test_anthropic_llm_creation(model_name):
@pytest.mark.parametrize("model_name", anthropic_models)
async def test_anthropic_llm_create(model_name):
model = AsyncAnthropicLLMModel(model_name=model_name, temperature=0)
response = await model.create([
result, error_info = await model.create([
{"role": "user", "content": "What is 2+2? Answer with just the number."}
], max_tokens=4000)
assert response.strip() == "4"
assert result.strip() == "4"
assert not error_info.has_error

@pytest.mark.asyncio
async def test_anthropic_invalid_model():
model = AsyncAnthropicLLMModel(model_name="invalid-model", temperature=0)
with pytest.raises(Exception): # Anthropic will raise an error on API call
await model.create([{"role": "user", "content": "test"}], max_tokens=4000)
client = AsyncAnthropicLLMModel(model_name="invalid-model")
result = await client.create([{"role": "user", "content": "test"}])
assert isinstance(result, tuple)
assert result[0] is None
assert isinstance(result[1], ErrorInfo)
assert result[1].has_error is True
assert "not_found_error" in result[1].error_message
assert result[1].component == "anthropic"

@pytest.mark.asyncio
async def test_anthropic_llm_create_with_response_model():
Expand All @@ -134,12 +146,13 @@ async def test_anthropic_llm_create_with_response_model():
temperature=0,
response_model=SimpleResponse
)
response = await model.create([
result, error_info = await model.create([
{"role": "user", "content": "Return the number 4 as a JSON object with the key 'answer'. Respond with only valid JSON."}
])

assert isinstance(response, SimpleResponse)
assert response.answer == 4
assert isinstance(result, SimpleResponse)
assert result.answer == 4
assert not error_info.has_error

# LLMModel wrapper tests
def test_llm_model_invalid_provider():
Expand All @@ -149,18 +162,38 @@ def test_llm_model_invalid_provider():
@pytest.mark.asyncio
async def test_llm_model_invalid_openai_model():
model = LLMModel(provider="openai", model_name="invalid-model")
response = await model.create([{"role": "user", "content": "test"}])
assert isinstance(response, LLMError)
assert response.error
assert "model_not_found" in response.error_message
response, error_info = await model.create([{"role": "user", "content": "test"}])
assert response is None
assert isinstance(error_info, ErrorInfo)
assert error_info.has_error is True
assert "model_not_found" in error_info.error_message
assert error_info.component == "openai"
assert error_info.error_type is not None
assert error_info.stacktrace is not None
assert error_info.file_path is not None

@pytest.mark.asyncio
async def test_llm_model_invalid_anthropic_model():
model = LLMModel(provider="anthropic", model_name="invalid-model")
response = await model.create([{"role": "user", "content": "test"}])
assert isinstance(response, LLMError)
assert response.error
assert "not_found_error" in response.error_message
response, error_info = await model.create([{"role": "user", "content": "test"}])
assert response is None
assert isinstance(error_info, ErrorInfo)
assert error_info.has_error is True
assert "not_found_error" in error_info.error_message
assert error_info.component == "anthropic"
assert error_info.error_type is not None
assert error_info.stacktrace is not None
assert error_info.file_path is not None

@pytest.mark.asyncio
async def test_successful_call_error_info():
model = LLMModel(provider="openai", model_name="gpt-4-1106-preview")
result, error_info = await model.create([{"role": "user", "content": "Say 'test'"}])
assert result is not None
assert isinstance(error_info, ErrorInfo)
assert error_info.has_error is False
assert error_info.error_message is None
assert error_info.component == "llm"

@pytest.mark.parametrize("provider,model_name", [
("openai", "gpt-4-1106-preview"),
Expand All @@ -174,14 +207,16 @@ def test_llm_model_valid_models(provider, model_name):
@pytest.mark.asyncio
async def test_llm_model_create_with_response_model():
model = LLMModel(
provider="openai",
provider="openai",
model_name="gpt-4o-2024-08-06",
response_model=SimpleResponse
)
response = await model.create([
response, error_info = await model.create([
{"role": "user", "content": "Return the number 4 as a JSON object with the key 'answer'. Respond with only valid JSON."}
])
assert isinstance(response, SimpleResponse)
assert isinstance(error_info, ErrorInfo)
assert error_info.has_error is False
assert response.answer == 4

@pytest.mark.asyncio
Expand All @@ -194,5 +229,6 @@ async def test_parallel_api_calls():
]))

responses = await asyncio.gather(*tasks)
for i, response in enumerate(responses):
assert response.strip() == str(i + i)
for i, (result, error_info) in enumerate(responses):
assert result.strip() == str(i + i)
assert not error_info.has_error

0 comments on commit 8c95706

Please sign in to comment.