From 67135cc14749d13aa4dc75ce9cbe559f8526991b Mon Sep 17 00:00:00 2001 From: Tony Wu <28306721+tonywu71@users.noreply.github.com> Date: Tue, 10 Sep 2024 11:45:04 +0200 Subject: [PATCH] feat: add `CustomRetrievalEvaluator` as a `mteb` wrapper + update `ColModelTraining` --- colpali_engine/trainer/colmodel_training.py | 11 +++--- colpali_engine/trainer/eval_utils.py | 37 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 colpali_engine/trainer/eval_utils.py diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 7b351397..9e0eebce 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -25,6 +25,7 @@ ColbertPairwiseNegativeCELoss, ) from colpali_engine.trainer.contrastive_trainer import ContrastiveNegativeTrainer, ContrastiveTrainer +from colpali_engine.trainer.eval_utils import CustomRetrievalEvaluator from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary @@ -117,13 +118,14 @@ def __init__(self, config: ColModelTrainingConfig) -> None: max_length=self.config.max_length, ) self.current_git_hash = os.popen("git rev-parse HEAD").read().strip() - self.retriever_evaluator = RetrievalScorer( + self.retriever_scorer = RetrievalScorer( is_multi_vector=( isinstance(self.config.loss_func, ColbertLoss) or isinstance(self.config.loss_func, ColbertPairwiseCELoss) or isinstance(self.config.loss_func, ColbertPairwiseNegativeCELoss) ) ) + self.retrieval_evaluator = CustomRetrievalEvaluator() def train(self) -> None: if isinstance(self.collator, HardNegCollator): @@ -231,7 +233,7 @@ def eval_dataset(self, test_dataset): qs.extend(list(torch.unbind(query.to("cpu")))) print("Embeddings computed, evaluating") - scores = self.retriever_evaluator.evaluate(qs, ps) + scores = self.retriever_scorer.evaluate(qs, ps) # scores is 2d array of shape (n_queries, n_docs) # turn it into a dict results = {} @@ -242,8 +244,9 @@ def eval_dataset(self, test_dataset): } # evaluate - metrics = self.retriever_evaluator.compute_metrics(relevant_docs, results) - print(metrics) + metrics = self.retrieval_evaluator.compute_mteb_metrics(relevant_docs, results) + print("MTEB metrics:", metrics) + return metrics def eval(self) -> None: diff --git a/colpali_engine/trainer/eval_utils.py b/colpali_engine/trainer/eval_utils.py new file mode 100644 index 00000000..d928bcb4 --- /dev/null +++ b/colpali_engine/trainer/eval_utils.py @@ -0,0 +1,37 @@ +from typing import Dict + +from mteb.evaluation.evaluators import RetrievalEvaluator + + +class CustomRetrievalEvaluator(RetrievalEvaluator): + """ + Wrapper class for the MTEB retrieval evaluator. + """ + + def compute_mteb_metrics( + self, + relevant_docs: Dict[str, dict[str, int]], + results: Dict[str, dict[str, float]], + **kwargs, + ) -> Dict[str, float]: + """ + Compute the MTEB retrieval metrics. + """ + ndcg, _map, recall, precision, naucs = self.evaluate( + relevant_docs, + results, + self.k_values, + ignore_identical_ids=kwargs.get("ignore_identical_ids", True), + ) + + mrr = self.evaluate_custom(relevant_docs, results, self.k_values, "mrr") + + scores = { + **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, + **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, + **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, + **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, + **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()}, + **{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()}, + } + return scores