Skip to content

Commit

Permalink
Merge pull request #371 from Marker-Inc-Korea/Feature/#368
Browse files Browse the repository at this point in the history
Fix retrieval and retrieval token metrics to use unaswerable
bwook00 authored Apr 26, 2024
2 parents a598015 + 6093e42 commit dd842db
Showing 11 changed files with 66 additions and 24 deletions.
8 changes: 7 additions & 1 deletion autorag/evaluate/metric/retrieval.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,13 @@
def retrieval_metric(func):
@functools.wraps(func)
def wrapper(retrieval_gt: List[List[List[str]]], pred_ids: List[List[str]]) -> List[float]:
return list(map(lambda x: func(x[0], x[1]), zip(retrieval_gt, pred_ids)))
results = []
for gt, pred in zip(retrieval_gt, pred_ids):
if gt == [[]] or any(bool(g_) is False for g in gt for g_ in g):
results.append(None)
else:
results.append(func(gt, pred))
return results

return wrapper

8 changes: 7 additions & 1 deletion autorag/evaluate/metric/retrieval_contents.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,13 @@
def retrieval_contents_metric(func):
@functools.wraps(func)
def wrapper(gt_contents: List[List[str]], pred_contents: List[List[str]]) -> List[float]:
return list(map(lambda x: func(x[0], x[1]), zip(gt_contents, pred_contents)))
results = []
for gt, pred in zip(gt_contents, pred_contents):
if gt == [] or any(bool(g) is False for g in gt):
results.append(None)
else:
results.append(func(gt, pred))
return results

return wrapper

2 changes: 2 additions & 0 deletions autorag/nodes/passageaugmenter/run.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,8 @@ def run_passage_augmenter_node(modules: List[Callable],
os.makedirs(node_line_dir)
project_dir = pathlib.PurePath(node_line_dir).parent.parent
retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist()
retrieval_gt = [[[str(uuid) for uuid in sub_array] if sub_array.size > 0 else [] for sub_array in inner_array]
for inner_array in retrieval_gt]

results, execution_times = zip(*map(lambda task: measure_speed(
task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))
2 changes: 2 additions & 0 deletions autorag/nodes/passagefilter/run.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,8 @@ def run_passage_filter_node(modules: List[Callable],
os.makedirs(node_line_dir)
project_dir = pathlib.PurePath(node_line_dir).parent.parent
retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist()
retrieval_gt = [[[str(uuid) for uuid in sub_array] if sub_array.size > 0 else [] for sub_array in inner_array]
for inner_array in retrieval_gt]

results, execution_times = zip(*map(lambda task: measure_speed(
task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))
2 changes: 2 additions & 0 deletions autorag/nodes/passagereranker/run.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,8 @@ def run_passage_reranker_node(modules: List[Callable],
os.makedirs(node_line_dir)
project_dir = pathlib.PurePath(node_line_dir).parent.parent
retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist()
retrieval_gt = [[[str(uuid) for uuid in sub_array] if sub_array.size > 0 else [] for sub_array in inner_array]
for inner_array in retrieval_gt]

results, execution_times = zip(*map(lambda task: measure_speed(
task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))
3 changes: 3 additions & 0 deletions autorag/nodes/retrieval/run.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,9 @@ def run_retrieval_node(modules: List[Callable],
os.makedirs(node_line_dir)
project_dir = pathlib.PurePath(node_line_dir).parent.parent
retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist()
retrieval_gt = [[[str(uuid) for uuid in sub_array] if sub_array.size > 0 else [] for sub_array in inner_array]
for inner_array in retrieval_gt]

save_dir = os.path.join(node_line_dir, "retrieval") # node name
if not os.path.exists(save_dir):
os.makedirs(save_dir)
25 changes: 17 additions & 8 deletions autorag/utils/util.py
Original file line number Diff line number Diff line change
@@ -17,18 +17,27 @@

def fetch_contents(corpus_data: pd.DataFrame, ids: List[List[str]],
column_name: str = 'contents') -> List[List[Any]]:
flat_ids = itertools.chain.from_iterable(ids)
contents = list(map(lambda x: corpus_data.loc[lambda row: row['doc_id'] == x][column_name].values[0], flat_ids))

result = []
idx = 0
for sublist in ids:
result.append(contents[idx:idx + len(sublist)])
idx += len(sublist)
def fetch_contents_pure(ids: List[str], corpus_data: pd.DataFrame, column_name: str):
return list(map(lambda x: fetch_one_content(corpus_data, x, column_name), ids))

result = flatten_apply(fetch_contents_pure, ids, corpus_data=corpus_data, column_name=column_name)
return result


def fetch_one_content(corpus_data: pd.DataFrame, id_: str,
column_name: str = 'contents') -> Any:
if isinstance(id_, str):
if id_ in ['', ""]:
return None
fetch_result = corpus_data[corpus_data['doc_id'] == id_]
if fetch_result.empty:
raise ValueError(f"doc_id: {id_} not found in corpus_data.")
else:
return fetch_result[column_name].iloc[0]
else:
return None


def result_to_dataframe(column_names: List[str]):
"""
Decorator for converting results to pd.DataFrame.
14 changes: 9 additions & 5 deletions tests/autorag/evaluate/metric/test_retrieval_contents_metric.py
Original file line number Diff line number Diff line change
@@ -5,11 +5,15 @@

gt = [
['Enough for drinking water', 'Just looking for a water bottle'],
['Do you want to buy some?']
['Do you want to buy some?'],
[''],
[]
]
pred = [
['Enough for mixing water', 'I want to do a nothing', 'Just looking is a very healthy'],
['Do you want to buy some?', 'I want to buy some', 'I want to buy some water']
['Do you want to buy some?', 'I want to buy some', 'I want to buy some water'],
['Who is son? He is great player in the world'],
['i love havertz', 'i love kai havertz']
]


@@ -33,14 +37,14 @@ def test_retrieval_token_f1():
assert f1 == pytest.approx(0.797979, rel=0.001)

result_f1 = retrieval_token_f1(gt_contents=gt, pred_contents=pred)
assert result_f1 == pytest.approx([0.38333, 0.797979], rel=0.001)
assert result_f1 == pytest.approx([0.38333, 0.797979, None, None], rel=0.001)


def test_retrieval_token_precision():
result_precision = retrieval_token_precision(gt_contents=gt, pred_contents=pred)
assert result_precision == pytest.approx([0.383333, 0.8222222], rel=0.001)
assert result_precision == pytest.approx([0.383333, 0.8222222, None, None], rel=0.001)


def test_retrieval_token_recall():
result_recall = retrieval_token_recall(gt_contents=gt, pred_contents=pred)
assert result_recall == pytest.approx([0.383333, 0.777777], rel=0.001)
assert result_recall == pytest.approx([0.383333, 0.777777, None, None], rel=0.001)
16 changes: 8 additions & 8 deletions tests/autorag/evaluate/metric/test_retrieval_metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import math

import pytest

from autorag.evaluate.metric import retrieval_f1, retrieval_precision, retrieval_recall
@@ -10,7 +8,8 @@
[['test-9', 'test-10']],
[['test-11'], ['test-12'], ['test-13']],
[['test-14']],
[['test-15']],
[[]],
[['']]
]

pred = [
@@ -19,26 +18,27 @@
['test-9', 'pred-0', 'pred-8', 'pred-9'], # recall: 1.0, precision: 0.25, f1: 2/5
['test-13', 'test-12', 'pred-10', 'pred-11'], # recall: 2/3, precision: 0.5, f1: 4/7
['test-14', 'pred-12'], # recall: 1.0, precision: 0.5, f1: 2/3
['pred-13'], # recall: 0.0, precision: 0.0, f1: 0.0
['pred-13'], # retrieval_gt is empty so not counted
['pred-14'] # retrieval_gt is empty so not counted
]


def test_retrieval_f1():
solution = [0.5, 2 / 7, 2 / 5, 4 / 7, 2 / 3, 0.0]
solution = [0.5, 2 / 7, 2 / 5, 4 / 7, 2 / 3, None, None]
result = retrieval_f1(retrieval_gt=retrieval_gt, pred_ids=pred)
for gt, res in zip(solution, result):
assert math.isclose(gt, res, rel_tol=1e-4)
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_recall():
solution = [0.5, 1 / 3, 1, 2 / 3, 1, 0.0]
solution = [0.5, 1 / 3, 1, 2 / 3, 1, None, None]
result = retrieval_recall(retrieval_gt=retrieval_gt, pred_ids=pred)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_precision():
solution = [0.5, 0.25, 0.25, 0.5, 0.5, 0.0]
solution = [0.5, 0.25, 0.25, 0.5, 0.5, None, None]
result = retrieval_precision(retrieval_gt=retrieval_gt, pred_ids=pred)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
'query': ['query-1', 'query-2', 'query-3'],
'retrieval_gt': [
[['doc-1'], ['doc-2']],
[['doc-3'], ['doc-4']],
[[]],
[['doc-5'], ['doc-6']],
],
'generation_gt': [['generation-1'], ['generation-2'], ['generation-3']],
8 changes: 8 additions & 0 deletions tests/autorag/utils/test_util.py
Original file line number Diff line number Diff line change
@@ -70,6 +70,14 @@ def test_fetch_contents():
{'last_modified_datetime': datetime(2022, 1, 1, 0, 0, 0)}]
assert find_metadatas[1] == [{'last_modified_datetime': datetime(2022, 1, 2, 0, 0, 0)}]

find_empty = fetch_contents(corpus_data, [[], ['doc2']])
assert find_empty[0] == [None]
assert find_empty[1] == ['banana']

find_blank = fetch_contents(corpus_data, [[''], ['doc2']])
assert find_blank[0] == [None]
assert find_blank[1] == ['banana']


def test_load_summary_file(summary_path):
with pytest.raises(ValueError):

0 comments on commit dd842db

Please sign in to comment.