Skip to content

Commit

Permalink
feat: add func to generate multiple quries (#1009)
Browse files Browse the repository at this point in the history
* feat: add func to generate multiple quries

* formatting the code

---------

Co-authored-by: Um Changyong <[email protected]>
Co-authored-by: Jeffrey (Dongkyu) Kim <[email protected]>
  • Loading branch information
3 people authored Nov 29, 2024
1 parent 1a49236 commit 7740a82
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ pytest.ini
.DS_Store
projects/tutorial_1
!projects/tutorial_1/config.yaml

# Visual Studio Code
.vscode/
14 changes: 13 additions & 1 deletion autorag/data/qa/query/llama_gen_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole

from autorag.data.qa.query.prompt import QUERY_GEN_PROMPT
from autorag.data.qa.query.prompt import QUERY_GEN_PROMPT, QUERY_GEN_PROMPT_EXTRA


async def llama_index_generate_base(
Expand Down Expand Up @@ -68,3 +68,15 @@ async def custom_query_gen(
messages: List[ChatMessage],
) -> Dict:
return await llama_index_generate_base(row, llm, messages)


# Experimental feature: can only use factoid_single_hop
async def multiple_queries_gen(
row: Dict,
llm: BaseLLM,
lang: str = "en",
n: int = 3,
) -> Dict:
_messages = QUERY_GEN_PROMPT["factoid_single_hop"][lang]
_messages[0].content += QUERY_GEN_PROMPT_EXTRA["multiple_queries"][lang].format(n=n)
return await llama_index_generate_base(row, llm, _messages)
9 changes: 9 additions & 0 deletions autorag/data/qa/query/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,12 @@
],
},
}

# Experimental feature
QUERY_GEN_PROMPT_EXTRA = {
"multiple_queries": {
"en": "\nAdditional instructions:\n - Please make {n} questions.",
"ko": "\n추가 지침:\n - 질문은 {n}개를 만드세요.",
"ja": "\n追加指示:\n - 質問を{n}個作成してください。",
}
}
21 changes: 21 additions & 0 deletions autorag/data/qa/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Callable, Optional, Dict, Awaitable, Any, Tuple, List
import uuid
import pandas as pd
from autorag.utils.util import process_batch, get_event_loop, fetch_contents

Expand Down Expand Up @@ -137,6 +138,11 @@ def batch_apply(
loop = get_event_loop()
tasks = [fn(qa_dict, **kwargs) for qa_dict in qa_dicts]
results = loop.run_until_complete(process_batch(tasks, batch_size))

# Experimental feature
if fn.__name__ == "multiple_queries_gen":
return self._process_multiple_queries_gen(results)

return QA(pd.DataFrame(results), self.linked_corpus)

def batch_filter(
Expand Down Expand Up @@ -299,3 +305,18 @@ def __make_path_corpus_dict(corpus_df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
path: corpus_df[corpus_df["path"] == path]
for path in corpus_df["path"].unique()
}

# Experimental feature
def _process_multiple_queries_gen(self, results: List[Dict]) -> "QA":
data = []
for result in results:
queries = result["query"].split("\n")
for query in queries:
new_result = {
key: (str(uuid.uuid4()) if key == "qid" else result[key])
for key in result.keys()
}
new_result["query"] = query
data.append(new_result)
df = pd.DataFrame(data)
return QA(df, self.linked_corpus)

0 comments on commit 7740a82

Please sign in to comment.