Skip to content

Commit

Permalink
Don't rerank empty docs (#1153)
Browse files Browse the repository at this point in the history
Co-authored-by: matt <[email protected]>
  • Loading branch information
collindutter and vachillo authored Sep 6, 2024
1 parent dc569b3 commit 9735d88
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Missing `maxTokens` inference parameter in `AmazonBedrockPromptDriver`.
- Incorrect model in `OpenAiDriverConfig`'s `text_to_speech_driver`.
- Crash when using `CohereRerankDriver` with `CsvRowArtifact`s.
- Crash when passing "empty" Artifacts or no Artifacts to `CohereRerankDriver`.


## [0.30.2] - 2024-08-26
Expand Down
24 changes: 14 additions & 10 deletions griptape/drivers/rerank/cohere_rerank_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ class CohereRerankDriver(BaseRerankDriver):
)

def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
artifacts_dict = {str(hash(a.to_text())): a for a in artifacts}
response = self.client.rerank(
model=self.model,
query=query,
documents=[a.to_text() for a in artifacts_dict.values()],
return_documents=True,
top_n=self.top_n,
)

return [artifacts_dict[str(hash(r.document.text))] for r in response.results]
# Cohere errors out if passed "empty" documents or no documents at all
artifacts_dict = {str(hash(a.to_text())): a for a in artifacts if a}

if artifacts_dict:
response = self.client.rerank(
model=self.model,
query=query,
documents=[a.to_text() for a in artifacts_dict.values()],
return_documents=True,
top_n=self.top_n,
)
return [artifacts_dict[str(hash(r.document.text))] for r in response.results]
else:
return []
16 changes: 16 additions & 0 deletions tests/unit/drivers/rerank/test_cohere_rerank_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,24 @@ def mock_client(self, mocker):

return mock_client

@pytest.fixture()
def mock_empty_client(self, mocker):
mock_client = mocker.patch("cohere.Client").return_value
mock_client.rerank.side_effect = Exception("Client should not be called")

return mock_client

def test_run(self, mock_client):
driver = CohereRerankDriver(api_key="api-key")
result = driver.run("hello", artifacts=[TextArtifact("foo"), TextArtifact("bar")])

assert len(result) == 2

def test_run_empty_artifacts(self, mock_empty_client):
driver = CohereRerankDriver(api_key="api-key")
result = driver.run("hello", artifacts=[TextArtifact(""), TextArtifact(" ")])

assert len(result) == 0

result = driver.run("hello", artifacts=[])
assert len(result) == 0

0 comments on commit 9735d88

Please sign in to comment.