diff --git a/tests/unit/drivers/web_search/test_tavily_web_search_driver.py b/tests/unit/drivers/web_search/test_tavily_web_search_driver.py index d5d8423c4..4038d7043 100644 --- a/tests/unit/drivers/web_search/test_tavily_web_search_driver.py +++ b/tests/unit/drivers/web_search/test_tavily_web_search_driver.py @@ -1,7 +1,6 @@ import json import pytest -from tavily import InvalidAPIKeyError, MissingAPIKeyError, UsageLimitExceededError from griptape.artifacts import ListArtifact from griptape.drivers import TavilyWebSearchDriver @@ -9,32 +8,21 @@ class TestTavilyWebSearchDriver: @pytest.fixture() - def driver(self, mocker): + def mock_tavily_client(self, mocker): + return mocker.patch("tavily.TavilyClient") + + @pytest.fixture() + def driver(self, mock_tavily_client): mock_response = { "results": [ {"title": "foo", "url": "bar", "content": "baz"}, {"title": "foo2", "url": "bar2", "content": "baz2"}, ] } - mock_tavily = mocker.Mock( - search=lambda *args, **kwargs: mock_response, - ) - mocker.patch("tavily.TavilyClient", return_value=mock_tavily) - return TavilyWebSearchDriver(api_key="test") - - @pytest.fixture() - def driver_with_error(self, mocker): - def error(*args, **kwargs): - raise Exception("test_error") - - mock_tavily = mocker.Mock( - search=error, - ) - mocker.patch("tavily.TavilyClient", return_value=mock_tavily) - + mock_tavily_client.return_value.search.return_value = mock_response return TavilyWebSearchDriver(api_key="test") - def test_search_returns_results(self, driver): + def test_search_returns_results(self, driver, mock_tavily_client): results = driver.search("test") assert isinstance(results, ListArtifact) output = [json.loads(result.value) for result in results] @@ -42,46 +30,45 @@ def test_search_returns_results(self, driver): assert output[0]["title"] == "foo" assert output[0]["url"] == "bar" assert output[0]["content"] == "baz" + mock_tavily_client.return_value.search.assert_called_once_with("test", max_results=5) - def test_search_raises_error(self, driver_with_error): - with pytest.raises(ValueError, match="An error occurred while searching for test using Tavily: test_error"): - driver_with_error.search("test") + def test_search_raises_error(self, mock_tavily_client): + mock_tavily_client.return_value.search.side_effect = Exception("test_error") + driver = TavilyWebSearchDriver(api_key="test") + with pytest.raises(Exception, match="test_error"): + driver.search("test") - @pytest.fixture() - def driver_missing_api_key(self, mocker): - def error(*args, **kwargs): - raise MissingAPIKeyError() - - mock_tavily = mocker.Mock(search=error) - mocker.patch("tavily.TavilyClient", return_value=mock_tavily) - return TavilyWebSearchDriver(api_key="") + def test_search_with_params(self, mock_tavily_client): + mock_response = { + "results": [ + {"title": "custom", "url": "custom_url", "content": "custom_content"}, + ] + } + mock_tavily_client.return_value.search.return_value = mock_response - def test_search_raises_missing_api_key_error(self, driver_missing_api_key): - with pytest.raises(ValueError, match="API Key is missing, Please provide a valid Tavily API Key."): - driver_missing_api_key.search("test") + driver = TavilyWebSearchDriver(api_key="test", params={"custom_param": "value"}) + results = driver.search("test", additional_param="extra") - @pytest.fixture() - def driver_usage_limit_exceeded(self, mocker): - def error(*args, **kwargs): - raise UsageLimitExceededError("Usage limit exceeded") + assert isinstance(results, ListArtifact) + output = json.loads(results[0].value) + assert output["title"] == "custom" + assert output["url"] == "custom_url" + assert output["content"] == "custom_content" - mock_tavily = mocker.Mock(search=error) - mocker.patch("tavily.TavilyClient", return_value=mock_tavily) - return TavilyWebSearchDriver(api_key="test") + mock_tavily_client.return_value.search.assert_called_once_with( + "test", max_results=5, custom_param="value", additional_param="extra" + ) - def test_search_raises_usage_limit_exceeded_error(self, driver_usage_limit_exceeded): - with pytest.raises(ValueError, match="Usage Limit Exceeded, Please try again later."): - driver_usage_limit_exceeded.search("test") + def test_custom_results_count(self, mock_tavily_client): + mock_response = { + "results": [{"title": f"title_{i}", "url": f"url_{i}", "content": f"content_{i}"} for i in range(5)] + } + mock_tavily_client.return_value.search.return_value = mock_response - @pytest.fixture() - def driver_invalid_api_key(self, mocker): - def error(*args, **kwargs): - raise InvalidAPIKeyError() + driver = TavilyWebSearchDriver(api_key="test", results_count=5) + results = driver.search("test") - mock_tavily = mocker.Mock(search=error) - mocker.patch("tavily.TavilyClient", return_value=mock_tavily) - return TavilyWebSearchDriver(api_key="invalid_key") + assert isinstance(results, ListArtifact) + assert len(results) == 5 - def test_search_raises_invalid_api_key_error(self, driver_invalid_api_key): - with pytest.raises(ValueError, match="Invalid API Key, Please provide a valid Tavily API Key."): - driver_invalid_api_key.search("test") + mock_tavily_client.return_value.search.assert_called_once_with("test", max_results=5)