Skip to content

Commit

Permalink
evaluation docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyiXia committed Oct 31, 2024
1 parent 134a1ad commit 7ae0ecf
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 2 deletions.
2 changes: 1 addition & 1 deletion FlagEmbedding/abc/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __call__(
dataset_name: Optional[str] = None,
**kwargs,
):
"""Called to the whole evaluation process.
"""This is called during the evaluation process.
Args:
splits (Union[str, List[str]]): Splits of datasets.
Expand Down
47 changes: 47 additions & 0 deletions FlagEmbedding/abc/evaluation/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@


class AbsEvalRunner:
"""
Abstract class of evaluation runner.
Args:
eval_args (AbsEvalArgs): :class:AbsEvalArgs object with the evaluation arguments.
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.
"""
def __init__(
self,
eval_args: AbsEvalArgs,
Expand All @@ -28,6 +35,15 @@ def __init__(

@staticmethod
def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]:
"""Get the embedding and reranker model
Args:
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.
Returns:
Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: A :class:FlagAutoModel object of embedding model, and
:class:FlagAutoReranker object of reranker model if path provided.
"""
embedder = FlagAutoModel.from_finetuned(
model_name_or_path=model_args.embedder_name_or_path,
model_class=model_args.embedder_model_class,
Expand Down Expand Up @@ -74,6 +90,12 @@ def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagA
return embedder, reranker

def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]:
"""Load retriever and reranker for evaluation
Returns:
Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a
:class:EvalReranker object if reranker provided.
"""
embedder, reranker = self.get_models(self.model_args)
retriever = EvalDenseRetriever(
embedder,
Expand All @@ -85,6 +107,11 @@ def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalRer
return retriever, reranker

def load_data_loader(self) -> AbsEvalDataLoader:
"""Load the data loader
Returns:
AbsEvalDataLoader: Data loader object for that specific task.
"""
data_loader = AbsEvalDataLoader(
eval_name=self.eval_args.eval_name,
dataset_dir=self.eval_args.dataset_dir,
Expand All @@ -95,6 +122,11 @@ def load_data_loader(self) -> AbsEvalDataLoader:
return data_loader

def load_evaluator(self) -> AbsEvaluator:
"""Load the evaluator for evaluation
Returns:
AbsEvaluator: the evaluator to run the evaluation.
"""
evaluator = AbsEvaluator(
eval_name=self.eval_args.eval_name,
data_loader=self.data_loader,
Expand All @@ -109,6 +141,18 @@ def evaluate_metrics(
output_path: str = "./eval_dev_results.md",
metrics: Union[str, List[str]] = ["ndcg_at_10", "recall_at_10"]
):
"""Evaluate the provided metrics and write the results.
Args:
search_results_save_dir (str): Path to save the search results.
output_method (str, optional): Output results to `json` or `markdown`. Defaults to "markdown".
output_path (str, optional): Path to write the output. Defaults to "./eval_dev_results.md".
metrics (Union[str, List[str]], optional): metrics to use. Defaults to ["ndcg_at_10", "recall_at_10"].
Raises:
FileNotFoundError: Eval results not found
ValueError: Invalid output method
"""
eval_results_dict = {}
for model_name in sorted(os.listdir(search_results_save_dir)):
model_search_results_save_dir = os.path.join(search_results_save_dir, model_name)
Expand Down Expand Up @@ -136,6 +180,9 @@ def evaluate_metrics(
raise ValueError(f"Invalid output method: {output_method}. Available methods: ['json', 'markdown']")

def run(self):
"""
Run the whole evaluation.
"""
if self.eval_args.dataset_names is None:
dataset_names = self.data_loader.available_dataset_names()
else:
Expand Down
11 changes: 10 additions & 1 deletion FlagEmbedding/abc/evaluation/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@


class EvalRetriever(ABC):
"""
This is the base class for retriever.
"""
def __init__(self, embedder: AbsEmbedder, search_top_k: int = 1000, overwrite: bool = False):
self.embedder = embedder
self.search_top_k = search_top_k
Expand Down Expand Up @@ -45,7 +48,7 @@ def __call__(
**kwargs,
) -> Dict[str, Dict[str, float]]:
"""
This is called during the retrieval process.
Abstract method to be overrode. This is called during the retrieval process.
Parameters:
corpus: Dict[str, Dict[str, Any]]: Corpus of documents.
Expand All @@ -63,6 +66,9 @@ def __call__(


class EvalDenseRetriever(EvalRetriever):
"""
Child class of :class:EvalRetriever for dense retrieval.
"""
def __call__(
self,
corpus: Dict[str, Dict[str, Any]],
Expand Down Expand Up @@ -144,6 +150,9 @@ def __call__(


class EvalReranker:
"""
Class for reranker during evaluation.
"""
def __init__(self, reranker: AbsReranker, rerank_top_k: int = 100):
self.reranker = reranker
self.rerank_top_k = rerank_top_k
Expand Down
41 changes: 41 additions & 0 deletions FlagEmbedding/abc/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ def evaluate_mrr(
results: Dict[str, Dict[str, float]],
k_values: List[int],
) -> Tuple[Dict[str, float]]:
"""Compute mean reciprocal rank (MRR).
Args:
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
results (Dict[str, Dict[str, float]]): Search results to evaluate.
k_values (List[int]): Cutoffs.
Returns:
Tuple[Dict[str, float]]: MRR results at provided k values.
"""
mrr = defaultdict(list)

k_max, top_hits = max(k_values), {}
Expand Down Expand Up @@ -53,6 +63,17 @@ def evaluate_metrics(
Dict[str, float],
Dict[str, float],
]:
"""Evaluate the main metrics.
Args:
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
results (Dict[str, Dict[str, float]]): Search results to evaluate.
k_values (List[int]): Cutoffs.
Returns:
Tuple[ Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], ]: Results of different metrics at
different provided k values.
"""
all_ndcgs, all_aps, all_recalls, all_precisions = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)

map_string = "map_cut." + ",".join([str(k) for k in k_values])
Expand Down Expand Up @@ -93,6 +114,17 @@ def index(
load_path: Optional[str] = None,
device: Optional[str] = None
):
"""Create and add embeddings into a Faiss index.
Args:
index_factory (str, optional): Type of Faiss index to create. Defaults to "Flat".
corpus_embeddings (Optional[np.ndarray], optional): The embedding vectors of the corpus. Defaults to None.
load_path (Optional[str], optional): Path to load embeddings from. Defaults to None.
device (Optional[str], optional): Device to hold Faiss index. Defaults to None.
Returns:
faiss.Index: The Faiss index that contains all the corpus embeddings.
"""
if corpus_embeddings is None:
corpus_embeddings = np.load(load_path)

Expand Down Expand Up @@ -127,6 +159,15 @@ def search(
"""
1. Encode queries into dense embeddings;
2. Search through faiss index
Args:
faiss_index (faiss.Index): The Faiss index that contains all the corpus embeddings.
k (int, optional): Top k numbers of closest neighbours. Defaults to 100.
query_embeddings (Optional[np.ndarray], optional): The embedding vectors of queries. Defaults to None.
load_path (Optional[str], optional): Path to load embeddings from. Defaults to None.
Returns:
Tuple[np.ndarray, np.ndarray]: The scores of search results and their corresponding indices.
"""
if query_embeddings is None:
query_embeddings = np.load(load_path)
Expand Down

0 comments on commit 7ae0ecf

Please sign in to comment.