From 7ae0ecf4f065bc625f2bee5e3e8b8fded9f7fe39 Mon Sep 17 00:00:00 2001 From: ZiyiXia Date: Thu, 31 Oct 2024 16:14:13 +0000 Subject: [PATCH] evaluation docstring --- FlagEmbedding/abc/evaluation/evaluator.py | 2 +- FlagEmbedding/abc/evaluation/runner.py | 47 +++++++++++++++++++++++ FlagEmbedding/abc/evaluation/searcher.py | 11 +++++- FlagEmbedding/abc/evaluation/utils.py | 41 ++++++++++++++++++++ 4 files changed, 99 insertions(+), 2 deletions(-) diff --git a/FlagEmbedding/abc/evaluation/evaluator.py b/FlagEmbedding/abc/evaluation/evaluator.py index 35e86de6..60b3c4f7 100644 --- a/FlagEmbedding/abc/evaluation/evaluator.py +++ b/FlagEmbedding/abc/evaluation/evaluator.py @@ -111,7 +111,7 @@ def __call__( dataset_name: Optional[str] = None, **kwargs, ): - """Called to the whole evaluation process. + """This is called during the evaluation process. Args: splits (Union[str, List[str]]): Splits of datasets. diff --git a/FlagEmbedding/abc/evaluation/runner.py b/FlagEmbedding/abc/evaluation/runner.py index 5bef0eb2..7d70dcb9 100644 --- a/FlagEmbedding/abc/evaluation/runner.py +++ b/FlagEmbedding/abc/evaluation/runner.py @@ -14,6 +14,13 @@ class AbsEvalRunner: + """ + Abstract class of evaluation runner. + + Args: + eval_args (AbsEvalArgs): :class:AbsEvalArgs object with the evaluation arguments. + model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments. + """ def __init__( self, eval_args: AbsEvalArgs, @@ -28,6 +35,15 @@ def __init__( @staticmethod def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: + """Get the embedding and reranker model + + Args: + model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments. + + Returns: + Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: A :class:FlagAutoModel object of embedding model, and + :class:FlagAutoReranker object of reranker model if path provided. + """ embedder = FlagAutoModel.from_finetuned( model_name_or_path=model_args.embedder_name_or_path, model_class=model_args.embedder_model_class, @@ -74,6 +90,12 @@ def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagA return embedder, reranker def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: + """Load retriever and reranker for evaluation + + Returns: + Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a + :class:EvalReranker object if reranker provided. + """ embedder, reranker = self.get_models(self.model_args) retriever = EvalDenseRetriever( embedder, @@ -85,6 +107,11 @@ def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalRer return retriever, reranker def load_data_loader(self) -> AbsEvalDataLoader: + """Load the data loader + + Returns: + AbsEvalDataLoader: Data loader object for that specific task. + """ data_loader = AbsEvalDataLoader( eval_name=self.eval_args.eval_name, dataset_dir=self.eval_args.dataset_dir, @@ -95,6 +122,11 @@ def load_data_loader(self) -> AbsEvalDataLoader: return data_loader def load_evaluator(self) -> AbsEvaluator: + """Load the evaluator for evaluation + + Returns: + AbsEvaluator: the evaluator to run the evaluation. + """ evaluator = AbsEvaluator( eval_name=self.eval_args.eval_name, data_loader=self.data_loader, @@ -109,6 +141,18 @@ def evaluate_metrics( output_path: str = "./eval_dev_results.md", metrics: Union[str, List[str]] = ["ndcg_at_10", "recall_at_10"] ): + """Evaluate the provided metrics and write the results. + + Args: + search_results_save_dir (str): Path to save the search results. + output_method (str, optional): Output results to `json` or `markdown`. Defaults to "markdown". + output_path (str, optional): Path to write the output. Defaults to "./eval_dev_results.md". + metrics (Union[str, List[str]], optional): metrics to use. Defaults to ["ndcg_at_10", "recall_at_10"]. + + Raises: + FileNotFoundError: Eval results not found + ValueError: Invalid output method + """ eval_results_dict = {} for model_name in sorted(os.listdir(search_results_save_dir)): model_search_results_save_dir = os.path.join(search_results_save_dir, model_name) @@ -136,6 +180,9 @@ def evaluate_metrics( raise ValueError(f"Invalid output method: {output_method}. Available methods: ['json', 'markdown']") def run(self): + """ + Run the whole evaluation. + """ if self.eval_args.dataset_names is None: dataset_names = self.data_loader.available_dataset_names() else: diff --git a/FlagEmbedding/abc/evaluation/searcher.py b/FlagEmbedding/abc/evaluation/searcher.py index a2931bb8..82cade65 100644 --- a/FlagEmbedding/abc/evaluation/searcher.py +++ b/FlagEmbedding/abc/evaluation/searcher.py @@ -16,6 +16,9 @@ class EvalRetriever(ABC): + """ + This is the base class for retriever. + """ def __init__(self, embedder: AbsEmbedder, search_top_k: int = 1000, overwrite: bool = False): self.embedder = embedder self.search_top_k = search_top_k @@ -45,7 +48,7 @@ def __call__( **kwargs, ) -> Dict[str, Dict[str, float]]: """ - This is called during the retrieval process. + Abstract method to be overrode. This is called during the retrieval process. Parameters: corpus: Dict[str, Dict[str, Any]]: Corpus of documents. @@ -63,6 +66,9 @@ def __call__( class EvalDenseRetriever(EvalRetriever): + """ + Child class of :class:EvalRetriever for dense retrieval. + """ def __call__( self, corpus: Dict[str, Dict[str, Any]], @@ -144,6 +150,9 @@ def __call__( class EvalReranker: + """ + Class for reranker during evaluation. + """ def __init__(self, reranker: AbsReranker, rerank_top_k: int = 100): self.reranker = reranker self.rerank_top_k = rerank_top_k diff --git a/FlagEmbedding/abc/evaluation/utils.py b/FlagEmbedding/abc/evaluation/utils.py index 2c47a53f..f5f81350 100644 --- a/FlagEmbedding/abc/evaluation/utils.py +++ b/FlagEmbedding/abc/evaluation/utils.py @@ -16,6 +16,16 @@ def evaluate_mrr( results: Dict[str, Dict[str, float]], k_values: List[int], ) -> Tuple[Dict[str, float]]: + """Compute mean reciprocal rank (MRR). + + Args: + qrels (Dict[str, Dict[str, int]]): Ground truth relevance. + results (Dict[str, Dict[str, float]]): Search results to evaluate. + k_values (List[int]): Cutoffs. + + Returns: + Tuple[Dict[str, float]]: MRR results at provided k values. + """ mrr = defaultdict(list) k_max, top_hits = max(k_values), {} @@ -53,6 +63,17 @@ def evaluate_metrics( Dict[str, float], Dict[str, float], ]: + """Evaluate the main metrics. + + Args: + qrels (Dict[str, Dict[str, int]]): Ground truth relevance. + results (Dict[str, Dict[str, float]]): Search results to evaluate. + k_values (List[int]): Cutoffs. + + Returns: + Tuple[ Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], ]: Results of different metrics at + different provided k values. + """ all_ndcgs, all_aps, all_recalls, all_precisions = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list) map_string = "map_cut." + ",".join([str(k) for k in k_values]) @@ -93,6 +114,17 @@ def index( load_path: Optional[str] = None, device: Optional[str] = None ): + """Create and add embeddings into a Faiss index. + + Args: + index_factory (str, optional): Type of Faiss index to create. Defaults to "Flat". + corpus_embeddings (Optional[np.ndarray], optional): The embedding vectors of the corpus. Defaults to None. + load_path (Optional[str], optional): Path to load embeddings from. Defaults to None. + device (Optional[str], optional): Device to hold Faiss index. Defaults to None. + + Returns: + faiss.Index: The Faiss index that contains all the corpus embeddings. + """ if corpus_embeddings is None: corpus_embeddings = np.load(load_path) @@ -127,6 +159,15 @@ def search( """ 1. Encode queries into dense embeddings; 2. Search through faiss index + + Args: + faiss_index (faiss.Index): The Faiss index that contains all the corpus embeddings. + k (int, optional): Top k numbers of closest neighbours. Defaults to 100. + query_embeddings (Optional[np.ndarray], optional): The embedding vectors of queries. Defaults to None. + load_path (Optional[str], optional): Path to load embeddings from. Defaults to None. + + Returns: + Tuple[np.ndarray, np.ndarray]: The scores of search results and their corresponding indices. """ if query_embeddings is None: query_embeddings = np.load(load_path)