-
Notifications
You must be signed in to change notification settings - Fork 175
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
362b667
commit c0a062d
Showing
1 changed file
with
40 additions
and
53 deletions.
There are no files selected for viewing
93 changes: 40 additions & 53 deletions
93
tests/unit/drivers/web_search/test_tavily_web_search_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,87 +1,74 @@ | ||
import json | ||
|
||
import pytest | ||
from tavily import InvalidAPIKeyError, MissingAPIKeyError, UsageLimitExceededError | ||
|
||
from griptape.artifacts import ListArtifact | ||
from griptape.drivers import TavilyWebSearchDriver | ||
|
||
|
||
class TestTavilyWebSearchDriver: | ||
@pytest.fixture() | ||
def driver(self, mocker): | ||
def mock_tavily_client(self, mocker): | ||
return mocker.patch("tavily.TavilyClient") | ||
|
||
@pytest.fixture() | ||
def driver(self, mock_tavily_client): | ||
mock_response = { | ||
"results": [ | ||
{"title": "foo", "url": "bar", "content": "baz"}, | ||
{"title": "foo2", "url": "bar2", "content": "baz2"}, | ||
] | ||
} | ||
mock_tavily = mocker.Mock( | ||
search=lambda *args, **kwargs: mock_response, | ||
) | ||
mocker.patch("tavily.TavilyClient", return_value=mock_tavily) | ||
return TavilyWebSearchDriver(api_key="test") | ||
|
||
@pytest.fixture() | ||
def driver_with_error(self, mocker): | ||
def error(*args, **kwargs): | ||
raise Exception("test_error") | ||
|
||
mock_tavily = mocker.Mock( | ||
search=error, | ||
) | ||
mocker.patch("tavily.TavilyClient", return_value=mock_tavily) | ||
|
||
mock_tavily_client.return_value.search.return_value = mock_response | ||
return TavilyWebSearchDriver(api_key="test") | ||
|
||
def test_search_returns_results(self, driver): | ||
def test_search_returns_results(self, driver, mock_tavily_client): | ||
results = driver.search("test") | ||
assert isinstance(results, ListArtifact) | ||
output = [json.loads(result.value) for result in results] | ||
assert len(output) == 2 | ||
assert output[0]["title"] == "foo" | ||
assert output[0]["url"] == "bar" | ||
assert output[0]["content"] == "baz" | ||
mock_tavily_client.return_value.search.assert_called_once_with("test", max_results=5) | ||
|
||
def test_search_raises_error(self, driver_with_error): | ||
with pytest.raises(ValueError, match="An error occurred while searching for test using Tavily: test_error"): | ||
driver_with_error.search("test") | ||
def test_search_raises_error(self, mock_tavily_client): | ||
mock_tavily_client.return_value.search.side_effect = Exception("test_error") | ||
driver = TavilyWebSearchDriver(api_key="test") | ||
with pytest.raises(Exception, match="test_error"): | ||
driver.search("test") | ||
|
||
@pytest.fixture() | ||
def driver_missing_api_key(self, mocker): | ||
def error(*args, **kwargs): | ||
raise MissingAPIKeyError() | ||
|
||
mock_tavily = mocker.Mock(search=error) | ||
mocker.patch("tavily.TavilyClient", return_value=mock_tavily) | ||
return TavilyWebSearchDriver(api_key="") | ||
def test_search_with_params(self, mock_tavily_client): | ||
mock_response = { | ||
"results": [ | ||
{"title": "custom", "url": "custom_url", "content": "custom_content"}, | ||
] | ||
} | ||
mock_tavily_client.return_value.search.return_value = mock_response | ||
|
||
def test_search_raises_missing_api_key_error(self, driver_missing_api_key): | ||
with pytest.raises(ValueError, match="API Key is missing, Please provide a valid Tavily API Key."): | ||
driver_missing_api_key.search("test") | ||
driver = TavilyWebSearchDriver(api_key="test", params={"custom_param": "value"}) | ||
results = driver.search("test", additional_param="extra") | ||
|
||
@pytest.fixture() | ||
def driver_usage_limit_exceeded(self, mocker): | ||
def error(*args, **kwargs): | ||
raise UsageLimitExceededError("Usage limit exceeded") | ||
assert isinstance(results, ListArtifact) | ||
output = json.loads(results[0].value) | ||
assert output["title"] == "custom" | ||
assert output["url"] == "custom_url" | ||
assert output["content"] == "custom_content" | ||
|
||
mock_tavily = mocker.Mock(search=error) | ||
mocker.patch("tavily.TavilyClient", return_value=mock_tavily) | ||
return TavilyWebSearchDriver(api_key="test") | ||
mock_tavily_client.return_value.search.assert_called_once_with( | ||
"test", max_results=5, custom_param="value", additional_param="extra" | ||
) | ||
|
||
def test_search_raises_usage_limit_exceeded_error(self, driver_usage_limit_exceeded): | ||
with pytest.raises(ValueError, match="Usage Limit Exceeded, Please try again later."): | ||
driver_usage_limit_exceeded.search("test") | ||
def test_custom_results_count(self, mock_tavily_client): | ||
mock_response = { | ||
"results": [{"title": f"title_{i}", "url": f"url_{i}", "content": f"content_{i}"} for i in range(5)] | ||
} | ||
mock_tavily_client.return_value.search.return_value = mock_response | ||
|
||
@pytest.fixture() | ||
def driver_invalid_api_key(self, mocker): | ||
def error(*args, **kwargs): | ||
raise InvalidAPIKeyError() | ||
driver = TavilyWebSearchDriver(api_key="test", results_count=5) | ||
results = driver.search("test") | ||
|
||
mock_tavily = mocker.Mock(search=error) | ||
mocker.patch("tavily.TavilyClient", return_value=mock_tavily) | ||
return TavilyWebSearchDriver(api_key="invalid_key") | ||
assert isinstance(results, ListArtifact) | ||
assert len(results) == 5 | ||
|
||
def test_search_raises_invalid_api_key_error(self, driver_invalid_api_key): | ||
with pytest.raises(ValueError, match="Invalid API Key, Please provide a valid Tavily API Key."): | ||
driver_invalid_api_key.search("test") | ||
mock_tavily_client.return_value.search.assert_called_once_with("test", max_results=5) |