diff --git a/autorag/data/beta/extract_evidence.py b/autorag/data/beta/extract_evidence.py new file mode 100644 index 000000000..e77529736 --- /dev/null +++ b/autorag/data/beta/extract_evidence.py @@ -0,0 +1 @@ +# This module is about extracting evidence from the given retrieval gt passage diff --git a/autorag/data/beta/schema.py b/autorag/data/beta/schema.py index 8d2c359d4..b162fc316 100644 --- a/autorag/data/beta/schema.py +++ b/autorag/data/beta/schema.py @@ -1,9 +1,9 @@ import logging -from typing import Callable, Optional, Dict, Awaitable, Any +from typing import Callable, Optional, Dict, Awaitable, Any, Tuple, List import pandas as pd +from autorag.utils.util import process_batch, get_event_loop, fetch_contents from autorag.support import get_support_modules -from autorag.utils.util import process_batch, get_event_loop, fetch_contents logger = logging.getLogger("AutoRAG") @@ -68,6 +68,20 @@ def linked_raw(self) -> Raw: def linked_raw(self, raw: Raw): raise NotImplementedError("linked_raw is read-only.") + def to_parquet(self, save_path: str): + """ + Save the corpus to the AutoRAG compatible parquet file. + It is not for the data creation, for running AutoRAG. + If you want to save it directly, use the below code. + `corpus.data.to_parquet(save_path)` + + :param save_path: The path to save the corpus. + """ + if not save_path.endswith(".parquet"): + raise ValueError("save_path must be ended with .parquet") + save_df = self.data[["doc_id", "contents", "metadata"]].reset_index(drop=True) + save_df.to_parquet(save_path) + def batch_apply( self, fn: Callable[[Dict, Any], Awaitable[Dict]], batch_size: int = 32, **kwargs ) -> "Corpus": @@ -152,6 +166,26 @@ def make_retrieval_gt_contents(self) -> "QA": ) return self + def to_parquet(self, qa_save_path: str, corpus_save_path: str): + """ + Save the qa and corpus to the AutoRAG compatible parquet file. + It is not for the data creation, for running AutoRAG. + If you want to save it directly, use the below code. + `qa.data.to_parquet(save_path)` + + :param qa_save_path: The path to save the qa dataset. + :param corpus_save_path: The path to save the corpus. + """ + if not qa_save_path.endswith(".parquet"): + raise ValueError("save_path must be ended with .parquet") + if not corpus_save_path.endswith(".parquet"): + raise ValueError("save_path must be ended with .parquet") + save_df = self.data[ + ["qid", "query", "retrieval_gt", "generation_gt"] + ].reset_index(drop=True) + save_df.to_parquet(qa_save_path) + self.linked_corpus.to_parquet(corpus_save_path) + def update_corpus(self, new_corpus: Corpus) -> "QA": """ Update linked corpus. @@ -163,4 +197,105 @@ def update_corpus(self, new_corpus: Corpus) -> "QA": Must have valid `linked_raw` and `raw_id`, `raw_start_idx`, `raw_end_idx` columns. :return: The QA instance that updated linked corpus. """ - pass + self.data["evidence_path"] = ( + self.data["retrieval_gt"] + .apply( + lambda x: fetch_contents( + self.linked_corpus.data, + x, + column_name="path", + ) + ) + .tolist() + ) + self.data["evidence_page"] = self.data["retrieval_gt"].apply( + lambda x: list( + map( + lambda lst: list(map(lambda x: x.get("page", -1), lst)), + fetch_contents(self.linked_corpus.data, x, column_name="metadata"), + ) + ) + ) + if "evidence_start_end_idx" not in self.data.columns: + # make evidence start_end_idx + self.data["evidence_start_end_idx"] = ( + self.data["retrieval_gt"] + .apply( + lambda x: fetch_contents( + self.linked_corpus.data, + x, + column_name="start_end_idx", + ) + ) + .tolist() + ) + + # matching the new corpus with the old corpus + path_corpus_dict = QA.__make_path_corpus_dict(new_corpus.data) + new_retrieval_gt = self.data.apply( + lambda row: QA.__match_index_row( + row["evidence_start_end_idx"], + row["evidence_path"], + row["evidence_page"], + path_corpus_dict, + ), + axis=1, + ).tolist() + new_qa = self.data.copy(deep=True)[["qid", "query", "generation_gt"]] + new_qa["retrieval_gt"] = new_retrieval_gt + return QA(new_qa, new_corpus) + + @staticmethod + def __match_index(target_idx: Tuple[int, int], dst_idx: Tuple[int, int]) -> bool: + """ + Check if the target_idx is overlap by the dst_idx. + """ + target_start, target_end = target_idx + dst_start, dst_end = dst_idx + return ( + dst_start <= target_start <= dst_end or dst_start <= target_end <= dst_end + ) + + @staticmethod + def __match_index_row( + evidence_indices: List[List[Tuple[int, int]]], + evidence_paths: List[List[str]], + evidence_pages: List[List[int]], + path_corpus_dict: Dict, + ) -> List[List[str]]: + """ + Find the matched passage from new_corpus. + + :param evidence_indices: The evidence indices at the corresponding Raw. + Its shape is the same as the retrieval_gt. + :param evidence_paths: The evidence paths at the corresponding Raw. + Its shape is the same as the retrieval_gt. + :param path_corpus_dict: The key is the path name, and the value is the corpus dataframe that only contains the path in the key. + You can make it using `QA.__make_path_corpus_dict`. + :return: + """ + result = [] + for i, idx_list in enumerate(evidence_indices): + sub_result = [] + for j, idx in enumerate(idx_list): + path_corpus_df = path_corpus_dict[evidence_paths[i][j]] + if evidence_pages[i][j] >= 0: + path_corpus_df = path_corpus_df.loc[ + path_corpus_df["metadata"].apply(lambda x: x.get("page", -1)) + == evidence_pages[i][j] + ] + matched_corpus = path_corpus_df.loc[ + path_corpus_df["start_end_idx"].apply( + lambda x: QA.__match_index(idx, x) + ) + ] + sub_result.extend(matched_corpus["doc_id"].tolist()) + result.append(sub_result) + return result + + @staticmethod + def __make_path_corpus_dict(corpus_df: pd.DataFrame) -> Dict[str, pd.DataFrame]: + return { + path: corpus_df[corpus_df["path"] == path] + for path in corpus_df["path"].unique() + } diff --git a/tests/autorag/data/beta/test_schema.py b/tests/autorag/data/beta/test_schema.py index ebd0a3f51..f8f45698d 100644 --- a/tests/autorag/data/beta/test_schema.py +++ b/tests/autorag/data/beta/test_schema.py @@ -1,6 +1,8 @@ +import tempfile + import pandas as pd -from autorag.data.beta.schema import Raw, Corpus +from autorag.data.beta.schema import Raw, Corpus, QA from tests.autorag.data.beta.test_data_creation_piepline import initial_raw @@ -59,3 +61,126 @@ def test_raw_chunk(): origin_path in initial_raw.data["path"].tolist() for origin_path in corpus.data["path"].tolist() ) + + +def test_update_corpus(): + raw = Raw( + pd.DataFrame( + { + "texts": ["hello", "world", "jax"], + "path": ["path1", "path1", "path2"], + "page": [1, 2, -1], + "last_modified_datetime": [ + "2021-08-01", + "2021-08-02", + "2021-08-03", + ], + } + ) + ) + original_corpus = Corpus( + pd.DataFrame( + { + "doc_id": ["id1", "id2", "id3", "id4", "id5", "id6"], + "contents": ["hello", "world", "foo", "bar", "baz", "jax"], + "path": ["path1", "path1", "path1", "path1", "path2", "path2"], + "start_end_idx": [ + (0, 120), + (90, 200), + (0, 40), + (35, 75), + (0, 100), + (150, 200), + ], + "metadata": [ + {"page": 1, "last_modified_datetime": "2021-08-01"}, + {"page": 1, "last_modified_datetime": "2021-08-01"}, + {"page": 2, "last_modified_datetime": "2021-08-02"}, + {"page": 2, "last_modified_datetime": "2021-08-02"}, + {"last_modified_datetime": "2021-08-01"}, + {"last_modified_datetime": "2021-08-01"}, + ], + } + ), + raw, + ) + + qa = QA( + pd.DataFrame( + { + "qid": ["qid1", "qid2", "qid3", "qid4"], + "query": ["hello", "world", "foo", "bar"], + "retrieval_gt": [ + [["id1"]], + [["id1"], ["id2"]], + [["id3", "id4"]], + [["id6", "id2"], ["id5"]], + ], + "generation_gt": ["world", "foo", "bar", "jax"], + } + ), + original_corpus, + ) + + new_corpus = Corpus( + pd.DataFrame( + { + "doc_id": [ + "new_id1", + "new_id2", + "new_id3", + "new_id4", + "new_id5", + "new_id6", + ], + "contents": ["hello", "world", "foo", "bar", "baz", "jax"], + "path": ["path1", "path1", "path1", "path1", "path2", "path2"], + "start_end_idx": [ + (0, 80), + (80, 150), + (15, 50), + (50, 80), + (0, 200), + (201, 400), + ], + "metadata": [ + {"page": 1, "last_modified_datetime": "2021-08-01"}, + {"page": 1, "last_modified_datetime": "2021-08-01"}, + {"page": 2, "last_modified_datetime": "2021-08-02"}, + {"page": 2, "last_modified_datetime": "2021-08-02"}, + {"last_modified_datetime": "2021-08-01"}, + {"last_modified_datetime": "2021-08-01"}, + ], + } + ), + raw, + ) + + new_qa = qa.update_corpus(new_corpus) + + expected_dataframe = pd.DataFrame( + { + "qid": ["qid1", "qid2", "qid3", "qid4"], + "retrieval_gt": [ + [["new_id1", "new_id2"]], + [["new_id1", "new_id2"], ["new_id2"]], + [["new_id3", "new_id3", "new_id4"]], + [["new_id5", "new_id2"], ["new_id5"]], + ], + } + ) + pd.testing.assert_frame_equal( + new_qa.data[["qid", "retrieval_gt"]], expected_dataframe + ) + with tempfile.NamedTemporaryFile(suffix=".parquet") as qa_path: + with tempfile.NamedTemporaryFile(suffix=".parquet") as corpus_path: + new_qa.to_parquet(qa_path.name, corpus_path.name) + loaded_qa = pd.read_parquet(qa_path.name, engine="pyarrow") + assert set(loaded_qa.columns) == { + "qid", + "query", + "retrieval_gt", + "generation_gt", + } + loaded_corpus = pd.read_parquet(corpus_path.name, engine="pyarrow") + assert set(loaded_corpus.columns) == {"doc_id", "contents", "metadata"}