-
-
Notifications
You must be signed in to change notification settings - Fork 270
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Mixedbread AI Reranker Module (#805)
* Add mxbai.py and its test code * Add func annotation * Add mxbai_reranker at support.py * Add docs and rst * change name mxbai to mixedbreadai * Add Full name * Add full name at docs * contents use text * change API use * change API use * change toctree name * Add mixedbreadai at requirements.txt * use mock api key at node test * Add func annotation
- Loading branch information
Showing
10 changed files
with
283 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import os | ||
from typing import List, Tuple | ||
|
||
import pandas as pd | ||
from mixedbread_ai.client import AsyncMixedbreadAI | ||
|
||
from autorag.nodes.passagereranker.base import BasePassageReranker | ||
from autorag.utils.util import ( | ||
result_to_dataframe, | ||
get_event_loop, | ||
process_batch, | ||
pop_params, | ||
) | ||
|
||
|
||
class MixedbreadAIReranker(BasePassageReranker): | ||
def __init__( | ||
self, | ||
project_dir: str, | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Initialize mixedbread-ai rerank node. | ||
:param project_dir: The project directory path. | ||
:param api_key: The API key for MixedbreadAI rerank. | ||
You can set it in the environment variable MXBAI_API_KEY. | ||
Or, you can directly set it on the config YAML file using this parameter. | ||
Default is env variable "MXBAI_API_KEY". | ||
:param kwargs: Extra arguments that are not affected | ||
""" | ||
super().__init__(project_dir) | ||
api_key = kwargs.pop("api_key", None) | ||
api_key = os.getenv("MXBAI_API_KEY", None) if api_key is None else api_key | ||
if api_key is None: | ||
raise KeyError( | ||
"Please set the API key for Mixedbread AI rerank in the environment variable MXBAI_API_KEY " | ||
"or directly set it on the config YAML file." | ||
) | ||
self.client = AsyncMixedbreadAI(api_key=api_key) | ||
|
||
def __del__(self): | ||
del self.client | ||
super().__del__() | ||
|
||
@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"]) | ||
def pure(self, previous_result: pd.DataFrame, *args, **kwargs): | ||
queries, contents, scores, ids = self.cast_to_run(previous_result) | ||
top_k = kwargs.pop("top_k") | ||
batch = kwargs.pop("batch", 8) | ||
model = kwargs.pop("model", "mixedbread-ai/mxbai-rerank-large-v1") | ||
rerank_params = pop_params(self.client.reranking, kwargs) | ||
return self._pure(queries, contents, ids, top_k, model, batch, **rerank_params) | ||
|
||
def _pure( | ||
self, | ||
queries: List[str], | ||
contents_list: List[List[str]], | ||
ids_list: List[List[str]], | ||
top_k: int, | ||
model: str = "mixedbread-ai/mxbai-rerank-large-v1", | ||
batch: int = 8, | ||
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: | ||
""" | ||
Rerank a list of contents with mixedbread-ai rerank models. | ||
You can get the API key from https://www.mixedbread.ai/api-reference#quick-start-guide and set it in the environment variable MXBAI_API_KEY. | ||
:param queries: The list of queries to use for reranking | ||
:param contents_list: The list of lists of contents to rerank | ||
:param ids_list: The list of lists of ids retrieved from the initial ranking | ||
:param top_k: The number of passages to be retrieved | ||
:param model: The model name for mixedbread-ai rerank. | ||
You can choose between "mixedbread-ai/mxbai-rerank-large-v1", "mixedbread-ai/mxbai-rerank-base-v1" and "mixedbread-ai/mxbai-rerank-xsmall-v1". | ||
Default is "mixedbread-ai/mxbai-rerank-large-v1". | ||
:param batch: The number of queries to be processed in a batch | ||
:return: Tuple of lists containing the reranked contents, ids, and scores | ||
""" | ||
tasks = [ | ||
mixedbreadai_rerank_pure( | ||
self.client, query, contents, ids, top_k=top_k, model=model | ||
) | ||
for query, contents, ids in zip(queries, contents_list, ids_list) | ||
] | ||
loop = get_event_loop() | ||
results = loop.run_until_complete(process_batch(tasks, batch)) | ||
|
||
content_result, id_result, score_result = zip(*results) | ||
|
||
return list(content_result), list(id_result), list(score_result) | ||
|
||
|
||
async def mixedbreadai_rerank_pure( | ||
client: AsyncMixedbreadAI, | ||
query: str, | ||
documents: List[str], | ||
ids: List[str], | ||
top_k: int, | ||
model: str = "mixedbread-ai/mxbai-rerank-large-v1", | ||
) -> Tuple[List[str], List[str], List[float]]: | ||
""" | ||
Rerank a list of contents with mixedbread-ai rerank models. | ||
:param client: The mixedbread-ai client to use for reranking | ||
:param query: The query to use for reranking | ||
:param documents: The list of contents to rerank | ||
:param ids: The list of ids corresponding to the documents | ||
:param top_k: The number of passages to be retrieved | ||
:param model: The model name for mixedbread-ai rerank. | ||
You can choose between "mixedbread-ai/mxbai-rerank-large-v1" and "mixedbread-ai/mxbai-rerank-base-v1". | ||
Default is "mixedbread-ai/mxbai-rerank-large-v1". | ||
:return: Tuple of lists containing the reranked contents, ids, and scores | ||
""" | ||
|
||
results = await client.reranking( | ||
query=query, | ||
input=documents, | ||
top_k=top_k, | ||
model=model, | ||
) | ||
reranked_scores: List[float] = list(map(lambda x: x.score, results.data)) | ||
reranked_scores_float = list(map(float, reranked_scores)) | ||
indices = list(map(lambda x: x.index, results.data)) | ||
reranked_contents = list(map(lambda x: documents[x], indices)) | ||
reranked_ids: List[str] = list(map(lambda i: ids[i], indices)) | ||
return reranked_contents, reranked_ids, reranked_scores_float |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
docs/source/nodes/passage_reranker/mixedbreadai_reranker.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
--- | ||
myst: | ||
html_meta: | ||
title: AutoRAG - Mixedbread Reranker | ||
description: Learn about Mixedbread reranker module in AutoRAG | ||
keywords: AutoRAG,RAG,Advanced RAG,Reranker,Mixedbread Reranker | ||
--- | ||
# Mixedbread AI Reranker | ||
|
||
The `Mixedbread AI Reranker` module is a reranker that uses the mixedbread-ai rerank model. This model rerank passages based on their relevance to a | ||
given query. | ||
|
||
## Before Usage | ||
|
||
At first, you need to get the Mixedbread AI API key from [MixedbreadAI](https://www.mixedbread.ai/api-reference#quick-start-guide). | ||
|
||
Next, you can set your Mixedbread AI API key in the environment variable. | ||
|
||
```bash | ||
export MXBAI_API_KEY=your_mixedbread_api_key | ||
``` | ||
|
||
Or, you can set your Mixedbread AI API key in the config.yaml file directly. | ||
|
||
```yaml | ||
- module_type: mixedbreadai_reranker | ||
api_key: your_mixedbread_api_key | ||
``` | ||
## **Module Parameters** | ||
- (Optional) `model_name`: | ||
- Requiring the specification of a model_name. | ||
- default is `mixedbread-ai/mxbai-rerank-large-v1` | ||
- api_key: The Mixedbread AI api key. | ||
|
||
## **Example config.yaml** | ||
|
||
```yaml | ||
modules: | ||
- module_type: mixedbreadai_reranker | ||
``` | ||
|
||
### Supported Model Names | ||
|
||
| Model Name | | ||
|:------------------------------------------:| | ||
| mixedbread-ai/mxbai-rerank-xsmall-v1 | | ||
| mixedbread-ai/mxbai-rerank-large-v1 | | ||
| mixedbread-ai/mxbai-rerank-base-v1 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
86 changes: 86 additions & 0 deletions
86
tests/autorag/nodes/passagereranker/test_mixedbreadai_reranker.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
import mixedbread_ai.client | ||
from mixedbread_ai.types.reranking_response import RerankingResponse | ||
from mixedbread_ai.types.ranked_document import RankedDocument | ||
from mixedbread_ai.types.usage import Usage | ||
|
||
from autorag.nodes.passagereranker import MixedbreadAIReranker | ||
from tests.autorag.nodes.passagereranker.test_passage_reranker_base import ( | ||
queries_example, | ||
contents_example, | ||
ids_example, | ||
base_reranker_test, | ||
project_dir, | ||
previous_result, | ||
base_reranker_node_test, | ||
) | ||
|
||
|
||
async def mock_mixedbreadai_reranker( | ||
self, | ||
*, | ||
query, | ||
input, | ||
model, | ||
top_k, | ||
**kwargs, | ||
): | ||
mock_usage = Usage(prompt_tokens=100, total_tokens=150, completion_tokens=50) | ||
mock_documents = [ | ||
RankedDocument(index=1, score=0.8, input="Document 1", object=None), | ||
RankedDocument(index=2, score=0.2, input="Document 2", object=None), | ||
RankedDocument(index=0, score=0.1, input="Document 3", object=None), | ||
] | ||
return RerankingResponse( | ||
usage=mock_usage, | ||
model="mock-model", | ||
data=mock_documents[:top_k], | ||
object=None, | ||
top_k=top_k, | ||
return_input=False, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mixedbreadai_reranker_instance(): | ||
return MixedbreadAIReranker(project_dir=project_dir, api_key="mock_api_key") | ||
|
||
|
||
@patch.object( | ||
mixedbread_ai.client.AsyncMixedbreadAI, "reranking", mock_mixedbreadai_reranker | ||
) | ||
def test_mixedbreadai_reranker(mixedbreadai_reranker_instance): | ||
top_k = 1 | ||
contents_result, id_result, score_result = mixedbreadai_reranker_instance._pure( | ||
queries_example, contents_example, ids_example, top_k | ||
) | ||
base_reranker_test(contents_result, id_result, score_result, top_k) | ||
|
||
|
||
@patch.object( | ||
mixedbread_ai.client.AsyncMixedbreadAI, "reranking", mock_mixedbreadai_reranker | ||
) | ||
def test_mixedbreadai_reranker_batch_one(mixedbreadai_reranker_instance): | ||
top_k = 1 | ||
batch = 1 | ||
contents_result, id_result, score_result = mixedbreadai_reranker_instance._pure( | ||
queries_example, contents_example, ids_example, top_k, batch=batch | ||
) | ||
base_reranker_test(contents_result, id_result, score_result, top_k) | ||
|
||
|
||
@patch.object( | ||
mixedbread_ai.client.AsyncMixedbreadAI, "reranking", mock_mixedbreadai_reranker | ||
) | ||
def test_mixedbreadai_node(): | ||
top_k = 1 | ||
result_df = MixedbreadAIReranker.run_evaluator( | ||
project_dir=project_dir, | ||
previous_result=previous_result, | ||
top_k=top_k, | ||
api_key="mock", | ||
) | ||
base_reranker_node_test(result_df, top_k) |