Skip to content

Commit

Permalink
Add Anthropic contextual retrieval experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiheng-huang committed Oct 2, 2024
1 parent 18783f4 commit 8ee2efd
Show file tree
Hide file tree
Showing 17 changed files with 7,538 additions and 40 deletions.
37 changes: 37 additions & 0 deletions denser_retriever/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,40 @@ def embed_query(self, text):
else:
embeddings = self.client.encode([text], prompt_name="query")
return embeddings


class VoyageAPIEmbeddings(DenserEmbeddings):
def __init__(self, api_key: str, model_name: str, embedding_size: int):
try:
import voyageai
except ImportError as exc:
raise ImportError(
"Could not import voyage python package. "
"Please install it with `pip install voyageai`."
) from exc

self.client = voyageai.Client(api_key)
self.model_name = model_name
self.embedding_size = embedding_size

def embed_documents(self, texts):
"""
Embeds multiple documents using the Voyage API.
Args:
texts: A list of document texts.
Returns:
A list of document embeddings.
"""
embeddings = self.client.embed(texts, model=self.model_name).embeddings
return embeddings

def embed_query(self, text):
"""
Embeds a single query using the Voyage API.
Args:
text: The query text.
Returns:
The query embedding.
"""
embeddings = self.client.embed([text], model=self.model_name).embeddings
return embeddings
12 changes: 8 additions & 4 deletions denser_retriever/keyword.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,13 @@ def add_documents(
"pid": metadata.get("pid"),
}
for filter in self.search_fields.get_keys():
v = metadata.get(filter, "").strip()
if v:
request[filter] = v
value = metadata.get(filter, "")
if isinstance(value, list):
value = [v.strip() for v in value]
elif value is not None:
value = value.strip()
if value:
request[filter] = value
requests.append(request)

if len(requests) > 0:
Expand Down Expand Up @@ -342,7 +346,7 @@ def retrieve(
},
)
score = res["hits"]["hits"][id]["_score"]
for field in filter:
for field in self.search_fields.get_keys():
if _source.get(field):
doc.metadata[field] = _source.get(field)
docs.append((doc, score))
Expand Down
64 changes: 59 additions & 5 deletions denser_retriever/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import List, Sequence, Tuple
import time
import logging
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
import cohere
from sentence_transformers import CrossEncoder
from langchain_core.documents import Document

logger = logging.getLogger(__name__)
Expand All @@ -25,9 +26,9 @@ def rerank(
class HFReranker(DenserReranker):
"""Rerank documents using a HuggingFaceCrossEncoder model."""

def __init__(self, model_name: str, model_kwargs: dict = {}, **kwargs):
super().__init__()
self.model = HuggingFaceCrossEncoder(model_name=model_name, model_kwargs = model_kwargs)
def __init__(self, model_name: str, top_k: int, **kwargs):
super().__init__(top_k=top_k)
self.model = CrossEncoder(model_name, **kwargs)

def rerank(
self,
Expand All @@ -47,11 +48,64 @@ def rerank(
if not documents:
return []
start_time = time.time()
scores = self.model.score([(query, doc.page_content) for doc in documents])
scores = self.model.predict([(query, doc.page_content) for doc in documents], convert_to_tensor=True)
docs_with_scores = list(zip(documents, scores))
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
rerank_time_sec = time.time() - start_time
logger.info(f"Rerank time: {rerank_time_sec:.3f} sec.")
logger.info(f"Reranked {len(result)} documents.")
return result


class CohereReranker(DenserReranker):
"""Rerank documents using the Cohere API."""

def __init__(self, api_key: str, model_name: str = "rerank-english-v3.0", **kwargs):
"""
Initialize Cohere reranker.
Args:
api_key: The API key for Cohere.
model_name: The name of the Cohere model to use for reranking.
"""
super().__init__()
self.client = cohere.Client(api_key)
self.model_name = model_name

def rerank(
self,
documents: Sequence[Document],
query: str,
) -> List[Tuple[Document, float]]:
"""
Rerank documents using Cohere's reranking model.
Args:
documents: A sequence of documents to rerank.
query: The query to use for ranking the documents.
Returns:
A list of tuples containing the document and its score.
"""
if not documents:
return []

start_time = time.time()

# Prepare documents for reranking
texts = [doc.page_content for doc in documents]
response = self.client.rerank(
model=self.model_name,
query=query,
documents=texts
)
# Combine documents with scores from the rerank response
docs_with_scores = [(documents[result.index], result.relevance_score) for result in response.results]

# Sort the documents by their scores in descending order
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)

rerank_time_sec = time.time() - start_time
logger.info(f"Cohere Rerank time: {rerank_time_sec:.3f} sec.")
logger.info(f"Reranked {len(result)} documents.")
return result
18 changes: 17 additions & 1 deletion denser_retriever/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def evaluate(
qrels: Dict[str, Dict[str, int]],
results: Dict[str, Dict[str, float]],
metric_file: Optional[str] = None,
k_values: List[int] = [1, 3, 5, 10, 100, 1000],
k_values: List[int] = [1, 3, 5, 10, 20, 100, 1000],
ignore_identical_ids: bool = True,
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]:
if ignore_identical_ids:
Expand Down Expand Up @@ -76,6 +76,11 @@ def save_queries(queries, output_file: str):
json.dump(data, out, ensure_ascii=False)
out.write("\n")

def load_queries(in_file: str):
res = []
for line in open(in_file, "r"):
res.append(json.loads(line))
return res

def save_qrels(qrels, output_file: str):
out = open(output_file, "w")
Expand All @@ -84,6 +89,17 @@ def save_qrels(qrels, output_file: str):
json.dump(data, out, ensure_ascii=False)
out.write("\n")

def save_qrels_from_trec(trec_file, qrels_file):
qrels = {}
with open(trec_file, "r") as f:
for line in f:
qid, _, pid, rel = line.split()
if qid not in qrels:
qrels[qid] = {}
qrels[qid][pid] = int(rel)

save_qrels(qrels, qrels_file)
return qrels

def load_qrels(in_file: str):
res = {}
Expand Down
Loading

0 comments on commit 8ee2efd

Please sign in to comment.