Skip to content

Commit

Permalink
feat: add CustomRetrievalEvaluator as a mteb wrapper + update `Co…
Browse files Browse the repository at this point in the history
…lModelTraining`
  • Loading branch information
tonywu71 committed Sep 10, 2024
1 parent 254e5fb commit 67135cc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
11 changes: 7 additions & 4 deletions colpali_engine/trainer/colmodel_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ColbertPairwiseNegativeCELoss,
)
from colpali_engine.trainer.contrastive_trainer import ContrastiveNegativeTrainer, ContrastiveTrainer
from colpali_engine.trainer.eval_utils import CustomRetrievalEvaluator
from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary


Expand Down Expand Up @@ -117,13 +118,14 @@ def __init__(self, config: ColModelTrainingConfig) -> None:
max_length=self.config.max_length,
)
self.current_git_hash = os.popen("git rev-parse HEAD").read().strip()
self.retriever_evaluator = RetrievalScorer(
self.retriever_scorer = RetrievalScorer(
is_multi_vector=(
isinstance(self.config.loss_func, ColbertLoss)
or isinstance(self.config.loss_func, ColbertPairwiseCELoss)
or isinstance(self.config.loss_func, ColbertPairwiseNegativeCELoss)
)
)
self.retrieval_evaluator = CustomRetrievalEvaluator()

def train(self) -> None:
if isinstance(self.collator, HardNegCollator):
Expand Down Expand Up @@ -231,7 +233,7 @@ def eval_dataset(self, test_dataset):
qs.extend(list(torch.unbind(query.to("cpu"))))

print("Embeddings computed, evaluating")
scores = self.retriever_evaluator.evaluate(qs, ps)
scores = self.retriever_scorer.evaluate(qs, ps)
# scores is 2d array of shape (n_queries, n_docs)
# turn it into a dict
results = {}
Expand All @@ -242,8 +244,9 @@ def eval_dataset(self, test_dataset):
}

# evaluate
metrics = self.retriever_evaluator.compute_metrics(relevant_docs, results)
print(metrics)
metrics = self.retrieval_evaluator.compute_mteb_metrics(relevant_docs, results)
print("MTEB metrics:", metrics)

return metrics

def eval(self) -> None:
Expand Down
37 changes: 37 additions & 0 deletions colpali_engine/trainer/eval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Dict

from mteb.evaluation.evaluators import RetrievalEvaluator


class CustomRetrievalEvaluator(RetrievalEvaluator):
"""
Wrapper class for the MTEB retrieval evaluator.
"""

def compute_mteb_metrics(
self,
relevant_docs: Dict[str, dict[str, int]],
results: Dict[str, dict[str, float]],
**kwargs,
) -> Dict[str, float]:
"""
Compute the MTEB retrieval metrics.
"""
ndcg, _map, recall, precision, naucs = self.evaluate(
relevant_docs,
results,
self.k_values,
ignore_identical_ids=kwargs.get("ignore_identical_ids", True),
)

mrr = self.evaluate_custom(relevant_docs, results, self.k_values, "mrr")

scores = {
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()},
**{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()},
}
return scores

0 comments on commit 67135cc

Please sign in to comment.