Skip to content

Commit

Permalink
perf: update retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
datvodinh committed May 13, 2024
1 parent 19788e5 commit 1ea9fc4
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 36 deletions.
146 changes: 113 additions & 33 deletions rag_chatbot/core/engine/retriever.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import List
from dotenv import load_dotenv
from llama_index.core.retrievers import (
BaseRetriever,
QueryFusionRetriever,
VectorIndexRetriever,
RouterRetriever
)
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.retrievers.fusion_retriever import FUSION_MODES
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.tools import RetrieverTool
from llama_index.core.selectors import LLMSingleSelector
from llama_index.core.schema import BaseNode
from llama_index.core.schema import BaseNode, NodeWithScore, QueryBundle, IndexNode
from llama_index.core.llms.llm import LLM
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core import Settings, VectorStoreIndex
Expand All @@ -17,10 +21,53 @@
load_dotenv()


class TwoStageRetriever:
def __init__(self) -> None:
# TODO
pass
class TwoStageRetriever(QueryFusionRetriever):
def __init__(
self,
retrievers: List[BaseRetriever],
setting: RAGSettings | None = None,
llm: str | None = None,
query_gen_prompt: str | None = None,
mode: FUSION_MODES = FUSION_MODES.SIMPLE,
similarity_top_k: int = ...,
num_queries: int = 4,
use_async: bool = True,
verbose: bool = False,
callback_manager: CallbackManager | None = None,
objects: List[IndexNode] | None = None,
object_map: dict | None = None,
retriever_weights: List[float] | None = None
) -> None:
super().__init__(
retrievers, llm, query_gen_prompt, mode, similarity_top_k, num_queries,
use_async, verbose, callback_manager, objects, object_map, retriever_weights
)
self._setting = setting or RAGSettings()
self._rerank_model = SentenceTransformerRerank(
top_n=self._setting.retriever.top_k_rerank,
model=self._setting.retriever.rerank_llm,
)

def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
queries: List[QueryBundle] = [query_bundle]
if self.num_queries > 1:
queries.extend(self._get_queries(query_bundle.query_str))

if self.use_async:
results = self._run_nested_async_queries(queries)
else:
results = self._run_sync_queries(queries)
results = self._simple_fusion(results)
return self._rerank_model.postprocess_nodes(results, query_bundle)

async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
queries: List[QueryBundle] = [query_bundle]
if self.num_queries > 1:
queries.extend(self._get_queries(query_bundle.query_str))

results = await self._run_async_queries(queries)
results = self._simple_fusion(results)
return self._rerank_model.postprocess_nodes(results, query_bundle)


class LocalRetriever:
Expand All @@ -33,24 +80,26 @@ def __init__(
self._setting = setting or RAGSettings()
self._host = host

def _get_two_stage_retriever(
def _get_normal_retriever(
self,
llm: LLM,
vector_index: VectorStoreIndex,
language: str,
llm: LLM | None = None,
language: str = "eng",
):
vector_retriever = VectorIndexRetriever(
llm = llm or Settings.llm
return VectorIndexRetriever(
index=vector_index,
similarity_top_k=self._setting.retriever.similarity_top_k,
embed_model=Settings.embed_model,
verbose=True
) # TODO
)

def _get_fusion_retriever(
def _get_hybrid_retriever(
self,
llm: LLM,
vector_index: VectorStoreIndex,
language: str,
llm: LLM | None = None,
language: str = "eng",
gen_query: bool = True
):
# VECTOR INDEX RETRIEVER
vector_retriever = VectorIndexRetriever(
Expand All @@ -67,18 +116,57 @@ def _get_fusion_retriever(
)

# FUSION RETRIEVER
fusion_retriever = QueryFusionRetriever(
retrievers=[bm25_retriever, vector_retriever],
retriever_weights=self._setting.retriever.retriever_weights,
llm=llm,
query_gen_prompt=get_query_gen_prompt(language),
similarity_top_k=self._setting.retriever.top_k_rerank,
num_queries=self._setting.retriever.num_queries,
mode=self._setting.retriever.fusion_mode,
verbose=True
if gen_query:
hybrid_retriever = QueryFusionRetriever(
retrievers=[bm25_retriever, vector_retriever],
retriever_weights=self._setting.retriever.retriever_weights,
llm=llm,
query_gen_prompt=get_query_gen_prompt(language),
similarity_top_k=self._setting.retriever.top_k_rerank,
num_queries=self._setting.retriever.num_queries,
mode=self._setting.retriever.fusion_mode,
verbose=True
)
else:
hybrid_retriever = TwoStageRetriever(
retrievers=[bm25_retriever, vector_retriever],
retriever_weights=self._setting.retriever.retriever_weights,
llm=llm,
query_gen_prompt=None,
similarity_top_k=self._setting.retriever.similarity_top_k,
num_queries=1,
mode=self._setting.retriever.fusion_mode,
verbose=True
)

return hybrid_retriever

def _get_router_retriever(
self,
vector_index: VectorStoreIndex,
llm: LLM | None = None,
language: str = "eng",
):
fusion_tool = RetrieverTool.from_defaults(
retriever=self._get_hybrid_retriever(
vector_index, llm, language, gen_query=True
),
description="Use this tool when the user's query is ambiguous or unclear.",
name="Fusion Retriever with BM25 and Vector Retriever and LLM Query Generation."
)
two_stage_tool = RetrieverTool.from_defaults(
retriever=self._get_hybrid_retriever(
vector_index, llm, language, gen_query=False
),
description="Use this tool when the user's query is clear and unambiguous.",
name="Two Stage Retriever with BM25 and Vector Retriever and LLM Rerank."
)

return fusion_retriever
return RouterRetriever.from_defaults(
selector=LLMSingleSelector.from_defaults(llm=llm),
retriever_tools=[fusion_tool, two_stage_tool],
llm=llm
)

def get_retrievers(
self,
Expand All @@ -88,16 +176,8 @@ def get_retrievers(
):
vector_index = VectorStoreIndex(nodes=nodes)
if len(nodes) > self._setting.retriever.top_k_rerank:
retriever = self._get_fusion_retriever(llm, vector_index, language)
retriever = self._get_router_retriever(vector_index, llm, language)
else:
retriever = VectorIndexRetriever(
index=vector_index,
similarity_top_k=self._setting.retriever.top_k_rerank,
verbose=True
)
retriever = self._get_normal_retriever(vector_index, llm, language)

return retriever

# TODO: new router retriever
# Ambigous query: vector + bm25 + query fusion
# Good query: vector + bm25 + rerank
2 changes: 1 addition & 1 deletion rag_chatbot/core/ingestion/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def store_nodes(
secondary_chunking_regex=self._setting.ingestion.chunking_regex
)
excluded_keys = [
"doc_id", "file_path", "file_type",
"doc_id", "file_path", "file_type", "page_label", "file_name",
"file_size", "creation_date", "last_modified_date"
]
if embed_nodes:
Expand Down
4 changes: 2 additions & 2 deletions rag_chatbot/setting/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class RetrieverSettings(BaseModel):
default=[0.4, 0.6], description="Weights for retriever"
)
top_k_rerank: int = Field(
default=4, description="Top k rerank"
default=5, description="Top k rerank"
)
rerank_llm: str = Field(
default="BAAI/bge-reranker-large", description="Rerank LLM model"
Expand All @@ -64,7 +64,7 @@ class RetrieverSettings(BaseModel):

class IngestionSettings(BaseModel):
embed_llm: str = Field(
default="BAAI/bge-base-en-v1.5", description="Embedding LLM model"
default="BAAI/bge-large-en-v1.5", description="Embedding LLM model"
)
embed_batch_size: int = Field(
default=4, description="Embedding batch size"
Expand Down

0 comments on commit 1ea9fc4

Please sign in to comment.