From 3061c75c5b7d2b9bae738c484c0ddaf510063e90 Mon Sep 17 00:00:00 2001 From: Sri Sudarsan Date: Thu, 18 Jul 2024 18:44:15 +0530 Subject: [PATCH] feat: adds compatibility for reranking using text-embeddings-inference server --- rerankers/models/api_rankers.py | 44 ++++++++++++++++++++------------- rerankers/reranker.py | 5 +++- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/rerankers/models/api_rankers.py b/rerankers/models/api_rankers.py index 76a5f06..bf3a667 100644 --- a/rerankers/models/api_rankers.py +++ b/rerankers/models/api_rankers.py @@ -16,9 +16,26 @@ "mixedbread.ai": "https://api.mixedbread.ai/v1/reranking", } +DOCUMENT_KEY_MAPPING = { + "mixedbread.ai": "input", + "text-embeddings-inference":"texts" +} +RETURN_DOCUMENTS_KEY_MAPPING = { + "mixedbread.ai":"return_input", + "text-embeddings-inference":"return_text" +} +RESULTS_KEY_MAPPING = { + "voyage": "data", + "mixedbread.ai": "data", + "text-embeddings-inference": None +} +SCORE_KEY_MAPPING = { + "mixedbread.ai": "score", + "text-embeddings-inference":"score" +} class APIRanker(BaseRanker): - def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1): + def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1, url: str = None): self.api_key = api_key self.model = model self.api_provider = api_provider.lower() @@ -29,7 +46,7 @@ def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1 "content-type": "application/json", "Authorization": f"Bearer {self.api_key}", } - self.url = URLS[self.api_provider] + self.url = url if url else URLS[self.api_provider] def _get_document_text(self, r: dict) -> str: @@ -37,26 +54,23 @@ def _get_document_text(self, r: dict) -> str: return r["document"] elif self.api_provider == "mixedbread.ai": return r["input"] + elif self.api_provider == "text-embeddings-inference": + return r["text"] else: return r["document"]["text"] def _get_score(self, r: dict) -> float: - if self.api_provider == "mixedbread.ai": - return r["score"] - return r["relevance_score"] + score_key = SCORE_KEY_MAPPING.get(self.api_provider,"relevance_score") + return r[score_key] def _parse_response( self, response: dict, docs: List[Document], ) -> RankedResults: ranked_docs = [] - results_key = ( - "results" - if self.api_provider not in ["voyage", "mixedbread.ai"] - else "data" - ) + results_key = RESULTS_KEY_MAPPING.get(self.api_provider,"results") print(response) - for i, r in enumerate(response[results_key]): + for i, r in enumerate(response[results_key] if results_key else response): ranked_docs.append( Result( document=docs[r["index"]], @@ -86,12 +100,8 @@ def _format_payload(self, query: str, docs: List[str]) -> str: top_key = ( "top_n" if self.api_provider not in ["voyage", "mixedbread.ai"] else "top_k" ) - documents_key = "documents" if self.api_provider != "mixedbread.ai" else "input" - return_documents_key = ( - "return_documents" - if self.api_provider != "mixedbread.ai" - else "return_input" - ) + documents_key = DOCUMENT_KEY_MAPPING.get(self.api_provider,"documents") + return_documents_key = RETURN_DOCUMENTS_KEY_MAPPING.get(self.api_provider,"return_documents") payload = { "model": self.model, diff --git a/rerankers/reranker.py b/rerankers/reranker.py index 009bef0..7be39d9 100644 --- a/rerankers/reranker.py +++ b/rerankers/reranker.py @@ -30,6 +30,7 @@ "es": "AdrienB134/ColBERTv2.0-spanish-mmarcoES", }, "flashrank": {"en": "ms-marco-MiniLM-L-12-v2", "other": "ms-marco-MultiBERT-L-12"}, + "text-embeddings-inference": {"other": "BAAI/bge-reranker-base"} } DEPS_MAPPING = { @@ -43,7 +44,7 @@ "RankLLMRanker": "rankllm", } -PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai"] +PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "text-embeddings-inference"] def _get_api_provider(model_name: str, model_type: Optional[str] = None) -> str: @@ -69,6 +70,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None) "cohere": "APIRanker", "jina": "APIRanker", "voyage": "APIRanker", + "text-embeddings-inference": "APIRanker", "rankgpt": "RankGPTRanker", "lit5": "LiT5Ranker", "t5": "T5Ranker", @@ -92,6 +94,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None) "cohere": "APIRanker", "jina": "APIRanker", "voyage": "APIRanker", + "text-embeddings-inference": "APIRanker", "ms-marco-minilm-l-12-v2": "FlashRankRanker", "ms-marco-multibert-l-12": "FlashRankRanker", "vicuna": "RankLLMRanker",