Skip to content

Commit

Permalink
Legal aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiheng-huang committed Dec 6, 2024
1 parent 76256ee commit 3ade108
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 127 deletions.
21 changes: 21 additions & 0 deletions denser_retriever/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ def embed_query(self, text):
return embeddings


class BGEEmbeddings(DenserEmbeddings):
def __init__(self, model_name: str, embedding_size: int):
try:
from FlagEmbedding import FlagICLModel
except ImportError as exc:
raise ImportError(
"Could not import FlagEmbedding python package."
) from exc

self.client = FlagICLModel(model_name,
query_instruction_for_retrieval="Given a web search query, retrieve relevant passages that answer the query.",
examples_for_task=None, # set `examples_for_task=None` to use model without examples
use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
self.embedding_size = embedding_size

def embed_documents(self, texts):
return self.client.encode_corpus(texts)

def embed_query(self, text):
return self.client.encode_queries(text)

class VoyageAPIEmbeddings(DenserEmbeddings):
def __init__(self, api_key: str, model_name: str, embedding_size: int):
try:
Expand Down
126 changes: 72 additions & 54 deletions denser_retriever/keyword.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ def retrieve(
def get_index_mappings(self) -> Dict[Any, Any]:
raise NotImplementedError

@abstractmethod
def get_categories(self, field: str, k: int = 10) -> List[Any]:
raise NotImplementedError

@abstractmethod
def delete(
Expand Down Expand Up @@ -128,11 +125,12 @@ def __init__(
self.analysis = analysis
self.client = es_connection

def create_index(self, index_name: str, search_fields: List[str], **args: Any):
def create_index(self, index_name: str, search_fields: List[str], date_fields: List[str], **args: Any):

# Define the index settings and mappings
self.index_name = index_name
self.search_fields = FieldMapper(search_fields)
self.date_fields = date_fields

logger.info("ES analysis %s", self.analysis)
if self.analysis == "default":
Expand Down Expand Up @@ -243,7 +241,8 @@ def add_documents(
"source": metadata.get("source"),
"pid": metadata.get("pid"),
}
for filter in self.search_fields.get_keys():
# for filter in self.search_fields.get_keys():
for filter in metadata.keys():
value = metadata.get(filter, "")
if isinstance(value, list):
value = [v.strip() for v in value]
Expand Down Expand Up @@ -276,37 +275,48 @@ def add_documents(
return []

def retrieve(
self,
query: str,
k: int = 100,
filter: Dict[str, Any] = {},
) -> List[Tuple[Document, float]]:
self,
query: str,
k: int = 100,
filter: Dict[str, Any] = {},
aggregation: bool = False, # Aggregate metadata
) -> Tuple[List[Tuple[Document, float]], Dict]:
assert self.client.indices.exists(index=self.index_name)
start_time = time.time()

# Build the query with title and content matching and a minimum_should_match condition
query_dict = {
"query": {
"bool": {
"should": [
"must": [
{
"match": {
"title": {
"query": query,
"boost": 2.0,
}
"bool": {
"should": [
{
"match": {
"title": {
"query": query,
"boost": 2.0,
}
}
},
{
"match": {
"content": query,
}
}
],
"minimum_should_match": 1 # Ensure at least one of the should conditions is matched
}
},
{
"match": {
"content": query,
},
},
],
"must": [],
}
]
}
},
"_source": True,
"aggs": {}, # This will be populated with aggregations for fields
}

# Add filters if provided
for field in filter:
category_or_date = filter.get(field)
if category_or_date:
Expand All @@ -318,7 +328,7 @@ def retrieve(
"gte": category_or_date[0],
"lte": category_or_date[1]
if len(category_or_date) > 1
else category_or_date[0], # type: ignore
else category_or_date[0],
}
}
}
Expand All @@ -328,32 +338,59 @@ def retrieve(
{"term": {field: category_or_date}}
)

# Add aggregations for each field provided in 'fields' if aggregation is True
if aggregation:
for field in self.search_fields.get_keys():
query_dict["aggs"][f"{field}_aggregation"] = {
"terms": {
"field": f"{field}", # Use keyword type for aggregations
"size": 50 # Adjust size as needed
}
}

# Execute search query
res = self.client.search(
index=self.index_name,
body=query_dict,
size=k,
)

# Process search hits (documents)
top_k_used = min(len(res["hits"]["hits"]), k)
docs = []
for id in range(top_k_used):
_source = res["hits"]["hits"][id]["_source"]
doc = Document(
page_content=_source["content"],
metadata={
"source": _source["source"],
"title": _source["title"],
"pid": _source["pid"],
},
page_content=_source.pop("content"),
metadata=_source,
)
score = res["hits"]["hits"][id]["_score"]
for field in self.search_fields.get_keys():
if _source.get(field):
doc.metadata[field] = _source.get(field)
# import pdb; pdb.set_trace()
# for field in self.search_fields.get_keys():
# if _source.get(field):
# doc.metadata[field] = _source.get(field)
docs.append((doc, score))

# Process aggregations for the specified fields
aggregations = {}
for field in self.search_fields.get_keys():
field_agg = res.get("aggregations", {}).get(f"{field}_aggregation", {}).get("buckets", [])
cat_keys = [cat['key'] for cat in field_agg]
cat_counts = [cat['doc_count'] for cat in field_agg]
if len(cat_keys) > 0:
if field in self.date_fields:
sorted_data = sorted(zip(cat_keys, cat_counts), key=lambda x: x[0], reverse=True)
sorted_keys, sorted_counts = zip(*sorted_data)
cat_keys = list(sorted_keys)
cat_counts = list(sorted_counts)
aggregations[field] = (cat_keys, cat_counts)

retrieve_time_sec = time.time() - start_time
logger.info(f"Keyword retrieve time: {retrieve_time_sec:.3f} sec.")
logger.info(f"Retrieved {len(docs)} documents.")
return docs

# Return both documents and aggregation results
return docs, aggregations

def get_index_mappings(self):
mapping = self.client.indices.get_mapping(index=self.index_name)
Expand Down Expand Up @@ -382,25 +419,6 @@ def extract_fields(fields_dict, parent_name=""):
all_fields = extract_fields(properties)
return all_fields

def get_categories(self, field: str, k: int = 10):
query = {
"size": 0, # No actual documents are needed, just the aggregation results
"aggs": {
"all_categories": {
"terms": {
"field": field,
"size": 1000, # Adjust this value based on the expected number of unique categories
}
}
},
}
response = self.client.search(index=self.index_name, body=query)
# Extract the aggregation results
categories = response["aggregations"]["all_categories"]["buckets"]
if k > 0:
categories = categories[:k]
res = [category["key"] for category in categories]
return res

def delete(
self,
Expand Down
36 changes: 12 additions & 24 deletions denser_retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(
gradient_boost: Optional[DenserGradientBoost],
combine_mode: str = "linear",
xgb_model_features: str = "es+vs+rr_n",
search_fields: List[str] = []
search_fields: List[str] = [],
date_fields: List[str] = [],
):
# config parameters
self.index_name = index_name
Expand All @@ -61,7 +62,7 @@ def __init__(
assert embeddings
self.vector_db.create_index(index_name, embeddings, search_fields)
if self.keyword_search:
self.keyword_search.create_index(index_name, search_fields)
self.keyword_search.create_index(index_name, search_fields, date_fields)

def ingest(self, docs: List[Document], overwrite_pid: bool = True) -> List[str]:
# add pid into metadata for each document
Expand All @@ -80,22 +81,23 @@ def ingest(self, docs: List[Document], overwrite_pid: bool = True) -> List[str]:
return [doc.metadata["pid"] for doc in docs]

def retrieve(
self, query: str, k: int = 100, filter: Dict[str, Any] = {}, **kwargs: Any
self, query: str, k: int = 100, filter: Dict[str, Any]= {}, aggregation: bool = False, **kwargs: Any
):
logger.info(f"Retrieve query: {query} top_k: {k}")
if self.combine_mode in ["linear", "rank"]:
return self._retrieve_by_linear_or_rank(query, k, filter, **kwargs)
return self._retrieve_by_linear_or_rank(query, k, filter, aggregation, **kwargs)
else:
return self._retrieve_by_model(query, k, filter, **kwargs)
return self._retrieve_by_model(query, k, filter, aggregation, **kwargs)

def _retrieve_by_linear_or_rank(
self, query: str, k: int = 100, filter: Dict[str, Any] = {}, **kwargs: Any
self, query: str, k: int = 100, filter: Dict[str, Any] = {}, aggregation: bool = False, **kwargs: Any
):
passages = []
aggregations = None

if self.keyword_search:
es_docs = self.keyword_search.retrieve(
query, self.keyword_search.top_k, filter=filter, **kwargs
es_docs, aggregations = self.keyword_search.retrieve(
query, self.keyword_search.top_k, filter=filter, aggregation=aggregation, **kwargs
)
es_passages = scale_results(es_docs, self.keyword_search.weight)
logger.info(f"Keyword search: {len(es_passages)}")
Expand Down Expand Up @@ -125,10 +127,10 @@ def _retrieve_by_linear_or_rank(
rerank_time_sec = time.time() - start_time
logger.info(f"Rerank time: {rerank_time_sec:.3f} sec.")

return passages[:k]
return passages[:k], aggregations

def _retrieve_by_model(
self, query: str, k: int = 100, filter: Dict[str, Any] = {}, **kwargs: Any
self, query: str, k: int = 100, filter: Dict[str, Any] = {}, aggregation: bool = False, **kwargs: Any
) -> List[Tuple[Document, float]]:
docs, doc_features = self._retrieve_with_features(query, filter, **kwargs)

Expand Down Expand Up @@ -262,20 +264,6 @@ def delete_all(self):
if self.keyword_search:
self.keyword_search.delete_all()

def get_field_categories(self, field, k: int = 10):
"""
Get the categories of a field.
Args:
field: The field to get the categories of.
k: The number of categories to return.
Returns:
A list of categories.
"""
if not self.keyword_search:
raise ValueError("Keyword search not initialized")
return self.keyword_search.get_categories(field, k)

def get_filter_fields(self):
"""Get the filter fields."""
Expand Down
9 changes: 5 additions & 4 deletions denser_retriever/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, List, Optional, Tuple

import numpy as np
import pytrec_eval
# import pytrec_eval

from scipy.sparse import csr_matrix
from collections import defaultdict
Expand Down Expand Up @@ -41,9 +41,10 @@ def evaluate(
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
recall_string = "recall." + ",".join([str(k) for k in k_values])
precision_string = "P." + ",".join([str(k) for k in k_values])
evaluator = pytrec_eval.RelevanceEvaluator(
qrels, {map_string, ndcg_string, recall_string, precision_string}
)
# evaluator = pytrec_eval.RelevanceEvaluator(
# qrels, {map_string, ndcg_string, recall_string, precision_string}
# )
evaluator = None
scores = evaluator.evaluate(results)

for query_id in scores.keys():
Expand Down
7 changes: 5 additions & 2 deletions denser_retriever/vectordb/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def create_index(
self.embeddings = embeddings
self.source_max_length = 500
self.title_max_length = 500
self.text_max_length = 8000
self.text_max_length = 30000
self.field_max_length = 500

self.connection_args = self.connection_args or DEFAULT_MILVUS_CONNECTION
Expand Down Expand Up @@ -207,7 +207,10 @@ def add_documents(
doc.metadata.get("source", "")[: self.source_max_length - 10]
)
titles.append(doc.metadata.get("title", "")[: self.title_max_length - 10])
texts.append(doc.page_content[: self.text_max_length - 1000]) # buffer
truncated_text = doc.page_content[:10000]
if len(truncated_text) >= self.text_max_length:
print(f"Truncated text length: {len(truncated_text)} longer than {self.text_max_length}")
texts.append(truncated_text)
pid_list.append(doc.metadata.get("pid", "-1"))

for i, field_original_key in enumerate(
Expand Down
3 changes: 2 additions & 1 deletion examples/denser_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def denser_search():
format="MM.DD.YYYY",
)
else:
categories = retriever.get_field_categories(field, 10)
# categories = retriever.get_field_categories(field, 10)
_, categories = retriever.retrieve("", 0, {}, True) ## TODO
option = st.sidebar.selectbox(
field,
tuple(categories),
Expand Down
3 changes: 2 additions & 1 deletion examples/denser_search_cpws.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def denser_search():
format="MM.DD.YYYY",
)
else:
categories = retriever.get_field_categories(field, 10)
# categories = retriever.get_field_categories(field, 10)
_, categories = retriever.retrieve("", 0, {}, True)
option = st.sidebar.selectbox(
field,
tuple(categories),
Expand Down
Loading

0 comments on commit 3ade108

Please sign in to comment.