Skip to content

Commit

Permalink
feat: retrieval params
Browse files Browse the repository at this point in the history
  • Loading branch information
jotyy committed Oct 28, 2024
1 parent b90deac commit c41e23c
Show file tree
Hide file tree
Showing 11 changed files with 378 additions and 215 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9"]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 0 additions & 4 deletions denser_retriever/keyword.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ class DenserKeywordSearch(ABC):
Denser keyword search interface.
"""

def __init__(self, top_k: int = 100, weight: float = 0.5):
self.top_k = top_k
self.weight = weight

@abstractmethod
def create_index(self, index_name: str, search_fields: List[str], **args: Any):
raise NotImplementedError
Expand Down
28 changes: 14 additions & 14 deletions denser_retriever/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@

logger = logging.getLogger(__name__)

class DenserReranker(ABC):
def __init__(self, top_k: int = 50, weight: float = 0.5):
self.top_k = top_k
self.weight = weight

class DenserReranker(ABC):
@abstractmethod
def rerank(
self,
Expand All @@ -26,8 +23,8 @@ def rerank(
class HFReranker(DenserReranker):
"""Rerank documents using a HuggingFaceCrossEncoder model."""

def __init__(self, model_name: str, top_k: int, **kwargs):
super().__init__(top_k=top_k)
def __init__(self, model_name: str, **kwargs):
super().__init__()
self.model = CrossEncoder(model_name, **kwargs)

def rerank(
Expand All @@ -48,7 +45,9 @@ def rerank(
if not documents:
return []
start_time = time.time()
scores = self.model.predict([(query, doc.page_content) for doc in documents], convert_to_tensor=True)
scores = self.model.predict(
[(query, doc.page_content) for doc in documents], convert_to_tensor=True
)
docs_with_scores = list(zip(documents, scores))
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
rerank_time_sec = time.time() - start_time
Expand All @@ -73,9 +72,9 @@ def __init__(self, api_key: str, model_name: str = "rerank-english-v3.0", **kwar
self.model_name = model_name

def rerank(
self,
documents: Sequence[Document],
query: str,
self,
documents: Sequence[Document],
query: str,
) -> List[Tuple[Document, float]]:
"""
Rerank documents using Cohere's reranking model.
Expand All @@ -95,12 +94,13 @@ def rerank(
# Prepare documents for reranking
texts = [doc.page_content for doc in documents]
response = self.client.rerank(
model=self.model_name,
query=query,
documents=texts
model=self.model_name, query=query, documents=texts
)
# Combine documents with scores from the rerank response
docs_with_scores = [(documents[result.index], result.relevance_score) for result in response.results]
docs_with_scores = [
(documents[result.index], result.relevance_score)
for result in response.results
]

# Sort the documents by their scores in descending order
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
Expand Down
78 changes: 61 additions & 17 deletions denser_retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time

from langchain_core.documents import Document
from pydantic import BaseModel

from denser_retriever.embeddings import DenserEmbeddings
from denser_retriever.gradient_boost import DenserGradientBoost
Expand Down Expand Up @@ -31,6 +32,17 @@
}


class RetrievalConfig(BaseModel):
top_k: int = 100
weight: float = 0.5


class RetrievalParams(BaseModel):
vector_db: RetrievalConfig = RetrievalConfig()
keyword: RetrievalConfig = RetrievalConfig()
reranker: RetrievalConfig = RetrievalConfig(top_k=50)


class DenserRetriever:
def __init__(
self,
Expand All @@ -42,7 +54,7 @@ def __init__(
gradient_boost: Optional[DenserGradientBoost],
combine_mode: str = "linear",
xgb_model_features: str = "es+vs+rr_n",
search_fields: List[str] = []
search_fields: List[str] = [],
):
# config parameters
self.index_name = index_name
Expand Down Expand Up @@ -80,46 +92,62 @@ def ingest(self, docs: List[Document], overwrite_pid: bool = True) -> List[str]:
return [doc.metadata["pid"] for doc in docs]

def retrieve(
self, query: str, k: int = 100, filter: Dict[str, Any] = {}, **kwargs: Any
self,
query: str,
k: int = 100,
filter: Dict[str, Any] = {},
retrieval_params: RetrievalParams = RetrievalParams(),
**kwargs: Any,
):
logger.info(f"Retrieve query: {query} top_k: {k}")
if self.combine_mode in ["linear", "rank"]:
return self._retrieve_by_linear_or_rank(query, k, filter, **kwargs)
return self._retrieve_by_linear_or_rank(
query, k, filter, retrieval_params, **kwargs
)
else:
return self._retrieve_by_model(query, k, filter, **kwargs)
return self._retrieve_by_model(query, k, filter, retrieval_params, **kwargs)

def _retrieve_by_linear_or_rank(
self, query: str, k: int = 100, filter: Dict[str, Any] = {}, **kwargs: Any
self,
query: str,
k: int = 100,
filter: Dict[str, Any] = {},
retrieval_params: RetrievalParams = RetrievalParams(),
**kwargs: Any,
):
passages = []

if self.keyword_search:
es_docs = self.keyword_search.retrieve(
query, self.keyword_search.top_k, filter=filter, **kwargs
query, retrieval_params.keyword.top_k, filter=filter, **kwargs
)
es_passages = scale_results(es_docs, self.keyword_search.weight)
es_passages = scale_results(es_docs, retrieval_params.keyword.weight)
logger.info(f"Keyword search: {len(es_passages)}")
passages.extend(es_passages)

if self.vector_db:
vector_docs = self.vector_db.similarity_search_with_score(
query, self.vector_db.top_k, filter, **kwargs
query, retrieval_params.vector_db.top_k, filter, **kwargs
)
logger.info(f"Vector search: {len(vector_docs)}")
passages = merge_results(
passages, vector_docs, 1.0, self.vector_db.weight, self.combine_mode
passages,
vector_docs,
1.0,
retrieval_params.vector_db.weight,
self.combine_mode,
)

if self.reranker:
start_time = time.time()
docs = [doc for doc, _ in passages[: self.reranker.top_k]]
docs = [doc for doc, _ in passages[: retrieval_params.reranker.top_k]]
reranked_docs = self.reranker.rerank(docs, query)

passages = merge_results(
passages,
reranked_docs,
1.0,
self.reranker.weight,
retrieval_params.reranker.weight,
self.combine_mode,
)
rerank_time_sec = time.time() - start_time
Expand All @@ -128,9 +156,16 @@ def _retrieve_by_linear_or_rank(
return passages[:k]

def _retrieve_by_model(
self, query: str, k: int = 100, filter: Dict[str, Any] = {}, **kwargs: Any
self,
query: str,
k: int = 100,
filter: Dict[str, Any] = {},
retrieval_params: RetrievalParams = RetrievalParams(),
**kwargs: Any,
) -> List[Tuple[Document, float]]:
docs, doc_features = self._retrieve_with_features(query, filter, **kwargs)
docs, doc_features = self._retrieve_with_features(
query, filter, retrieval_params, **kwargs
)

if not self.gradient_boost:
raise ValueError("Gradient Boost model not provided")
Expand All @@ -147,17 +182,21 @@ def _retrieve_by_model(
return reranked_docs[:k]

def _retrieve_with_features(
self, query: str, filter: Dict[str, Any] = {}, **kwargs: Any
self,
query: str,
filter: Dict[str, Any] = {},
retrieval_params: RetrievalParams = RetrievalParams(),
**kwargs: Any,
) -> Tuple[List[Document], List[List[str]]]:
ks_docs = []
if self.keyword_search:
ks_docs = self.keyword_search.retrieve(
query, self.keyword_search.top_k, filter=filter, **kwargs
query, retrieval_params.keyword.top_k, filter=filter, **kwargs
)
vs_docs = []
if self.vector_db:
vs_docs = self.vector_db.similarity_search_with_score(
query, k=self.vector_db.top_k, filter=filter, **kwargs
query, retrieval_params.vector_db.top_k, filter=filter, **kwargs
)

combined = []
Expand Down Expand Up @@ -248,7 +287,12 @@ def _retrieve_with_features(

return docs, non_zero_normalized_features

def delete(self, ids: Optional[List[str]] = None, source_id: Optional[str] = None, **kwargs: str):
def delete(
self,
ids: Optional[List[str]] = None,
source_id: Optional[str] = None,
**kwargs: str,
):
"""Clear the retriever."""
if self.vector_db:
self.vector_db.delete(ids=ids, source_id=source_id, **kwargs)
Expand Down
4 changes: 0 additions & 4 deletions denser_retriever/vectordb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ class DenserVectorDB(ABC):
Interface for a denser vector database.
"""

def __init__(self, top_k: int = 100, weight: float = 0.5):
self.top_k = top_k
self.weight = weight

def create_index(
self,
index_name: str,
Expand Down
2 changes: 1 addition & 1 deletion denser_retriever/vectordb/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def similarity_search_with_score(
logger.info(f"Vector DB retrieve time: {retrieve_time_sec:.3f} sec.")
logger.info(f"Retrieved {len(result[0])} documents.")

top_k_used = min(len(result[0]), self.top_k) # type: ignore
top_k_used = min(len(result[0]), k) # type: ignore

ret = []
for id in range(top_k_used):
Expand Down
Loading

0 comments on commit c41e23c

Please sign in to comment.