Skip to content

Commit

Permalink
Add update corpus feature for chunking optimization (#706)
Browse files Browse the repository at this point in the history
* will not work because of page

* add update_corpus feature

---------

Co-authored-by: jeffrey <[email protected]>
  • Loading branch information
vkehfdl1 and jeffrey authored Sep 13, 2024
1 parent 682c354 commit caa9f00
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 4 deletions.
1 change: 1 addition & 0 deletions autorag/data/beta/extract_evidence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# This module is about extracting evidence from the given retrieval gt passage
141 changes: 138 additions & 3 deletions autorag/data/beta/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from typing import Callable, Optional, Dict, Awaitable, Any
from typing import Callable, Optional, Dict, Awaitable, Any, Tuple, List
import pandas as pd
from autorag.utils.util import process_batch, get_event_loop, fetch_contents

from autorag.support import get_support_modules
from autorag.utils.util import process_batch, get_event_loop, fetch_contents

logger = logging.getLogger("AutoRAG")

Expand Down Expand Up @@ -68,6 +68,20 @@ def linked_raw(self) -> Raw:
def linked_raw(self, raw: Raw):
raise NotImplementedError("linked_raw is read-only.")

def to_parquet(self, save_path: str):
"""
Save the corpus to the AutoRAG compatible parquet file.
It is not for the data creation, for running AutoRAG.
If you want to save it directly, use the below code.
`corpus.data.to_parquet(save_path)`
:param save_path: The path to save the corpus.
"""
if not save_path.endswith(".parquet"):
raise ValueError("save_path must be ended with .parquet")
save_df = self.data[["doc_id", "contents", "metadata"]].reset_index(drop=True)
save_df.to_parquet(save_path)

def batch_apply(
self, fn: Callable[[Dict, Any], Awaitable[Dict]], batch_size: int = 32, **kwargs
) -> "Corpus":
Expand Down Expand Up @@ -152,6 +166,26 @@ def make_retrieval_gt_contents(self) -> "QA":
)
return self

def to_parquet(self, qa_save_path: str, corpus_save_path: str):
"""
Save the qa and corpus to the AutoRAG compatible parquet file.
It is not for the data creation, for running AutoRAG.
If you want to save it directly, use the below code.
`qa.data.to_parquet(save_path)`
:param qa_save_path: The path to save the qa dataset.
:param corpus_save_path: The path to save the corpus.
"""
if not qa_save_path.endswith(".parquet"):
raise ValueError("save_path must be ended with .parquet")
if not corpus_save_path.endswith(".parquet"):
raise ValueError("save_path must be ended with .parquet")
save_df = self.data[
["qid", "query", "retrieval_gt", "generation_gt"]
].reset_index(drop=True)
save_df.to_parquet(qa_save_path)
self.linked_corpus.to_parquet(corpus_save_path)

def update_corpus(self, new_corpus: Corpus) -> "QA":
"""
Update linked corpus.
Expand All @@ -163,4 +197,105 @@ def update_corpus(self, new_corpus: Corpus) -> "QA":
Must have valid `linked_raw` and `raw_id`, `raw_start_idx`, `raw_end_idx` columns.
:return: The QA instance that updated linked corpus.
"""
pass
self.data["evidence_path"] = (
self.data["retrieval_gt"]
.apply(
lambda x: fetch_contents(
self.linked_corpus.data,
x,
column_name="path",
)
)
.tolist()
)
self.data["evidence_page"] = self.data["retrieval_gt"].apply(
lambda x: list(
map(
lambda lst: list(map(lambda x: x.get("page", -1), lst)),
fetch_contents(self.linked_corpus.data, x, column_name="metadata"),
)
)
)
if "evidence_start_end_idx" not in self.data.columns:
# make evidence start_end_idx
self.data["evidence_start_end_idx"] = (
self.data["retrieval_gt"]
.apply(
lambda x: fetch_contents(
self.linked_corpus.data,
x,
column_name="start_end_idx",
)
)
.tolist()
)

# matching the new corpus with the old corpus
path_corpus_dict = QA.__make_path_corpus_dict(new_corpus.data)
new_retrieval_gt = self.data.apply(
lambda row: QA.__match_index_row(
row["evidence_start_end_idx"],
row["evidence_path"],
row["evidence_page"],
path_corpus_dict,
),
axis=1,
).tolist()
new_qa = self.data.copy(deep=True)[["qid", "query", "generation_gt"]]
new_qa["retrieval_gt"] = new_retrieval_gt
return QA(new_qa, new_corpus)

@staticmethod
def __match_index(target_idx: Tuple[int, int], dst_idx: Tuple[int, int]) -> bool:
"""
Check if the target_idx is overlap by the dst_idx.
"""
target_start, target_end = target_idx
dst_start, dst_end = dst_idx
return (
dst_start <= target_start <= dst_end or dst_start <= target_end <= dst_end
)

@staticmethod
def __match_index_row(
evidence_indices: List[List[Tuple[int, int]]],
evidence_paths: List[List[str]],
evidence_pages: List[List[int]],
path_corpus_dict: Dict,
) -> List[List[str]]:
"""
Find the matched passage from new_corpus.
:param evidence_indices: The evidence indices at the corresponding Raw.
Its shape is the same as the retrieval_gt.
:param evidence_paths: The evidence paths at the corresponding Raw.
Its shape is the same as the retrieval_gt.
:param path_corpus_dict: The key is the path name, and the value is the corpus dataframe that only contains the path in the key.
You can make it using `QA.__make_path_corpus_dict`.
:return:
"""
result = []
for i, idx_list in enumerate(evidence_indices):
sub_result = []
for j, idx in enumerate(idx_list):
path_corpus_df = path_corpus_dict[evidence_paths[i][j]]
if evidence_pages[i][j] >= 0:
path_corpus_df = path_corpus_df.loc[
path_corpus_df["metadata"].apply(lambda x: x.get("page", -1))
== evidence_pages[i][j]
]
matched_corpus = path_corpus_df.loc[
path_corpus_df["start_end_idx"].apply(
lambda x: QA.__match_index(idx, x)
)
]
sub_result.extend(matched_corpus["doc_id"].tolist())
result.append(sub_result)
return result

@staticmethod
def __make_path_corpus_dict(corpus_df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
return {
path: corpus_df[corpus_df["path"] == path]
for path in corpus_df["path"].unique()
}
127 changes: 126 additions & 1 deletion tests/autorag/data/beta/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import tempfile

import pandas as pd

from autorag.data.beta.schema import Raw, Corpus
from autorag.data.beta.schema import Raw, Corpus, QA
from tests.autorag.data.beta.test_data_creation_piepline import initial_raw


Expand Down Expand Up @@ -59,3 +61,126 @@ def test_raw_chunk():
origin_path in initial_raw.data["path"].tolist()
for origin_path in corpus.data["path"].tolist()
)


def test_update_corpus():
raw = Raw(
pd.DataFrame(
{
"texts": ["hello", "world", "jax"],
"path": ["path1", "path1", "path2"],
"page": [1, 2, -1],
"last_modified_datetime": [
"2021-08-01",
"2021-08-02",
"2021-08-03",
],
}
)
)
original_corpus = Corpus(
pd.DataFrame(
{
"doc_id": ["id1", "id2", "id3", "id4", "id5", "id6"],
"contents": ["hello", "world", "foo", "bar", "baz", "jax"],
"path": ["path1", "path1", "path1", "path1", "path2", "path2"],
"start_end_idx": [
(0, 120),
(90, 200),
(0, 40),
(35, 75),
(0, 100),
(150, 200),
],
"metadata": [
{"page": 1, "last_modified_datetime": "2021-08-01"},
{"page": 1, "last_modified_datetime": "2021-08-01"},
{"page": 2, "last_modified_datetime": "2021-08-02"},
{"page": 2, "last_modified_datetime": "2021-08-02"},
{"last_modified_datetime": "2021-08-01"},
{"last_modified_datetime": "2021-08-01"},
],
}
),
raw,
)

qa = QA(
pd.DataFrame(
{
"qid": ["qid1", "qid2", "qid3", "qid4"],
"query": ["hello", "world", "foo", "bar"],
"retrieval_gt": [
[["id1"]],
[["id1"], ["id2"]],
[["id3", "id4"]],
[["id6", "id2"], ["id5"]],
],
"generation_gt": ["world", "foo", "bar", "jax"],
}
),
original_corpus,
)

new_corpus = Corpus(
pd.DataFrame(
{
"doc_id": [
"new_id1",
"new_id2",
"new_id3",
"new_id4",
"new_id5",
"new_id6",
],
"contents": ["hello", "world", "foo", "bar", "baz", "jax"],
"path": ["path1", "path1", "path1", "path1", "path2", "path2"],
"start_end_idx": [
(0, 80),
(80, 150),
(15, 50),
(50, 80),
(0, 200),
(201, 400),
],
"metadata": [
{"page": 1, "last_modified_datetime": "2021-08-01"},
{"page": 1, "last_modified_datetime": "2021-08-01"},
{"page": 2, "last_modified_datetime": "2021-08-02"},
{"page": 2, "last_modified_datetime": "2021-08-02"},
{"last_modified_datetime": "2021-08-01"},
{"last_modified_datetime": "2021-08-01"},
],
}
),
raw,
)

new_qa = qa.update_corpus(new_corpus)

expected_dataframe = pd.DataFrame(
{
"qid": ["qid1", "qid2", "qid3", "qid4"],
"retrieval_gt": [
[["new_id1", "new_id2"]],
[["new_id1", "new_id2"], ["new_id2"]],
[["new_id3", "new_id3", "new_id4"]],
[["new_id5", "new_id2"], ["new_id5"]],
],
}
)
pd.testing.assert_frame_equal(
new_qa.data[["qid", "retrieval_gt"]], expected_dataframe
)
with tempfile.NamedTemporaryFile(suffix=".parquet") as qa_path:
with tempfile.NamedTemporaryFile(suffix=".parquet") as corpus_path:
new_qa.to_parquet(qa_path.name, corpus_path.name)
loaded_qa = pd.read_parquet(qa_path.name, engine="pyarrow")
assert set(loaded_qa.columns) == {
"qid",
"query",
"retrieval_gt",
"generation_gt",
}
loaded_corpus = pd.read_parquet(corpus_path.name, engine="pyarrow")
assert set(loaded_corpus.columns) == {"doc_id", "contents", "metadata"}

0 comments on commit caa9f00

Please sign in to comment.