-
-
Notifications
You must be signed in to change notification settings - Fork 276
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use generator module at hyde and decompose instead of LLMPredictorTyp…
…e from LlamaIndex (#464) * just commit * implement generator func at hyde and decompose * add docs for module parameter * add no generator test and patch * change to simple.yaml * add prompt parameter at docs delete batch parameter * use bool and f-string at query_decompose.py
- Loading branch information
Showing
10 changed files
with
148 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,31 @@ | ||
import asyncio | ||
from typing import List | ||
from typing import List, Dict, Callable | ||
|
||
from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType | ||
import pandas as pd | ||
|
||
from autorag.nodes.queryexpansion.base import query_expansion_node | ||
from autorag.utils.util import process_batch | ||
|
||
hyde_prompt = "Please write a passage to answer the question" | ||
|
||
|
||
@query_expansion_node | ||
def hyde(queries: List[str], llm: LLMPredictorType, | ||
prompt: str = hyde_prompt, | ||
batch: int = 16) -> List[List[str]]: | ||
def hyde(queries: List[str], | ||
generator_func: Callable, | ||
generator_params: Dict, | ||
prompt: str = hyde_prompt) -> List[List[str]]: | ||
""" | ||
HyDE, which inspired by "Precise Zero-shot Dense Retrieval without Relevance Labels" (https://arxiv.org/pdf/2212.10496.pdf) | ||
LLM model creates a hypothetical passage. | ||
And then, retrieve passages using hypothetical passage as a query. | ||
:param queries: List[str], queries to retrieve. | ||
:param llm: llm to use for hypothetical passage generation. | ||
:param generator_func: Callable, generator functions. | ||
:param generator_params: Dict, generator parameters. | ||
:param prompt: prompt to use when generating hypothetical passage | ||
:param batch: Batch size for llm. | ||
Default is 16. | ||
:return: List[List[str]], List of hyde results. | ||
""" | ||
# Run async query_decompose_pure function | ||
tasks = [hyde_pure(query, llm, prompt) for query in queries] | ||
loop = asyncio.get_event_loop() | ||
results = loop.run_until_complete(process_batch(tasks, batch_size=batch)) | ||
full_prompts = list( | ||
map(lambda x: (prompt if not bool(prompt) else hyde_prompt) + f"\nQuestion: {x}\nPassage:", queries)) | ||
input_df = pd.DataFrame({"prompts": full_prompts}) | ||
result_df = generator_func(project_dir=None, previous_result=input_df, **generator_params) | ||
answers = result_df['generated_texts'].tolist() | ||
results = list(map(lambda x: [x], answers)) | ||
return results | ||
|
||
|
||
async def hyde_pure(query: str, llm: LLMPredictorType, | ||
prompt: str = hyde_prompt) -> List[str]: | ||
if prompt is "": | ||
prompt = hyde_prompt | ||
full_prompt = prompt + f"\nQuestion: {query}\nPassage:" | ||
hyde_answer = await llm.acomplete(full_prompt) | ||
return [hyde_answer.text] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,25 @@ | ||
from autorag import generator_models | ||
from autorag.nodes.queryexpansion import hyde | ||
from autorag.support import get_support_modules | ||
from tests.autorag.nodes.queryexpansion.test_query_expansion_base import project_dir, previous_result, \ | ||
base_query_expansion_node_test, ingested_vectordb_node | ||
from tests.mock import MockLLM | ||
|
||
sample_query = ["How many members are in Newjeans?", "What is visconde structure?"] | ||
|
||
|
||
def test_hyde(): | ||
llm = MockLLM() | ||
generator_func = get_support_modules('llama_index_llm') | ||
generator_params = {'llm': 'mock'} | ||
original_hyde = hyde.__wrapped__ | ||
result = original_hyde(sample_query, llm, prompt="") | ||
result = original_hyde(sample_query, generator_func, generator_params, prompt="") | ||
assert len(result[0]) == 1 | ||
assert len(result) == 2 | ||
|
||
|
||
def test_hyde_node(ingested_vectordb_node): | ||
generator_models['mock'] = MockLLM | ||
generator_dict = { | ||
'generator_module_type': 'llama_index_llm', | ||
'llm': 'mock' | ||
} | ||
result_df = hyde(project_dir=project_dir, previous_result=previous_result, | ||
llm="mock", max_tokens=64) | ||
**generator_dict) | ||
base_query_expansion_node_test(result_df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,26 @@ | ||
from autorag import generator_models | ||
from autorag.nodes.queryexpansion import query_decompose | ||
from autorag.support import get_support_modules | ||
from tests.autorag.nodes.queryexpansion.test_query_expansion_base import project_dir, previous_result, \ | ||
base_query_expansion_node_test, ingested_vectordb_node | ||
from tests.mock import MockLLM | ||
|
||
sample_query = ["Which group has more members, Newjeans or Aespa?", "Which group has more members, STAYC or Aespa?"] | ||
|
||
|
||
def test_query_decompose(): | ||
llm = MockLLM(temperature=0.2) | ||
generator_func = get_support_modules('llama_index_llm') | ||
generator_params = {'llm': 'mock'} | ||
original_query_decompose = query_decompose.__wrapped__ | ||
result = original_query_decompose(sample_query, llm, prompt="") | ||
result = original_query_decompose(sample_query, generator_func, generator_params, prompt="") | ||
assert len(result[0]) > 1 | ||
assert len(result) == 2 | ||
assert isinstance(result[0][0], str) | ||
|
||
|
||
def test_query_decompose_node(ingested_vectordb_node): | ||
generator_models['mock'] = MockLLM | ||
generator_dict = { | ||
'generator_module_type': 'llama_index_llm', | ||
'llm': 'mock' | ||
} | ||
result_df = query_decompose(project_dir=project_dir, previous_result=previous_result, | ||
llm="mock", temperature=0.2) | ||
**generator_dict) | ||
base_query_expansion_node_test(result_df) |
Oops, something went wrong.