diff --git a/llama_index/evaluation/benchmarks/beir.py b/llama_index/evaluation/benchmarks/beir.py index ddba34fa2ef65..6bab13b3c35ab 100644 --- a/llama_index/evaluation/benchmarks/beir.py +++ b/llama_index/evaluation/benchmarks/beir.py @@ -1,11 +1,12 @@ import os from shutil import rmtree -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional import tqdm from llama_index.core import BaseRetriever -from llama_index.schema import Document +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import Document, QueryBundle from llama_index.utils import get_cache_dir @@ -52,6 +53,7 @@ def run( create_retriever: Callable[[List[Document]], BaseRetriever], datasets: List[str] = ["nfcorpus"], metrics_k_values: List[int] = [3, 10], + node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, ) -> None: from beir.datasets.data_loader import GenericDataLoader from beir.retrieval.evaluation import EvaluateRetrieval @@ -82,6 +84,11 @@ def run( results = {} for key, query in tqdm.tqdm(queries.items()): nodes_with_score = retriever.retrieve(query) + node_postprocessors = node_postprocessors or [] + for node_postprocessor in node_postprocessors: + nodes_with_score = node_postprocessor.postprocess_nodes( + nodes_with_score, query_bundle=QueryBundle(query_str=query) + ) results[key] = { node.node.metadata["doc_id"]: node.score for node in nodes_with_score