diff --git a/tests/unit/drivers/web_search/test_tavily_web_search_driver.py b/tests/unit/drivers/web_search/test_tavily_web_search_driver.py index 04b7fb1ea..d5d8423c4 100644 --- a/tests/unit/drivers/web_search/test_tavily_web_search_driver.py +++ b/tests/unit/drivers/web_search/test_tavily_web_search_driver.py @@ -1,6 +1,7 @@ import json import pytest +from tavily import InvalidAPIKeyError, MissingAPIKeyError, UsageLimitExceededError from griptape.artifacts import ListArtifact from griptape.drivers import TavilyWebSearchDriver @@ -43,5 +44,44 @@ def test_search_returns_results(self, driver): assert output[0]["content"] == "baz" def test_search_raises_error(self, driver_with_error): - with pytest.raises(Exception, match="An error occurred while searching for test using Tavily: test_error"): + with pytest.raises(ValueError, match="An error occurred while searching for test using Tavily: test_error"): driver_with_error.search("test") + + @pytest.fixture() + def driver_missing_api_key(self, mocker): + def error(*args, **kwargs): + raise MissingAPIKeyError() + + mock_tavily = mocker.Mock(search=error) + mocker.patch("tavily.TavilyClient", return_value=mock_tavily) + return TavilyWebSearchDriver(api_key="") + + def test_search_raises_missing_api_key_error(self, driver_missing_api_key): + with pytest.raises(ValueError, match="API Key is missing, Please provide a valid Tavily API Key."): + driver_missing_api_key.search("test") + + @pytest.fixture() + def driver_usage_limit_exceeded(self, mocker): + def error(*args, **kwargs): + raise UsageLimitExceededError("Usage limit exceeded") + + mock_tavily = mocker.Mock(search=error) + mocker.patch("tavily.TavilyClient", return_value=mock_tavily) + return TavilyWebSearchDriver(api_key="test") + + def test_search_raises_usage_limit_exceeded_error(self, driver_usage_limit_exceeded): + with pytest.raises(ValueError, match="Usage Limit Exceeded, Please try again later."): + driver_usage_limit_exceeded.search("test") + + @pytest.fixture() + def driver_invalid_api_key(self, mocker): + def error(*args, **kwargs): + raise InvalidAPIKeyError() + + mock_tavily = mocker.Mock(search=error) + mocker.patch("tavily.TavilyClient", return_value=mock_tavily) + return TavilyWebSearchDriver(api_key="invalid_key") + + def test_search_raises_invalid_api_key_error(self, driver_invalid_api_key): + with pytest.raises(ValueError, match="Invalid API Key, Please provide a valid Tavily API Key."): + driver_invalid_api_key.search("test")