From 59a3dc134e30e3db2e895879e1667b4233b3383a Mon Sep 17 00:00:00 2001 From: ZiyiXia Date: Thu, 31 Oct 2024 02:45:41 +0000 Subject: [PATCH 1/8] ignore result files --- .gitignore | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index f0c3b0d9..788dc18e 100644 --- a/.gitignore +++ b/.gitignore @@ -139,4 +139,8 @@ pic2.py .pyre/ # MacOS associated -.DS_Store \ No newline at end of file +.DS_Store + +# results +en_results +zh_results \ No newline at end of file From bcf1b0341dbc58de880021d1bd7faa14679c4ae8 Mon Sep 17 00:00:00 2001 From: ZiyiXia Date: Thu, 31 Oct 2024 03:34:27 +0000 Subject: [PATCH 2/8] eval arguments --- FlagEmbedding/abc/evaluation/arguments.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/FlagEmbedding/abc/evaluation/arguments.py b/FlagEmbedding/abc/evaluation/arguments.py index ba2dfaae..8cbe88d9 100644 --- a/FlagEmbedding/abc/evaluation/arguments.py +++ b/FlagEmbedding/abc/evaluation/arguments.py @@ -8,6 +8,9 @@ @dataclass class AbsEvalArgs: + """ + Arguments for running evaluation. + """ eval_name: str = field( default=None, metadata={"help": "The name of the evaluation task, such as msmarco, beir, miracl, etc."} @@ -77,6 +80,9 @@ class AbsEvalArgs: @dataclass class AbsEvalModelArgs: + """ + Arguments for model during evaluation. + """ embedder_name_or_path: str = field( metadata={"help": "The embedder name or path.", "required": True} ) From 5ffe17b3243be571a3d052b3b909d09c0149f7eb Mon Sep 17 00:00:00 2001 From: ZiyiXia Date: Thu, 31 Oct 2024 07:23:14 +0000 Subject: [PATCH 3/8] data loader --- FlagEmbedding/abc/evaluation/data_loader.py | 178 ++++++++++++++++++++ 1 file changed, 178 insertions(+) diff --git a/FlagEmbedding/abc/evaluation/data_loader.py b/FlagEmbedding/abc/evaluation/data_loader.py index 2cd6ad51..f61e1a02 100644 --- a/FlagEmbedding/abc/evaluation/data_loader.py +++ b/FlagEmbedding/abc/evaluation/data_loader.py @@ -12,6 +12,15 @@ class AbsEvalDataLoader(ABC): + """_summary_ + + Args: + eval_name (str): The experiment name of current evaluation. + dataset_dir (str, optional): path to the datasets. Defaults to None. + cache_dir (str, optional): Path to HuggingFace cache directory. Defaults to None. + token (str, optional): HF_TOKEN to access the private datasets/models in HF. Defaults to None. + force_redownload: If True, will force redownload the dataset to cover the local dataset. Defaults to False. + """ def __init__( self, eval_name: str, @@ -43,6 +52,17 @@ def available_splits(self, dataset_name: Optional[str] = None) -> List[str]: pass def check_dataset_names(self, dataset_names: Union[str, List[str]]) -> List[str]: + """Check the validity of dataset names + + Args: + dataset_names (Union[str, List[str]]): a dataset name (str) or a list of dataset names (List[str]) + + Raises: + ValueError + + Returns: + List[str]: List of valid dataset names. + """ available_dataset_names = self.available_dataset_names() if isinstance(dataset_names, str): dataset_names = [dataset_names] @@ -53,6 +73,15 @@ def check_dataset_names(self, dataset_names: Union[str, List[str]]) -> List[str] return dataset_names def check_splits(self, splits: Union[str, List[str]], dataset_name: Optional[str] = None) -> List[str]: + """Check whether the splits are available in the dataset. + + Args: + splits (Union[str, List[str]]): Splits to check. + dataset_name (Optional[str], optional): Name of dataset to check. Defaults to None. + + Returns: + List[str]: The available splits. + """ available_splits = self.available_splits(dataset_name=dataset_name) if isinstance(splits, str): splits = [splits] @@ -65,6 +94,14 @@ def check_splits(self, splits: Union[str, List[str]], dataset_name: Optional[str return checked_splits def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDict: + """Load the corpus from the dataset. + + Args: + dataset_name (Optional[str], optional): Name of the dataset. Defaults to None. + + Returns: + datasets.DatasetDict: A dict of corpus with id as key, title and text as value. + """ if self.dataset_dir is not None: if dataset_name is None: save_dir = self.dataset_dir @@ -75,6 +112,18 @@ def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDic return self._load_remote_corpus(dataset_name=dataset_name) def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: + """Load the corpus from the dataset. + + Args: + dataset_name (Optional[str], optional): Name of the dataset. Defaults to None. + split (str, optional): The split to load relevance from. Defaults to 'test'. + + Raises: + ValueError + + Returns: + datasets.DatasetDict: A dict of relevance of query and document. + """ if self.dataset_dir is not None: if dataset_name is None: save_dir = self.dataset_dir @@ -91,6 +140,18 @@ def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') -> return self._load_remote_qrels(dataset_name=dataset_name, split=split) def load_queries(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: + """Load the queries from the dataset. + + Args: + dataset_name (Optional[str], optional): Name of the dataset. Defaults to None. + split (str, optional): The split to load queries from. Defaults to 'test'. + + Raises: + ValueError + + Returns: + datasets.DatasetDict: A dict of queries with id as key, query text as value. + """ if self.dataset_dir is not None: if dataset_name is None: save_dir = self.dataset_dir @@ -111,6 +172,18 @@ def _load_remote_corpus( dataset_name: Optional[str] = None, save_dir: Optional[str] = None ) -> datasets.DatasetDict: + """Abstract method to load corpus from remote dataset, to be overrode in child class. + + Args: + dataset_name (Optional[str], optional): Name of the dataset. Defaults to None. + save_dir (Optional[str], optional): Path to save the new downloaded corpus. Defaults to None. + + Raises: + NotImplementedError: Loading remote corpus is not implemented. + + Returns: + datasets.DatasetDict: A dict of corpus with id as key, title and text as value. + """ raise NotImplementedError("Loading remote corpus is not implemented.") def _load_remote_qrels( @@ -119,6 +192,19 @@ def _load_remote_qrels( split: str = 'test', save_dir: Optional[str] = None ) -> datasets.DatasetDict: + """Abstract method to load relevance from remote dataset, to be overrode in child class. + + Args: + dataset_name (Optional[str], optional): Name of the dataset. Defaults to None. + split (str, optional): Split to load from the remote dataset. Defaults to 'test'. + save_dir (Optional[str], optional): Path to save the new downloaded relevance. Defaults to None. + + Raises: + NotImplementedError: Loading remote qrels is not implemented. + + Returns: + datasets.DatasetDict: A dict of relevance of query and document. + """ raise NotImplementedError("Loading remote qrels is not implemented.") def _load_remote_queries( @@ -127,9 +213,31 @@ def _load_remote_queries( split: str = 'test', save_dir: Optional[str] = None ) -> datasets.DatasetDict: + """Abstract method to load queries from remote dataset, to be overrode in child class. + + Args: + dataset_name (Optional[str], optional): Name of the dataset. Defaults to None. + split (str, optional): Split to load from the remote dataset. Defaults to 'test'. + save_dir (Optional[str], optional): Path to save the new downloaded queries. Defaults to None. + + Raises: + NotImplementedError + + Returns: + datasets.DatasetDict: A dict of queries with id as key, query text as value. + """ raise NotImplementedError("Loading remote queries is not implemented.") def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None) -> datasets.DatasetDict: + """Load corpus from local dataset. + + Args: + save_dir (str): Path to save the loaded corpus. + dataset_name (Optional[str], optional): Name of the dataset. Defaults to None. + + Returns: + datasets.DatasetDict: A dict of corpus with id as key, title and text as value. + """ corpus_path = os.path.join(save_dir, 'corpus.jsonl') if self.force_redownload or not os.path.exists(corpus_path): logger.warning(f"Corpus not found in {corpus_path}. Trying to download the corpus from the remote and save it to {save_dir}.") @@ -144,6 +252,19 @@ def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None) return datasets.DatasetDict(corpus) def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: + """Load relevance from local dataset. + + Args: + save_dir (str): Path to save the loaded relevance. + dataset_name (Optional[str], optional): Name of the dataset. Defaults to None. + split (str, optional): Split to load from the local dataset. Defaults to 'test'. + + Raises: + ValueError + + Returns: + datasets.DatasetDict: A dict of relevance of query and document. + """ checked_split = self.check_splits(split) if len(checked_split) == 0: raise ValueError(f"Split {split} not found in the dataset.") @@ -166,6 +287,19 @@ def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, s return datasets.DatasetDict(qrels) def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: + """Load queries from local dataset. + + Args: + save_dir (str): Path to save the loaded queries. + dataset_name (Optional[str], optional): Name of the dataset. Defaults to None. + split (str, optional): Split to load from the local dataset. Defaults to 'test'. + + Raises: + ValueError + + Returns: + datasets.DatasetDict: A dict of queries with id as key, query text as value. + """ checked_split = self.check_splits(split) if len(checked_split) == 0: raise ValueError(f"Split {split} not found in the dataset.") @@ -182,6 +316,18 @@ def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None, return datasets.DatasetDict(queries) def _download_file(self, download_url: str, save_dir: str): + """Download file from provided URL. + + Args: + download_url (str): Source URL of the file. + save_dir (str): Path to the directory to save the zip file. + + Raises: + FileNotFoundError + + Returns: + str: The path of the downloaded file. + """ save_path = os.path.join(save_dir, download_url.split('/')[-1]) if self.force_redownload or (not os.path.exists(save_path) or os.path.getsize(save_path) == 0): @@ -201,6 +347,14 @@ def _download_file(self, download_url: str, save_dir: str): return save_path def _get_fpath_size(self, fpath: str) -> int: + """Get the total size of the files in provided path. + + Args: + fpath (str): path of files to compute the size. + + Returns: + int: The total size in bytes. + """ if not os.path.isdir(fpath): return os.path.getsize(fpath) else: @@ -212,6 +366,18 @@ def _get_fpath_size(self, fpath: str) -> int: return total_size def _download_gz_file(self, download_url: str, save_dir: str): + """Download and unzip the gzip file from provided URL. + + Args: + download_url (str): Source URL of the gzip file. + save_dir (str): Path to the directory to save the gzip file. + + Raises: + FileNotFoundError: _description_ + + Returns: + str: The path to the file after unzip. + """ gz_file_path = self._download_file(download_url, save_dir) cmd = ["gzip", "-d", gz_file_path] try: @@ -226,6 +392,18 @@ def _download_gz_file(self, download_url: str, save_dir: str): return file_path def _download_zip_file(self, download_url: str, save_dir: str): + """Download and unzip the zip file from provided URL. + + Args: + download_url (str): Source URL of the zip file. + save_dir (str): Path to the directory to save the zip file. + + Raises: + FileNotFoundError + + Returns: + str: The path to the file after unzip. + """ zip_file_path = self._download_file(download_url, save_dir) file_path = zip_file_path.replace(".zip", "") if self.force_redownload or not os.path.exists(file_path): From bea75457b38f84cb7f26423363768664248c160c Mon Sep 17 00:00:00 2001 From: ZiyiXia Date: Thu, 31 Oct 2024 09:38:07 +0000 Subject: [PATCH 4/8] absembedder and absreranker --- FlagEmbedding/abc/inference/AbsEmbedder.py | 12 ++++++------ FlagEmbedding/abc/inference/AbsReranker.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/FlagEmbedding/abc/inference/AbsEmbedder.py b/FlagEmbedding/abc/inference/AbsEmbedder.py index 421ac1bd..5f303921 100644 --- a/FlagEmbedding/abc/inference/AbsEmbedder.py +++ b/FlagEmbedding/abc/inference/AbsEmbedder.py @@ -18,7 +18,7 @@ class AbsEmbedder(ABC): """ Base class for embedder. - Extend this class and implement `encode_queries`, `encode_passages`, `encode` for custom embedders. + Extend this class and implement :meth:`encode_queries`, :meth:`encode_passages`, :meth:`encode` for custom embedders. Args: model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and @@ -27,14 +27,14 @@ class AbsEmbedder(ABC): use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance degradation. Default: `True`. query_instruction_for_retrieval: (Optional[str], optional): Query instruction for retrieval tasks, which will be used with - with `query_instruction_format`. Default: `None`. - query_instruction_format: (str, optional): The template for `query_instruction_for_retrieval`. Default: `"{}{}"`. + with :attr:`query_instruction_format`. Default: `None`. + query_instruction_format: (str, optional): The template for :attr:`query_instruction_for_retrieval`. Default: `"{}{}"`. devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Default: `None`. batch_size (int, optional): Batch size for inference. Default: `256`. query_max_length (int, optional): Maximum length for query. Default: `512`. passage_max_length (int, optional): Maximum length for passage. Default: `512`. - instruction (Optional[str], optional): Instruction for embedding. Default: `None`. - instruction_format (str, optional): Instruction format when using `instruction`. Default: `"{}{}"`. + instruction (Optional[str], optional): Instruction for embedding with :attr:`instruction_format`. Default: `None`. + instruction_format (str, optional): Instruction format when using :attr:`instruction`. Default: `"{}{}"`. convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor. Default: `True`. kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes. @@ -88,7 +88,7 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s ValueError: Devices should be a string or an integer or a list of strings or a list of integers. Returns: - List[str]: A list of target devices in format + List[str]: A list of target devices in format. """ if devices is None: if torch.cuda.is_available(): diff --git a/FlagEmbedding/abc/inference/AbsReranker.py b/FlagEmbedding/abc/inference/AbsReranker.py index 55460c39..deb35715 100644 --- a/FlagEmbedding/abc/inference/AbsReranker.py +++ b/FlagEmbedding/abc/inference/AbsReranker.py @@ -17,7 +17,7 @@ class AbsReranker(ABC): """ Base class for Reranker. - Extend this class and implement `compute_score_single_gpu` for custom rerankers. + Extend this class and implement :meth:`compute_score_single_gpu` for custom rerankers. Args: model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and @@ -25,8 +25,8 @@ class AbsReranker(ABC): use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance degradation. Default: `False`. query_instruction_for_rerank: (Optional[str], optional): Query instruction for reranking, which will be used with - with `query_instruction_format`. Default: `None`. - query_instruction_format: (str, optional): The template for `query_instruction_for_rerank`. Default: `"{}{}"`. + with :attr:`query_instruction_format`. Default: `None`. + query_instruction_format: (str, optional): The template for :attr:`query_instruction_for_rerank`. Default: `"{}{}"`. passage_instruction_for_rerank (Optional[str], optional): Passage instruction for reranking. Default: `None`. passage_instruction_format (str, optional): Passage instruction format when using `passage_instruction_for_rerank`. Default: `"{}{}"`. From 3805ad5c8e4fcc5d1cf4d47bd1c9745c5e04f71e Mon Sep 17 00:00:00 2001 From: ZiyiXia Date: Thu, 31 Oct 2024 10:31:23 +0000 Subject: [PATCH 5/8] data loader --- .gitignore | 3 ++- FlagEmbedding/abc/evaluation/arguments.py | 4 ++-- FlagEmbedding/abc/evaluation/data_loader.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 788dc18e..306095a2 100644 --- a/.gitignore +++ b/.gitignore @@ -143,4 +143,5 @@ pic2.py # results en_results -zh_results \ No newline at end of file +zh_results +docs \ No newline at end of file diff --git a/FlagEmbedding/abc/evaluation/arguments.py b/FlagEmbedding/abc/evaluation/arguments.py index 8cbe88d9..61dfa6f5 100644 --- a/FlagEmbedding/abc/evaluation/arguments.py +++ b/FlagEmbedding/abc/evaluation/arguments.py @@ -9,7 +9,7 @@ @dataclass class AbsEvalArgs: """ - Arguments for running evaluation. + Base class for evaluation arguments. """ eval_name: str = field( default=None, @@ -81,7 +81,7 @@ class AbsEvalArgs: @dataclass class AbsEvalModelArgs: """ - Arguments for model during evaluation. + Base class for model arguments during evaluation. """ embedder_name_or_path: str = field( metadata={"help": "The embedder name or path.", "required": True} diff --git a/FlagEmbedding/abc/evaluation/data_loader.py b/FlagEmbedding/abc/evaluation/data_loader.py index f61e1a02..003c95f7 100644 --- a/FlagEmbedding/abc/evaluation/data_loader.py +++ b/FlagEmbedding/abc/evaluation/data_loader.py @@ -12,7 +12,8 @@ class AbsEvalDataLoader(ABC): - """_summary_ + """ + Base class of data loader for evaluation. Args: eval_name (str): The experiment name of current evaluation. From 117f1406644b1a262f7977dc1ff0e3e8d15dfbce Mon Sep 17 00:00:00 2001 From: ZiyiXia Date: Thu, 31 Oct 2024 10:31:30 +0000 Subject: [PATCH 6/8] evaluator --- FlagEmbedding/abc/evaluation/evaluator.py | 72 ++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/FlagEmbedding/abc/evaluation/evaluator.py b/FlagEmbedding/abc/evaluation/evaluator.py index 77256965..388c76cc 100644 --- a/FlagEmbedding/abc/evaluation/evaluator.py +++ b/FlagEmbedding/abc/evaluation/evaluator.py @@ -16,6 +16,14 @@ class AbsEvaluator: + """ + Base class of Evaluator. + + Args: + eval_name (str): The experiment name of current evaluation. + data_loader (AbsEvalDataLoader): The data_loader to deal with data. + overwrite (bool): If true, will overwrite the existing results. + """ def __init__( self, eval_name: str, @@ -34,6 +42,21 @@ def check_data_info( split: str, dataset_name: Optional[str] = None, ): + """Check the validity of data info. + + Args: + data_info (Dict[str, str]): The loaded data info to be check. + model_name (str): Name of model used. + reranker_name (str): Name of reranker used. + split (str): Split used in searching. + dataset_name (Optional[str], optional): Name of dataset used. Defaults to None. + + Raises: + ValueError: eval_name mismatch + ValueError: model_name or reranker_name mismatch + ValueError: split mismatch + ValueError: dataset_name mismatch + """ if data_info["eval_name"] != self.eval_name: raise ValueError( f'eval_name mismatch: {data_info["eval_name"]} vs {self.eval_name}' @@ -61,7 +84,13 @@ def get_corpus_embd_save_dir( dataset_name: Optional[str] = None ): """ - If corpus_embd_save_dir is not None, then it will be used as the base directory to save the corpus embeddings. For dataset such as MKQA, the corpus for all languages is the same, so the subclass can override this method to save the corpus embeddings in the same directory. + If corpus_embd_save_dir is not None, then it will be used as the base directory to save the corpus embeddings. For dataset such as MKQA, + the corpus for all languages is the same, so the subclass can override this method to save the corpus embeddings in the same directory. + + Args: + retriever_name (str): Name of the retriever. + corpus_embd_save_dir (str, optional): Directory that saving the corpus embedding. + dataset_name (str, optional): """ if corpus_embd_save_dir is not None: if dataset_name is not None: @@ -228,6 +257,17 @@ def save_search_results( split: str, dataset_name: Optional[str] = None, ): + """Save the metadata and search results into a file. + + Args: + eval_name (str): The experiment name of current evaluation. + model_name (str): Name of model used. + reranker_name (str): Name of reranker used. + search_results (Dict[str, Dict[str, float]]): The search results. + output_path (str): Output path to write the results. + split (str): Split used in searching. + dataset_name (Optional[str], optional): Name of dataset used. Defaults to None. + """ data = { "eval_name": eval_name, "model_name": model_name, @@ -244,6 +284,14 @@ def save_search_results( @staticmethod def load_search_results(input_path: str): + """Load search results from path. + + Args: + input_path (str): Path to load from. + + Returns: + dict, dict: data info that contains metadata and search results. + """ with open(input_path, "r", encoding="utf-8") as f: data_info = json.load(f) @@ -312,6 +360,12 @@ def evaluate_results( @staticmethod def output_eval_results_to_json(eval_results_dict: dict, output_path: str): + """Write the evaluation results into a json file. + + Args: + eval_results_dict (dict): Dictionary of the evaluation results. + output_path (str): Output path to write the json file. + """ os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: @@ -320,6 +374,15 @@ def output_eval_results_to_json(eval_results_dict: dict, output_path: str): @staticmethod def get_results_df(metric: str, eval_results_dict: dict): + """Get the results from dictionary to a DataFrame. + + Args: + metric (str): Selected metric. + eval_results_dict (dict): Dictionary of the evaluation results. + + Returns: + DataFrame: DataFrame of the results. + """ results_dict = {} for model_name, model_results in eval_results_dict.items(): @@ -361,6 +424,13 @@ def get_results_df(metric: str, eval_results_dict: dict): @staticmethod def output_eval_results_to_markdown(eval_results_dict: dict, output_path: str, metrics: Union[List[str], str]): + """Write the evaluation results to a markdown file. + + Args: + eval_results_dict (dict): Dictionary that contains evaluation results. + output_path (str): Path to write the output to. + metrics (Union[List[str], str]): The metrics that will be written in the markdown file. + """ os.makedirs(os.path.dirname(output_path), exist_ok=True) if isinstance(metrics, str): From 134a1ade23da32f4ced6a31b6fb56dbe13d140d2 Mon Sep 17 00:00:00 2001 From: ZiyiXia Date: Thu, 31 Oct 2024 12:00:34 +0000 Subject: [PATCH 7/8] abc evaluator --- FlagEmbedding/abc/evaluation/evaluator.py | 33 ++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/FlagEmbedding/abc/evaluation/evaluator.py b/FlagEmbedding/abc/evaluation/evaluator.py index 388c76cc..35e86de6 100644 --- a/FlagEmbedding/abc/evaluation/evaluator.py +++ b/FlagEmbedding/abc/evaluation/evaluator.py @@ -111,6 +111,18 @@ def __call__( dataset_name: Optional[str] = None, **kwargs, ): + """Called to the whole evaluation process. + + Args: + splits (Union[str, List[str]]): Splits of datasets. + search_results_save_dir (str): Directory to save the search results. + retriever (EvalRetriever): object of :class:EvalRetriever + reranker (Optional[EvalReranker], optional): Object of :class:EvalReranker. Defaults to None. + corpus_embd_save_dir (Optional[str], optional): Directory to save the embedded corpus. Defaults to None. + ignore_identical_ids (bool, optional): If True, will ignore identical ids in search results. Defaults to False. + k_values (List[int], optional): Cutoffs. Defaults to [1, 3, 5, 10, 100, 1000]. + dataset_name (Optional[str], optional): Name of the datasets. Defaults to None. + """ # Check Splits checked_splits = self.data_loader.check_splits(splits, dataset_name=dataset_name) if len(checked_splits) == 0: @@ -263,7 +275,7 @@ def save_search_results( eval_name (str): The experiment name of current evaluation. model_name (str): Name of model used. reranker_name (str): Name of reranker used. - search_results (Dict[str, Dict[str, float]]): The search results. + search_results (Dict[str, Dict[str, float]]): Dictionary of search results. output_path (str): Output path to write the results. split (str): Split used in searching. dataset_name (Optional[str], optional): Name of dataset used. Defaults to None. @@ -304,6 +316,16 @@ def compute_metrics( search_results: Dict[str, Dict[str, float]], k_values: List[int], ): + """Evaluate the model with metrics. + + Args: + qrels (Dict[str, Dict[str, int]]): Ground truth relevance of queries and documents. + search_results (Dict[str, Dict[str, float]]): Dictionary of search results + k_values (List[int]): Cutoffs. + + Returns: + dict: The results of the metrics. + """ ndcg, _map, recall, precision = evaluate_metrics( qrels=qrels, results=search_results, @@ -328,6 +350,15 @@ def evaluate_results( search_results_save_dir: str, k_values: List[int] = [1, 3, 5, 10, 100, 1000] ): + """Compute metrics according to the results in the directory. + + Args: + search_results_save_dir (str): Path to the search results. + k_values (List[int], optional): Cutoffs. Defaults to [1, 3, 5, 10, 100, 1000]. + + Returns: + _type_: _description_ + """ eval_results_dict = {} for file in os.listdir(search_results_save_dir): From 7ae0ecf4f065bc625f2bee5e3e8b8fded9f7fe39 Mon Sep 17 00:00:00 2001 From: ZiyiXia Date: Thu, 31 Oct 2024 16:14:13 +0000 Subject: [PATCH 8/8] evaluation docstring --- FlagEmbedding/abc/evaluation/evaluator.py | 2 +- FlagEmbedding/abc/evaluation/runner.py | 47 +++++++++++++++++++++++ FlagEmbedding/abc/evaluation/searcher.py | 11 +++++- FlagEmbedding/abc/evaluation/utils.py | 41 ++++++++++++++++++++ 4 files changed, 99 insertions(+), 2 deletions(-) diff --git a/FlagEmbedding/abc/evaluation/evaluator.py b/FlagEmbedding/abc/evaluation/evaluator.py index 35e86de6..60b3c4f7 100644 --- a/FlagEmbedding/abc/evaluation/evaluator.py +++ b/FlagEmbedding/abc/evaluation/evaluator.py @@ -111,7 +111,7 @@ def __call__( dataset_name: Optional[str] = None, **kwargs, ): - """Called to the whole evaluation process. + """This is called during the evaluation process. Args: splits (Union[str, List[str]]): Splits of datasets. diff --git a/FlagEmbedding/abc/evaluation/runner.py b/FlagEmbedding/abc/evaluation/runner.py index 5bef0eb2..7d70dcb9 100644 --- a/FlagEmbedding/abc/evaluation/runner.py +++ b/FlagEmbedding/abc/evaluation/runner.py @@ -14,6 +14,13 @@ class AbsEvalRunner: + """ + Abstract class of evaluation runner. + + Args: + eval_args (AbsEvalArgs): :class:AbsEvalArgs object with the evaluation arguments. + model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments. + """ def __init__( self, eval_args: AbsEvalArgs, @@ -28,6 +35,15 @@ def __init__( @staticmethod def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: + """Get the embedding and reranker model + + Args: + model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments. + + Returns: + Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: A :class:FlagAutoModel object of embedding model, and + :class:FlagAutoReranker object of reranker model if path provided. + """ embedder = FlagAutoModel.from_finetuned( model_name_or_path=model_args.embedder_name_or_path, model_class=model_args.embedder_model_class, @@ -74,6 +90,12 @@ def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagA return embedder, reranker def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: + """Load retriever and reranker for evaluation + + Returns: + Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a + :class:EvalReranker object if reranker provided. + """ embedder, reranker = self.get_models(self.model_args) retriever = EvalDenseRetriever( embedder, @@ -85,6 +107,11 @@ def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalRer return retriever, reranker def load_data_loader(self) -> AbsEvalDataLoader: + """Load the data loader + + Returns: + AbsEvalDataLoader: Data loader object for that specific task. + """ data_loader = AbsEvalDataLoader( eval_name=self.eval_args.eval_name, dataset_dir=self.eval_args.dataset_dir, @@ -95,6 +122,11 @@ def load_data_loader(self) -> AbsEvalDataLoader: return data_loader def load_evaluator(self) -> AbsEvaluator: + """Load the evaluator for evaluation + + Returns: + AbsEvaluator: the evaluator to run the evaluation. + """ evaluator = AbsEvaluator( eval_name=self.eval_args.eval_name, data_loader=self.data_loader, @@ -109,6 +141,18 @@ def evaluate_metrics( output_path: str = "./eval_dev_results.md", metrics: Union[str, List[str]] = ["ndcg_at_10", "recall_at_10"] ): + """Evaluate the provided metrics and write the results. + + Args: + search_results_save_dir (str): Path to save the search results. + output_method (str, optional): Output results to `json` or `markdown`. Defaults to "markdown". + output_path (str, optional): Path to write the output. Defaults to "./eval_dev_results.md". + metrics (Union[str, List[str]], optional): metrics to use. Defaults to ["ndcg_at_10", "recall_at_10"]. + + Raises: + FileNotFoundError: Eval results not found + ValueError: Invalid output method + """ eval_results_dict = {} for model_name in sorted(os.listdir(search_results_save_dir)): model_search_results_save_dir = os.path.join(search_results_save_dir, model_name) @@ -136,6 +180,9 @@ def evaluate_metrics( raise ValueError(f"Invalid output method: {output_method}. Available methods: ['json', 'markdown']") def run(self): + """ + Run the whole evaluation. + """ if self.eval_args.dataset_names is None: dataset_names = self.data_loader.available_dataset_names() else: diff --git a/FlagEmbedding/abc/evaluation/searcher.py b/FlagEmbedding/abc/evaluation/searcher.py index a2931bb8..82cade65 100644 --- a/FlagEmbedding/abc/evaluation/searcher.py +++ b/FlagEmbedding/abc/evaluation/searcher.py @@ -16,6 +16,9 @@ class EvalRetriever(ABC): + """ + This is the base class for retriever. + """ def __init__(self, embedder: AbsEmbedder, search_top_k: int = 1000, overwrite: bool = False): self.embedder = embedder self.search_top_k = search_top_k @@ -45,7 +48,7 @@ def __call__( **kwargs, ) -> Dict[str, Dict[str, float]]: """ - This is called during the retrieval process. + Abstract method to be overrode. This is called during the retrieval process. Parameters: corpus: Dict[str, Dict[str, Any]]: Corpus of documents. @@ -63,6 +66,9 @@ def __call__( class EvalDenseRetriever(EvalRetriever): + """ + Child class of :class:EvalRetriever for dense retrieval. + """ def __call__( self, corpus: Dict[str, Dict[str, Any]], @@ -144,6 +150,9 @@ def __call__( class EvalReranker: + """ + Class for reranker during evaluation. + """ def __init__(self, reranker: AbsReranker, rerank_top_k: int = 100): self.reranker = reranker self.rerank_top_k = rerank_top_k diff --git a/FlagEmbedding/abc/evaluation/utils.py b/FlagEmbedding/abc/evaluation/utils.py index 2c47a53f..f5f81350 100644 --- a/FlagEmbedding/abc/evaluation/utils.py +++ b/FlagEmbedding/abc/evaluation/utils.py @@ -16,6 +16,16 @@ def evaluate_mrr( results: Dict[str, Dict[str, float]], k_values: List[int], ) -> Tuple[Dict[str, float]]: + """Compute mean reciprocal rank (MRR). + + Args: + qrels (Dict[str, Dict[str, int]]): Ground truth relevance. + results (Dict[str, Dict[str, float]]): Search results to evaluate. + k_values (List[int]): Cutoffs. + + Returns: + Tuple[Dict[str, float]]: MRR results at provided k values. + """ mrr = defaultdict(list) k_max, top_hits = max(k_values), {} @@ -53,6 +63,17 @@ def evaluate_metrics( Dict[str, float], Dict[str, float], ]: + """Evaluate the main metrics. + + Args: + qrels (Dict[str, Dict[str, int]]): Ground truth relevance. + results (Dict[str, Dict[str, float]]): Search results to evaluate. + k_values (List[int]): Cutoffs. + + Returns: + Tuple[ Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], ]: Results of different metrics at + different provided k values. + """ all_ndcgs, all_aps, all_recalls, all_precisions = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list) map_string = "map_cut." + ",".join([str(k) for k in k_values]) @@ -93,6 +114,17 @@ def index( load_path: Optional[str] = None, device: Optional[str] = None ): + """Create and add embeddings into a Faiss index. + + Args: + index_factory (str, optional): Type of Faiss index to create. Defaults to "Flat". + corpus_embeddings (Optional[np.ndarray], optional): The embedding vectors of the corpus. Defaults to None. + load_path (Optional[str], optional): Path to load embeddings from. Defaults to None. + device (Optional[str], optional): Device to hold Faiss index. Defaults to None. + + Returns: + faiss.Index: The Faiss index that contains all the corpus embeddings. + """ if corpus_embeddings is None: corpus_embeddings = np.load(load_path) @@ -127,6 +159,15 @@ def search( """ 1. Encode queries into dense embeddings; 2. Search through faiss index + + Args: + faiss_index (faiss.Index): The Faiss index that contains all the corpus embeddings. + k (int, optional): Top k numbers of closest neighbours. Defaults to 100. + query_embeddings (Optional[np.ndarray], optional): The embedding vectors of queries. Defaults to None. + load_path (Optional[str], optional): Path to load embeddings from. Defaults to None. + + Returns: + Tuple[np.ndarray, np.ndarray]: The scores of search results and their corresponding indices. """ if query_embeddings is None: query_embeddings = np.load(load_path)