Skip to content

Commit

Permalink
Add mrr retrieval metric (#390)
Browse files Browse the repository at this point in the history
* add mrr not complete just commit

* add retrieval_ndcg_metric

* add retrieval_mrr_metric

* use hits in ndcg

* add next function at mrr

* edit ndcg solution

* change logic ndcg

* change logic mrr

* add metric funcs

* add metric funcs

---------

Co-authored-by: Jeffrey (Dongkyu) Kim <[email protected]>
  • Loading branch information
bwook00 and vkehfdl1 authored Apr 30, 2024
1 parent e1e244c commit 951bbcf
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 9 deletions.
2 changes: 1 addition & 1 deletion autorag/evaluate/metric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .generation import bleu, meteor, rouge, sem_score, g_eval, bert_score
from .retrieval import retrieval_f1, retrieval_recall, retrieval_precision, retrieval_ndcg
from .retrieval import retrieval_f1, retrieval_recall, retrieval_precision, retrieval_mrr, retrieval_ndcg
from .retrieval_contents import retrieval_token_f1, retrieval_token_precision, retrieval_token_recall
18 changes: 18 additions & 0 deletions autorag/evaluate/metric/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,21 @@ def retrieval_ndcg(gt: List[List[str]], pred: List[str]):

ndcg = dcg / idcg if idcg > 0 else 0
return ndcg


@retrieval_metric
def retrieval_mrr(gt: List[List[str]], pred: List[str]) -> float:
"""
Reciprocal Rank (RR) is the reciprocal of the rank of the first relevant item.
Mean of RR in whole queries is MRR.
"""
# Flatten the ground truth list of lists into a single set of relevant documents
gt_sets = [frozenset(g) for g in gt]

rr_list = []
for gt_set in gt_sets:
for i, pred_id in enumerate(pred):
if pred_id in gt_set:
rr_list.append(1.0 / (i + 1))
break
return sum(rr_list) / len(gt_sets) if rr_list else 0.0
3 changes: 2 additions & 1 deletion autorag/evaluate/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd

from autorag.evaluate.metric import retrieval_recall, retrieval_precision, retrieval_f1, retrieval_ndcg
from autorag.evaluate.metric import retrieval_recall, retrieval_precision, retrieval_f1, retrieval_ndcg, retrieval_mrr


def evaluate_retrieval(retrieval_gt: List[List[List[str]]], metrics: List[str]):
Expand All @@ -27,6 +27,7 @@ def wrapper(*args, **kwargs) -> pd.DataFrame:
retrieval_precision.__name__: retrieval_precision,
retrieval_f1.__name__: retrieval_f1,
retrieval_ndcg.__name__: retrieval_ndcg,
retrieval_mrr.__name__: retrieval_mrr,
}

metric_scores = {}
Expand Down
23 changes: 16 additions & 7 deletions tests/autorag/evaluate/metric/test_retrieval_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from autorag.evaluate.metric import retrieval_f1, retrieval_precision, retrieval_recall, retrieval_ndcg
from autorag.evaluate.metric import retrieval_f1, retrieval_precision, retrieval_recall, retrieval_ndcg, retrieval_mrr

retrieval_gt = [
[['test-1', 'test-2'], ['test-3']],
Expand All @@ -9,7 +9,8 @@
[['test-11'], ['test-12'], ['test-13']],
[['test-14']],
[[]],
[['']]
[['']],
[['test-15']]
]

pred = [
Expand All @@ -19,33 +20,41 @@
['test-13', 'test-12', 'pred-10', 'pred-11'], # recall: 2/3, precision: 0.5, f1: 4/7
['test-14', 'pred-12'], # recall: 1.0, precision: 0.5, f1: 2/3
['pred-13'], # retrieval_gt is empty so not counted
['pred-14'] # retrieval_gt is empty so not counted
['pred-14'], # retrieval_gt is empty so not counted
['pred-15', 'pred-16', 'test-15'] # recall:1, precision: 1/3, f1: 0.5
]


def test_retrieval_f1():
solution = [0.5, 2 / 7, 2 / 5, 4 / 7, 2 / 3, None, None]
solution = [0.5, 2 / 7, 2 / 5, 4 / 7, 2 / 3, None, None, 0.5]
result = retrieval_f1(retrieval_gt=retrieval_gt, pred_ids=pred)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_recall():
solution = [0.5, 1 / 3, 1, 2 / 3, 1, None, None]
solution = [0.5, 1 / 3, 1, 2 / 3, 1, None, None, 1]
result = retrieval_recall(retrieval_gt=retrieval_gt, pred_ids=pred)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_precision():
solution = [0.5, 0.25, 0.25, 0.5, 0.5, None, None]
solution = [0.5, 0.25, 0.25, 0.5, 0.5, None, None, 1 / 3]
result = retrieval_precision(retrieval_gt=retrieval_gt, pred_ids=pred)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_ndcg():
solution = [0.7039180890341347, 0.3903800499921017, 0.6131471927654584, 0.7653606369886217, 1, None, None]
solution = [0.7039180890341347, 0.3903800499921017, 0.6131471927654584, 0.7653606369886217, 1, None, None, 0.5]
result = retrieval_ndcg(retrieval_gt=retrieval_gt, pred_ids=pred)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_mrr():
solution = [1 / 2, 1 / 3, 1, 1 / 2, 1, None, None, 1 / 3]
result = retrieval_mrr(retrieval_gt=retrieval_gt, pred_ids=pred)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)

0 comments on commit 951bbcf

Please sign in to comment.