From 951bbcf6ada9cade85ece9cc4339402336413ba2 Mon Sep 17 00:00:00 2001 From: "Bwook (Byoungwook) Kim" Date: Tue, 30 Apr 2024 14:53:36 +0900 Subject: [PATCH] Add mrr retrieval metric (#390) * add mrr not complete just commit * add retrieval_ndcg_metric * add retrieval_mrr_metric * use hits in ndcg * add next function at mrr * edit ndcg solution * change logic ndcg * change logic mrr * add metric funcs * add metric funcs --------- Co-authored-by: Jeffrey (Dongkyu) Kim --- autorag/evaluate/metric/__init__.py | 2 +- autorag/evaluate/metric/retrieval.py | 18 +++++++++++++++ autorag/evaluate/retrieval.py | 3 ++- .../evaluate/metric/test_retrieval_metric.py | 23 +++++++++++++------ 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/autorag/evaluate/metric/__init__.py b/autorag/evaluate/metric/__init__.py index 21cb1e621..eea0a44da 100644 --- a/autorag/evaluate/metric/__init__.py +++ b/autorag/evaluate/metric/__init__.py @@ -1,3 +1,3 @@ from .generation import bleu, meteor, rouge, sem_score, g_eval, bert_score -from .retrieval import retrieval_f1, retrieval_recall, retrieval_precision, retrieval_ndcg +from .retrieval import retrieval_f1, retrieval_recall, retrieval_precision, retrieval_mrr, retrieval_ndcg from .retrieval_contents import retrieval_token_f1, retrieval_token_precision, retrieval_token_recall diff --git a/autorag/evaluate/metric/retrieval.py b/autorag/evaluate/metric/retrieval.py index 6d9b148c0..6ca460585 100644 --- a/autorag/evaluate/metric/retrieval.py +++ b/autorag/evaluate/metric/retrieval.py @@ -69,3 +69,21 @@ def retrieval_ndcg(gt: List[List[str]], pred: List[str]): ndcg = dcg / idcg if idcg > 0 else 0 return ndcg + + +@retrieval_metric +def retrieval_mrr(gt: List[List[str]], pred: List[str]) -> float: + """ + Reciprocal Rank (RR) is the reciprocal of the rank of the first relevant item. + Mean of RR in whole queries is MRR. + """ + # Flatten the ground truth list of lists into a single set of relevant documents + gt_sets = [frozenset(g) for g in gt] + + rr_list = [] + for gt_set in gt_sets: + for i, pred_id in enumerate(pred): + if pred_id in gt_set: + rr_list.append(1.0 / (i + 1)) + break + return sum(rr_list) / len(gt_sets) if rr_list else 0.0 diff --git a/autorag/evaluate/retrieval.py b/autorag/evaluate/retrieval.py index 76ddc51c8..6a9fab205 100644 --- a/autorag/evaluate/retrieval.py +++ b/autorag/evaluate/retrieval.py @@ -4,7 +4,7 @@ import pandas as pd -from autorag.evaluate.metric import retrieval_recall, retrieval_precision, retrieval_f1, retrieval_ndcg +from autorag.evaluate.metric import retrieval_recall, retrieval_precision, retrieval_f1, retrieval_ndcg, retrieval_mrr def evaluate_retrieval(retrieval_gt: List[List[List[str]]], metrics: List[str]): @@ -27,6 +27,7 @@ def wrapper(*args, **kwargs) -> pd.DataFrame: retrieval_precision.__name__: retrieval_precision, retrieval_f1.__name__: retrieval_f1, retrieval_ndcg.__name__: retrieval_ndcg, + retrieval_mrr.__name__: retrieval_mrr, } metric_scores = {} diff --git a/tests/autorag/evaluate/metric/test_retrieval_metric.py b/tests/autorag/evaluate/metric/test_retrieval_metric.py index 1180bb2c9..7e3a246c5 100644 --- a/tests/autorag/evaluate/metric/test_retrieval_metric.py +++ b/tests/autorag/evaluate/metric/test_retrieval_metric.py @@ -1,6 +1,6 @@ import pytest -from autorag.evaluate.metric import retrieval_f1, retrieval_precision, retrieval_recall, retrieval_ndcg +from autorag.evaluate.metric import retrieval_f1, retrieval_precision, retrieval_recall, retrieval_ndcg, retrieval_mrr retrieval_gt = [ [['test-1', 'test-2'], ['test-3']], @@ -9,7 +9,8 @@ [['test-11'], ['test-12'], ['test-13']], [['test-14']], [[]], - [['']] + [['']], + [['test-15']] ] pred = [ @@ -19,33 +20,41 @@ ['test-13', 'test-12', 'pred-10', 'pred-11'], # recall: 2/3, precision: 0.5, f1: 4/7 ['test-14', 'pred-12'], # recall: 1.0, precision: 0.5, f1: 2/3 ['pred-13'], # retrieval_gt is empty so not counted - ['pred-14'] # retrieval_gt is empty so not counted + ['pred-14'], # retrieval_gt is empty so not counted + ['pred-15', 'pred-16', 'test-15'] # recall:1, precision: 1/3, f1: 0.5 ] def test_retrieval_f1(): - solution = [0.5, 2 / 7, 2 / 5, 4 / 7, 2 / 3, None, None] + solution = [0.5, 2 / 7, 2 / 5, 4 / 7, 2 / 3, None, None, 0.5] result = retrieval_f1(retrieval_gt=retrieval_gt, pred_ids=pred) for gt, res in zip(solution, result): assert gt == pytest.approx(res, rel=1e-4) def test_retrieval_recall(): - solution = [0.5, 1 / 3, 1, 2 / 3, 1, None, None] + solution = [0.5, 1 / 3, 1, 2 / 3, 1, None, None, 1] result = retrieval_recall(retrieval_gt=retrieval_gt, pred_ids=pred) for gt, res in zip(solution, result): assert gt == pytest.approx(res, rel=1e-4) def test_retrieval_precision(): - solution = [0.5, 0.25, 0.25, 0.5, 0.5, None, None] + solution = [0.5, 0.25, 0.25, 0.5, 0.5, None, None, 1 / 3] result = retrieval_precision(retrieval_gt=retrieval_gt, pred_ids=pred) for gt, res in zip(solution, result): assert gt == pytest.approx(res, rel=1e-4) def test_retrieval_ndcg(): - solution = [0.7039180890341347, 0.3903800499921017, 0.6131471927654584, 0.7653606369886217, 1, None, None] + solution = [0.7039180890341347, 0.3903800499921017, 0.6131471927654584, 0.7653606369886217, 1, None, None, 0.5] result = retrieval_ndcg(retrieval_gt=retrieval_gt, pred_ids=pred) for gt, res in zip(solution, result): assert gt == pytest.approx(res, rel=1e-4) + + +def test_retrieval_mrr(): + solution = [1 / 2, 1 / 3, 1, 1 / 2, 1, None, None, 1 / 3] + result = retrieval_mrr(retrieval_gt=retrieval_gt, pred_ids=pred) + for gt, res in zip(solution, result): + assert gt == pytest.approx(res, rel=1e-4)