Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
william-price01 committed Sep 18, 2024
1 parent 362b667 commit c0a062d
Showing 1 changed file with 40 additions and 53 deletions.
93 changes: 40 additions & 53 deletions tests/unit/drivers/web_search/test_tavily_web_search_driver.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,74 @@
import json

import pytest
from tavily import InvalidAPIKeyError, MissingAPIKeyError, UsageLimitExceededError

from griptape.artifacts import ListArtifact
from griptape.drivers import TavilyWebSearchDriver


class TestTavilyWebSearchDriver:
@pytest.fixture()
def driver(self, mocker):
def mock_tavily_client(self, mocker):
return mocker.patch("tavily.TavilyClient")

@pytest.fixture()
def driver(self, mock_tavily_client):
mock_response = {
"results": [
{"title": "foo", "url": "bar", "content": "baz"},
{"title": "foo2", "url": "bar2", "content": "baz2"},
]
}
mock_tavily = mocker.Mock(
search=lambda *args, **kwargs: mock_response,
)
mocker.patch("tavily.TavilyClient", return_value=mock_tavily)
return TavilyWebSearchDriver(api_key="test")

@pytest.fixture()
def driver_with_error(self, mocker):
def error(*args, **kwargs):
raise Exception("test_error")

mock_tavily = mocker.Mock(
search=error,
)
mocker.patch("tavily.TavilyClient", return_value=mock_tavily)

mock_tavily_client.return_value.search.return_value = mock_response
return TavilyWebSearchDriver(api_key="test")

def test_search_returns_results(self, driver):
def test_search_returns_results(self, driver, mock_tavily_client):
results = driver.search("test")
assert isinstance(results, ListArtifact)
output = [json.loads(result.value) for result in results]
assert len(output) == 2
assert output[0]["title"] == "foo"
assert output[0]["url"] == "bar"
assert output[0]["content"] == "baz"
mock_tavily_client.return_value.search.assert_called_once_with("test", max_results=5)

def test_search_raises_error(self, driver_with_error):
with pytest.raises(ValueError, match="An error occurred while searching for test using Tavily: test_error"):
driver_with_error.search("test")
def test_search_raises_error(self, mock_tavily_client):
mock_tavily_client.return_value.search.side_effect = Exception("test_error")
driver = TavilyWebSearchDriver(api_key="test")
with pytest.raises(Exception, match="test_error"):
driver.search("test")

@pytest.fixture()
def driver_missing_api_key(self, mocker):
def error(*args, **kwargs):
raise MissingAPIKeyError()

mock_tavily = mocker.Mock(search=error)
mocker.patch("tavily.TavilyClient", return_value=mock_tavily)
return TavilyWebSearchDriver(api_key="")
def test_search_with_params(self, mock_tavily_client):
mock_response = {
"results": [
{"title": "custom", "url": "custom_url", "content": "custom_content"},
]
}
mock_tavily_client.return_value.search.return_value = mock_response

def test_search_raises_missing_api_key_error(self, driver_missing_api_key):
with pytest.raises(ValueError, match="API Key is missing, Please provide a valid Tavily API Key."):
driver_missing_api_key.search("test")
driver = TavilyWebSearchDriver(api_key="test", params={"custom_param": "value"})
results = driver.search("test", additional_param="extra")

@pytest.fixture()
def driver_usage_limit_exceeded(self, mocker):
def error(*args, **kwargs):
raise UsageLimitExceededError("Usage limit exceeded")
assert isinstance(results, ListArtifact)
output = json.loads(results[0].value)
assert output["title"] == "custom"
assert output["url"] == "custom_url"
assert output["content"] == "custom_content"

mock_tavily = mocker.Mock(search=error)
mocker.patch("tavily.TavilyClient", return_value=mock_tavily)
return TavilyWebSearchDriver(api_key="test")
mock_tavily_client.return_value.search.assert_called_once_with(
"test", max_results=5, custom_param="value", additional_param="extra"
)

def test_search_raises_usage_limit_exceeded_error(self, driver_usage_limit_exceeded):
with pytest.raises(ValueError, match="Usage Limit Exceeded, Please try again later."):
driver_usage_limit_exceeded.search("test")
def test_custom_results_count(self, mock_tavily_client):
mock_response = {
"results": [{"title": f"title_{i}", "url": f"url_{i}", "content": f"content_{i}"} for i in range(5)]
}
mock_tavily_client.return_value.search.return_value = mock_response

@pytest.fixture()
def driver_invalid_api_key(self, mocker):
def error(*args, **kwargs):
raise InvalidAPIKeyError()
driver = TavilyWebSearchDriver(api_key="test", results_count=5)
results = driver.search("test")

mock_tavily = mocker.Mock(search=error)
mocker.patch("tavily.TavilyClient", return_value=mock_tavily)
return TavilyWebSearchDriver(api_key="invalid_key")
assert isinstance(results, ListArtifact)
assert len(results) == 5

def test_search_raises_invalid_api_key_error(self, driver_invalid_api_key):
with pytest.raises(ValueError, match="Invalid API Key, Please provide a valid Tavily API Key."):
driver_invalid_api_key.search("test")
mock_tavily_client.return_value.search.assert_called_once_with("test", max_results=5)

0 comments on commit c0a062d

Please sign in to comment.