Skip to content

Commit

Permalink
Merge pull request #6 from lancedb/rag
Browse files Browse the repository at this point in the history
Add reranker latency support
  • Loading branch information
AyushExel authored Sep 4, 2024
2 parents e7ce715 + e8adddf commit 58ff234
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 2 deletions.
50 changes: 50 additions & 0 deletions ragged/metrics/reranker/rerank_latency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import copy
from ..retriever.base import Metric
import lancedb
import logging
import pandas as pd
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
import time

from ...search_utils import QueryType, deduce_query_type, search_table, QueryConfigError
from ..retriever.hit_rate import Metric
from ...results import RerankerResult

# Set logging level to INFO
logger = logging.getLogger("lancedb")
logger.setLevel(logging.INFO)

class RerankerLatency:
def __init__(self, metric: Metric) -> None:
if metric.reranker is None:
raise ValueError("Reranker must be provided to evaluate reranker latency")
self.metric = metric

def evaluate(self, query_type:str, top_k: int = 5) -> RerankerResult:
if query_type == QueryType.HYBRID:
raise ValueError("HYBRID query type is not supported for reranker latency evaluation")
elif query_type == QueryType.VECTOR:
qt, reranked_qt = QueryType.VECTOR, QueryType.RERANK_VECTOR
elif query_type == QueryType.FTS:
qt, reranked_qt = QueryType.FTS, QueryType.RERANK_FTS
else:
raise ValueError(f"Invalid query type: {query_type}. Supported query types are VECTOR and FTS")

t1 = time.time()
eval_reranker = self.metric.evaluate_query_type(reranked_qt, top_k)
t2 = time.time()
reranker_latency = t2 - t1

t3 = time.time()
eval_wo_reranker = self.metric.evaluate_query_type(qt, top_k)
t4 = time.time()
wo_reranker_latency = t4 - t3

delta_accuracy = eval_reranker - eval_wo_reranker
data_size = len(self.metric.dataset.to_pandas())
latency = (reranker_latency - wo_reranker_latency) / data_size

return RerankerResult(latency=latency, delta_accuracy=delta_accuracy)


2 changes: 1 addition & 1 deletion ragged/metrics/retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .hit_rate import HitRate
from .base import QueryType
from .base import Metric
6 changes: 5 additions & 1 deletion ragged/metrics/retriever/hit_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
logger.setLevel(logging.INFO)

class HitRate(Metric):
def evaluate_query_type(self,query_type:str, top_k:5) -> float:
def evaluate_query_type(self,query_type:str, top_k:int = 5) -> float:
if not self.table:
self.ingest_docs()
self.table.create_fts_index("text", replace=True)

eval_results = []
ds = self.dataset.to_pandas()
for idx in tqdm.tqdm(range(len(ds))):
Expand Down
4 changes: 4 additions & 0 deletions ragged/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ class RAGResult(BaseModel):
context_precision: float = -inf
context_recall: float = -inf
harmfulness: float = -inf

class RerankerResult(BaseModel):
latency: float = -inf
delta_accuracy: float = -inf

0 comments on commit 58ff234

Please sign in to comment.