diff --git a/CHANGELOG.md b/CHANGELOG.md index e7d8336122..3e5b3c5996 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Missing `maxTokens` inference parameter in `AmazonBedrockPromptDriver`. - Incorrect model in `OpenAiDriverConfig`'s `text_to_speech_driver`. - Crash when using `CohereRerankDriver` with `CsvRowArtifact`s. +- Crash when passing "empty" Artifacts to `CohereRerankDriver`. ## [0.30.2] - 2024-08-26 diff --git a/griptape/drivers/rerank/cohere_rerank_driver.py b/griptape/drivers/rerank/cohere_rerank_driver.py index 5ca03cf63c..c57a947e52 100644 --- a/griptape/drivers/rerank/cohere_rerank_driver.py +++ b/griptape/drivers/rerank/cohere_rerank_driver.py @@ -24,7 +24,9 @@ class CohereRerankDriver(BaseRerankDriver): ) def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: - artifacts_dict = {str(hash(a.to_text())): a for a in artifacts} + # Cohere errors out if passed "empty" documents + artifacts_dict = {str(hash(a.to_text())): a for a in artifacts if a} + response = self.client.rerank( model=self.model, query=query, diff --git a/tests/unit/drivers/rerank/test_cohere_rerank_driver.py b/tests/unit/drivers/rerank/test_cohere_rerank_driver.py index 87a727269a..a713464b79 100644 --- a/tests/unit/drivers/rerank/test_cohere_rerank_driver.py +++ b/tests/unit/drivers/rerank/test_cohere_rerank_driver.py @@ -20,8 +20,21 @@ def mock_client(self, mocker): return mock_client + @pytest.fixture() + def mock_empty_client(self, mocker): + mock_client = mocker.patch("cohere.Client").return_value + mock_client.rerank.return_value.results = [] + + return mock_client + def test_run(self, mock_client): driver = CohereRerankDriver(api_key="api-key") result = driver.run("hello", artifacts=[TextArtifact("foo"), TextArtifact("bar")]) assert len(result) == 2 + + def test_run_empty_artifacts(self, mock_empty_client): + driver = CohereRerankDriver(api_key="api-key") + result = driver.run("hello", artifacts=[TextArtifact(""), TextArtifact(" ")]) + + assert len(result) == 0