Skip to content

Commit

Permalink
Add test code if QA data was created from a corpus (#382)
Browse files Browse the repository at this point in the history
* add validate_qa_from_corpus_dataset

* add validate_qa_from_corpus_dataset at evaluator class init
  • Loading branch information
bwook00 authored Apr 27, 2024
1 parent 2f2e270 commit 97476b4
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 4 deletions.
4 changes: 3 additions & 1 deletion autorag/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from autorag.nodes.retrieval.vectordb import vectordb_ingest
from autorag.schema import Node
from autorag.schema.node import module_type_exists, extract_values_from_nodes
from autorag.utils import cast_qa_dataset, cast_corpus_dataset
from autorag.utils import cast_qa_dataset, cast_corpus_dataset, validate_qa_from_corpus_dataset
from autorag.utils.util import load_summary_file, convert_string_to_tuple_in_dict, convert_env_in_dict, explode

logger = logging.getLogger("AutoRAG")
Expand Down Expand Up @@ -63,6 +63,8 @@ def __init__(self, qa_data_path: str, corpus_data_path: str, project_dir: Option
if not os.path.exists(self.project_dir):
os.makedirs(self.project_dir)

validate_qa_from_corpus_dataset(self.qa_data, self.corpus_data)

# copy dataset to project directory
if not os.path.exists(os.path.join(self.project_dir, 'data')):
os.makedirs(os.path.join(self.project_dir, 'data'))
Expand Down
3 changes: 2 additions & 1 deletion autorag/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .preprocess import validate_qa_dataset, validate_corpus_dataset, cast_qa_dataset, cast_corpus_dataset
from .preprocess import (validate_qa_dataset, validate_corpus_dataset, cast_qa_dataset, cast_corpus_dataset,
validate_qa_from_corpus_dataset)
from .util import fetch_contents, result_to_dataframe, sort_by_scores
15 changes: 15 additions & 0 deletions autorag/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,18 @@ def make_prev_next_id_metadata(x, id_type: str):
assert all('next_id' in metadata for metadata in df['metadata']), "Every metadata must have a next_id key."

return df


def validate_qa_from_corpus_dataset(qa_df: pd.DataFrame, corpus_df: pd.DataFrame):
qa_ids = []
for retrieval_gt in qa_df['retrieval_gt'].tolist():
if isinstance(retrieval_gt, list) and (retrieval_gt[0] != [] or any(bool(g) is True for g in retrieval_gt)):
for gt in retrieval_gt:
qa_ids.extend(gt)
elif isinstance(retrieval_gt, np.ndarray) and retrieval_gt[0].size > 0:
for gt in retrieval_gt:
qa_ids.extend(gt)

no_exist_ids = list(filter(lambda qa_id: corpus_df[corpus_df['doc_id'] == qa_id].empty, qa_ids))

assert len(no_exist_ids) == 0, f"{len(no_exist_ids)} doc_ids in retrieval_gt do not exist in corpus_df."
25 changes: 23 additions & 2 deletions tests/autorag/utils/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import os
import pathlib
from datetime import datetime

import pandas as pd
import pytest

from autorag.utils import validate_qa_dataset, validate_corpus_dataset, cast_qa_dataset, cast_corpus_dataset
from autorag.utils import (validate_qa_dataset, validate_corpus_dataset, cast_qa_dataset, cast_corpus_dataset,
validate_qa_from_corpus_dataset)


@pytest.fixture
def qa_df():
return pd.DataFrame({
'qid': ['id1', 'id2'],
'query': ['query1', 'query2'],
'retrieval_gt': ['answer1', 'answer2'],
'retrieval_gt': [[['doc1', 'doc3'], ['doc2']], [[]]],
'generation_gt': 'answer1',
})

Expand Down Expand Up @@ -75,3 +78,21 @@ def test_cast_corpus_dataset(corpus_df):
assert casted_df['metadata'].iloc[1]['prev_id'] is None
assert casted_df['metadata'].iloc[1]['next_id'] is None
assert casted_df['metadata'].iloc[2]['last_modified_datetime'] == datetime(2022, 12, 1, 3, 4, 5)


def test_validate_qa_from_corpus_dataset(qa_df, corpus_df):
validate_qa_from_corpus_dataset(qa_df, corpus_df)

with pytest.raises(AssertionError) as excinfo:
invalid_df = qa_df.copy()
invalid_df.at[0, 'retrieval_gt'] = [['answer1', 'answer2'], ['answer3']]
validate_qa_from_corpus_dataset(invalid_df, corpus_df)
assert "3 doc_ids in retrieval_gt do not exist in corpus_df." in str(excinfo.value)

root_dir = pathlib.PurePath(os.path.dirname(os.path.realpath(__file__))).parent.parent
project_dir = os.path.join(root_dir, "resources", "sample_project")
qa_parquet = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))

with pytest.raises(AssertionError) as excinfo:
validate_qa_from_corpus_dataset(qa_parquet, corpus_df)
assert "10 doc_ids in retrieval_gt do not exist in corpus_df." in str(excinfo.value)

0 comments on commit 97476b4

Please sign in to comment.