diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index c937e779899..406faf2f9b2 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -172,6 +172,10 @@ def __init__(self, key="xxxxxxx", model_name="", base_url=""): def similarity(self, query: str, texts: list): if len(texts) == 0: return np.array([]), 0 + pairs = [(query, truncate(t, 4096)) for t in texts] + token_count = 0 + for _, t in pairs: + token_count += num_tokens_from_string(t) data = { "model": self.model_name, "query": query, @@ -183,7 +187,7 @@ def similarity(self, query: str, texts: list): rank = np.zeros(len(texts), dtype=float) for d in res["results"]: rank[d["index"]] = d["relevance_score"] - return rank, res["meta"]["tokens"]["input_tokens"] + res["meta"]["tokens"]["output_tokens"] + return rank, token_count class LocalAIRerank(Base):