Skip to content

Commit

Permalink
Support aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiheng-huang committed Dec 12, 2024
1 parent 74da4a0 commit 79577b7
Show file tree
Hide file tree
Showing 19 changed files with 218 additions and 481 deletions.
27 changes: 10 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,13 @@ An enterprise-grade AI retriever designed to streamline AI integration into your

</div>

## 📝 Description

Denser Retriever combines multiple search technologies into a single platform. It utilizes **gradient boosting (
xgboost)** machine learning technique to combine:

- **Keyword-based searches** that focus on fetching precisely what the query mentions.
- **Vector databases** that are great for finding a wide range of potentially relevant answers.
- **Machine Learning rerankers** that fine-tune the results to ensure the most relevant answers top the list.

* Our experiments on MTEB datasets show that the combination of keyword search, vector search and a reranker via a xgboost model (denoted as ES+VS+RR_n) can significantly improve the vector search (VS) baseline.

![mteb_ndcg_plot](https://github.com/denser-org/denser-retriever/blob/main/mteb_ndcg_plot.png?raw=true)

* **Check out Denser Retriever experiments using the Anthropic Contextual Retrieval dataset at [here](https://github.com/denser-org/denser-retriever/tree/main/experiments/data/contextual-embeddings)**.
## 🚀 Features

The initial release of Denser Retriever provides the following features.

- Supporting heterogeneous retrievers such as **keyword search**, **vector search**, and **ML model reranking**
- Leveraging **xgboost** ML technique to effectively combine heterogeneous retrievers
- **State-of-the-art accuracy** on [MTEB](https://github.com/embeddings-benchmark/mteb) Retrieval benchmarking
- **Comprehensive benchmark** on [MTEB](https://github.com/embeddings-benchmark/mteb) Retrieval dataset
- Demonstrating how to use Denser retriever to power an **end-to-end applications** such as chatbot and semantic search
![mteb_ndcg_plot](https://github.com/denser-org/denser-retriever/blob/main/mteb_ndcg_plot.png?raw=true)

## 📦 Installation

Expand All @@ -56,6 +41,14 @@ pip install denser-retriever
poetry add denser-retriever
```

## Quick Start

## 📝 Experiments

### [Anthropic Contextual Retrieval experiment](https://github.com/denser-org/denser-retriever/tree/main/experiments/data/contextual-embeddings)

### [MTEB Retrieval experiment](https://retriever-docs.denser.ai/docs/core/experiments/mteb_retrieval)

## 📃 Documentation

The official documentation is hosted on [retriever.denser.ai](https://retriever.denser.ai).
Expand Down
21 changes: 21 additions & 0 deletions denser_retriever/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ def embed_query(self, text):
return embeddings


class BGEEmbeddings(DenserEmbeddings):
def __init__(self, model_name: str, embedding_size: int):
try:
from FlagEmbedding import FlagICLModel
except ImportError as exc:
raise ImportError(
"Could not import FlagEmbedding python package."
) from exc

self.client = FlagICLModel(model_name,
query_instruction_for_retrieval="Given a web search query, retrieve relevant passages that answer the query.",
examples_for_task=None, # set `examples_for_task=None` to use model without examples
use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
self.embedding_size = embedding_size

def embed_documents(self, texts):
return self.client.encode_corpus(texts)

def embed_query(self, text):
return self.client.encode_queries(text)

class VoyageAPIEmbeddings(DenserEmbeddings):
def __init__(self, api_key: str, model_name: str, embedding_size: int):
try:
Expand Down
136 changes: 77 additions & 59 deletions denser_retriever/keyword.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ def retrieve(
def get_index_mappings(self) -> Dict[Any, Any]:
raise NotImplementedError

@abstractmethod
def get_categories(self, field: str, k: int = 10) -> List[Any]:
raise NotImplementedError

@abstractmethod
def delete(
Expand Down Expand Up @@ -123,11 +120,12 @@ def __init__(
self.analysis = analysis
self.client = es_connection

def create_index(self, index_name: str, search_fields: List[str], **args: Any):
def create_index(self, index_name: str, search_fields: List[str], date_fields: List[str]=[], **args: Any):

# Define the index settings and mappings
self.index_name = index_name
self.search_fields = FieldMapper(search_fields)
self.date_fields = date_fields

logger.info("ES analysis %s", self.analysis)
if self.analysis == "default":
Expand Down Expand Up @@ -238,14 +236,15 @@ def add_documents(
"source": metadata.get("source"),
"pid": metadata.get("pid"),
}
for filter in self.search_fields.get_keys():
value = metadata.get(filter, "")
for filter_key in metadata.keys():
value = metadata.get(filter_key, "")
if isinstance(value, list):
value = [v.strip() for v in value]
elif value is not None:
value = value.strip()
value = [str(v).strip() for v in value if v is not None]
else:
if value is not None:
value = str(value).strip()
if value:
request[filter] = value
request[filter_key] = value
requests.append(request)

if len(requests) > 0:
Expand All @@ -271,37 +270,48 @@ def add_documents(
return []

def retrieve(
self,
query: str,
k: int = 100,
filter: Dict[str, Any] = {},
) -> List[Tuple[Document, float]]:
self,
query: str,
k: int = 100,
filter: Dict[str, Any] = {},
aggregation: bool = False, # Aggregate metadata
) -> Tuple[List[Tuple[Document, float]], Dict]:
assert self.client.indices.exists(index=self.index_name)
start_time = time.time()

# Build the query with title and content matching and a minimum_should_match condition
query_dict = {
"query": {
"bool": {
"should": [
"must": [
{
"match": {
"title": {
"query": query,
"boost": 2.0,
}
"bool": {
"should": [
{
"match": {
"title": {
"query": query,
"boost": 2.0,
}
}
},
{
"match": {
"content": query,
}
}
],
"minimum_should_match": 1 # Ensure at least one of the should conditions is matched
}
},
{
"match": {
"content": query,
},
},
],
"must": [],
}
]
}
},
"_source": True,
"aggs": {}, # This will be populated with aggregations for fields
}

# Add filters if provided
for field in filter:
category_or_date = filter.get(field)
if category_or_date:
Expand All @@ -313,7 +323,7 @@ def retrieve(
"gte": category_or_date[0],
"lte": category_or_date[1]
if len(category_or_date) > 1
else category_or_date[0], # type: ignore
else category_or_date[0],
}
}
}
Expand All @@ -323,32 +333,59 @@ def retrieve(
{"term": {field: category_or_date}}
)

# Add aggregations for each field provided in 'fields' if aggregation is True
if aggregation:
for field in self.search_fields.get_keys():
query_dict["aggs"][f"{field}_aggregation"] = {
"terms": {
"field": f"{field}", # Use keyword type for aggregations
"size": 50 # Adjust size as needed
}
}

# Execute search query
res = self.client.search(
index=self.index_name,
body=query_dict,
size=k,
)

# Process search hits (documents)
top_k_used = min(len(res["hits"]["hits"]), k)
docs = []
for id in range(top_k_used):
_source = res["hits"]["hits"][id]["_source"]
doc = Document(
page_content=_source["content"],
metadata={
"source": _source["source"],
"title": _source["title"],
"pid": _source["pid"],
},
page_content=_source.pop("content"),
metadata=_source,
)
score = res["hits"]["hits"][id]["_score"]
for field in self.search_fields.get_keys():
if _source.get(field):
doc.metadata[field] = _source.get(field)
# import pdb; pdb.set_trace()
# for field in self.search_fields.get_keys():
# if _source.get(field):
# doc.metadata[field] = _source.get(field)
docs.append((doc, score))

# Process aggregations for the specified fields
aggregations = {}
for field in self.search_fields.get_keys():
field_agg = res.get("aggregations", {}).get(f"{field}_aggregation", {}).get("buckets", [])
cat_keys = [cat['key'] for cat in field_agg]
cat_counts = [cat['doc_count'] for cat in field_agg]
if len(cat_keys) > 0:
if field in self.date_fields:
sorted_data = sorted(zip(cat_keys, cat_counts), key=lambda x: x[0], reverse=True)
sorted_keys, sorted_counts = zip(*sorted_data)
cat_keys = list(sorted_keys)
cat_counts = list(sorted_counts)
aggregations[field] = (cat_keys, cat_counts)

retrieve_time_sec = time.time() - start_time
logger.info(f"Keyword retrieve time: {retrieve_time_sec:.3f} sec.")
logger.info(f"Retrieved {len(docs)} documents.")
return docs

# Return both documents and aggregation results
return docs, aggregations

def get_index_mappings(self):
mapping = self.client.indices.get_mapping(index=self.index_name)
Expand Down Expand Up @@ -377,25 +414,6 @@ def extract_fields(fields_dict, parent_name=""):
all_fields = extract_fields(properties)
return all_fields

def get_categories(self, field: str, k: int = 10):
query = {
"size": 0, # No actual documents are needed, just the aggregation results
"aggs": {
"all_categories": {
"terms": {
"field": field,
"size": 1000, # Adjust this value based on the expected number of unique categories
}
}
},
}
response = self.client.search(index=self.index_name, body=query)
# Extract the aggregation results
categories = response["aggregations"]["all_categories"]["buckets"]
if k > 0:
categories = categories[:k]
res = [category["key"] for category in categories]
return res

def delete(
self,
Expand Down
Loading

0 comments on commit 79577b7

Please sign in to comment.