Skip to content

Commit

Permalink
Merge pull request #20 from srisudarsan/main
Browse files Browse the repository at this point in the history
feat: adds compatibility for reranking using text-embeddings-inference server
  • Loading branch information
bclavie authored Jul 26, 2024
2 parents dab1eed + 3061c75 commit b407dc5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
44 changes: 27 additions & 17 deletions rerankers/models/api_rankers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,26 @@
"mixedbread.ai": "https://api.mixedbread.ai/v1/reranking",
}

DOCUMENT_KEY_MAPPING = {
"mixedbread.ai": "input",
"text-embeddings-inference":"texts"
}
RETURN_DOCUMENTS_KEY_MAPPING = {
"mixedbread.ai":"return_input",
"text-embeddings-inference":"return_text"
}
RESULTS_KEY_MAPPING = {
"voyage": "data",
"mixedbread.ai": "data",
"text-embeddings-inference": None
}
SCORE_KEY_MAPPING = {
"mixedbread.ai": "score",
"text-embeddings-inference":"score"
}

class APIRanker(BaseRanker):
def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1):
def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1, url: str = None):
self.api_key = api_key
self.model = model
self.api_provider = api_provider.lower()
Expand All @@ -29,34 +46,31 @@ def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1
"content-type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
self.url = URLS[self.api_provider]
self.url = url if url else URLS[self.api_provider]


def _get_document_text(self, r: dict) -> str:
if self.api_provider == "voyage":
return r["document"]
elif self.api_provider == "mixedbread.ai":
return r["input"]
elif self.api_provider == "text-embeddings-inference":
return r["text"]
else:
return r["document"]["text"]

def _get_score(self, r: dict) -> float:
if self.api_provider == "mixedbread.ai":
return r["score"]
return r["relevance_score"]
score_key = SCORE_KEY_MAPPING.get(self.api_provider,"relevance_score")
return r[score_key]

def _parse_response(
self, response: dict, docs: List[Document],
) -> RankedResults:
ranked_docs = []
results_key = (
"results"
if self.api_provider not in ["voyage", "mixedbread.ai"]
else "data"
)
results_key = RESULTS_KEY_MAPPING.get(self.api_provider,"results")
print(response)

for i, r in enumerate(response[results_key]):
for i, r in enumerate(response[results_key] if results_key else response):
ranked_docs.append(
Result(
document=docs[r["index"]],
Expand Down Expand Up @@ -86,12 +100,8 @@ def _format_payload(self, query: str, docs: List[str]) -> str:
top_key = (
"top_n" if self.api_provider not in ["voyage", "mixedbread.ai"] else "top_k"
)
documents_key = "documents" if self.api_provider != "mixedbread.ai" else "input"
return_documents_key = (
"return_documents"
if self.api_provider != "mixedbread.ai"
else "return_input"
)
documents_key = DOCUMENT_KEY_MAPPING.get(self.api_provider,"documents")
return_documents_key = RETURN_DOCUMENTS_KEY_MAPPING.get(self.api_provider,"return_documents")

payload = {
"model": self.model,
Expand Down
5 changes: 4 additions & 1 deletion rerankers/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"es": "AdrienB134/ColBERTv2.0-spanish-mmarcoES",
},
"flashrank": {"en": "ms-marco-MiniLM-L-12-v2", "other": "ms-marco-MultiBERT-L-12"},
"text-embeddings-inference": {"other": "BAAI/bge-reranker-base"}
}

DEPS_MAPPING = {
Expand All @@ -43,7 +44,7 @@
"RankLLMRanker": "rankllm",
}

PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai"]
PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "text-embeddings-inference"]


def _get_api_provider(model_name: str, model_type: Optional[str] = None) -> str:
Expand All @@ -69,6 +70,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"cohere": "APIRanker",
"jina": "APIRanker",
"voyage": "APIRanker",
"text-embeddings-inference": "APIRanker",
"rankgpt": "RankGPTRanker",
"lit5": "LiT5Ranker",
"t5": "T5Ranker",
Expand All @@ -92,6 +94,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"cohere": "APIRanker",
"jina": "APIRanker",
"voyage": "APIRanker",
"text-embeddings-inference": "APIRanker",
"ms-marco-minilm-l-12-v2": "FlashRankRanker",
"ms-marco-multibert-l-12": "FlashRankRanker",
"vicuna": "RankLLMRanker",
Expand Down

0 comments on commit b407dc5

Please sign in to comment.