Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update project structure #7

Merged
merged 12 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: Deploy to aws

on:
push:
branches:
- main

jobs:
deploy:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Setup Node.js
uses: actions/setup-node@v3
with:
node-version: "20"

- name: Setup pnpm
uses: pnpm/[email protected]

- name: Cache pnpm modules
uses: actions/cache@v3
env:
cache-name: cache-pnpm-modules
with:
path: ~/.pnpm-store
key: cache-pnpm-modules-key

- name: Change directory
run: cd www

- name: Install dependencies
run: pnpm i --no-frozen-lockfile

- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v2
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: us-west-2

- name: Setup sst
run: curl -fsSL https://ion.sst.dev/install | bash

- name: Deploy
run: NO_BUN=true sst deploy --stage production
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
Expand Down Expand Up @@ -99,4 +98,6 @@ ENV/
.mypy_cache/
.idea/

.DS_Store
.DS_Store

node_modules/
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.analysis.typeCheckingMode": "off"
}
11 changes: 2 additions & 9 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,10 @@

## Quick start

Conda package manager is recommended. Create a conda environment.
Global Install poetry with pipx or another:

```bash
conda create -n denser-retriever python==3.10
```

Activate conda environment and install poetry

```bash
conda activate denser-retriever
pip install poetry
pipx install poetry
```

Then you can run the client using the following command:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ python -m pytest tests
After that, you can launch a simple end-to-end [streamlit](https://streamlit.io/) app with the retriever indices built from the above unit tests.

```shell
streamlit run examples/denser_chat.py
streamlit run examples/denser_search.py
poetry run streamlit run examples/denser_chat.py
poetry run streamlit run examples/denser_search.py
```

To evaluate retrievers' accuracy on [mteb](https://github.com/embeddings-benchmark/mteb) benchmark datasets, you can run the following command. This command evaluates the baselines of elasticsearch, vector and reranker approaches and reports the NDCG@10 scores on the test data. In addition, it uses the mteb training data to train a xgboost classifier to effectively combine the scores of elasticsearch, vector and reranker. Finally, the classifier is used to generate ranked results and the NDCG@10 score of the ranked results is also reported.
Expand Down
3 changes: 2 additions & 1 deletion denser_retriever/reranker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from sentence_transformers import CrossEncoder
import copy

from sentence_transformers import CrossEncoder


class Reranker:
def __init__(self, rerank_model):
Expand Down
5 changes: 3 additions & 2 deletions denser_retriever/retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from abc import ABC, abstractmethod

import yaml
import os


class Retriever(ABC):
"""
Expand Down Expand Up @@ -32,7 +34,6 @@ def __init__(self, index_name, config_file):
if not os.path.exists(self.exp_dir):
os.makedirs(self.exp_dir)


@abstractmethod
def ingest(self, data):
pass
Expand Down
43 changes: 20 additions & 23 deletions denser_retriever/retriever_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,9 @@ def retrieve(self, query_text, meta_data, query_id=None):
}
}
},
{
"match": {
"content": query_text
}
}
{"match": {"content": query_text}},
],
"must": [
]
"must": [],
}
},
"_source": True,
Expand All @@ -139,14 +134,16 @@ def retrieve(self, query_text, meta_data, query_id=None):
category_or_date = meta_data.get(field)
if category_or_date:
if isinstance(category_or_date, tuple):
query_dict["query"]["bool"]["must"].append({
"range": {
field: {
"gte": category_or_date[0],
"lte": category_or_date[1] if len(category_or_date) > 1 else category_or_date[0]
query_dict["query"]["bool"]["must"].append(
{
"range": {
field: {
"gte": category_or_date[0],
"lte": category_or_date[1] if len(category_or_date) > 1 else category_or_date[0],
}
}
}
})
)
else:
query_dict["query"]["bool"]["must"].append({"term": {field: category_or_date}})

Expand All @@ -172,17 +169,17 @@ def get_index_mappings(self):
mapping = self.es.indices.get_mapping(index=self.index_name)

# The mapping response structure can be quite nested, focusing on the 'properties' section
properties = mapping[self.index_name]['mappings']['properties']
properties = mapping[self.index_name]["mappings"]["properties"]

# Function to recursively extract fields and types
def extract_fields(fields_dict, parent_name=''):
def extract_fields(fields_dict, parent_name=""):
fields = {}
for field_name, details in fields_dict.items():
full_field_name = f"{parent_name}.{field_name}" if parent_name else field_name
if 'properties' in details:
fields.update(extract_fields(details['properties'], full_field_name))
if "properties" in details:
fields.update(extract_fields(details["properties"], full_field_name))
else:
fields[full_field_name] = details.get('type', 'notype') # Default 'notype' if no type is found
fields[full_field_name] = details.get("type", "notype") # Default 'notype' if no type is found
return fields

# Extract fields and types
Expand All @@ -196,15 +193,15 @@ def get_categories(self, field, topk):
"all_categories": {
"terms": {
"field": field,
"size": 1000 # Adjust this value based on the expected number of unique categories
"size": 1000, # Adjust this value based on the expected number of unique categories
}
}
}
},
}
response = self.es.search(index=self.index_name, body=query)
# Extract the aggregation results
categories = response['aggregations']['all_categories']['buckets']
categories = response["aggregations"]["all_categories"]["buckets"]
if topk > 0:
categories = categories[:topk]
res = [category['key'] for category in categories]
return res
res = [category["key"] for category in categories]
return res
13 changes: 7 additions & 6 deletions denser_retriever/retriever_general.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import time


from denser_retriever.reranker import Reranker
from denser_retriever.retriever import Retriever
from denser_retriever.retriever_elasticsearch import RetrieverElasticSearch
Expand All @@ -11,15 +10,18 @@


class RetrieverGeneral(Retriever):

def __init__(self, index_name, config_file):
super().__init__(index_name, config_file)
self.retrieve_type = "general"
self.retrieverElasticSearch = (
RetrieverElasticSearch(index_name, config_file) if self.config["keyword_weight"] > 0 else None
)
self.retrieverMilvus = RetrieverMilvus(index_name, config_file) if self.config["vector_weight"] > 0 else None
self.reranker = Reranker(self.config["rerank"]["rerank_model"]) if self.config["rerank_weight"] > 0 else None
self.reranker = (
Reranker(self.config["rerank"]["rerank_model"], self.out_reranker)
if self.config["rerank_weight"] > 0
else None
)

def ingest(self, doc_or_passage_file):
# import pdb; pdb.set_trace()
Expand Down Expand Up @@ -55,11 +57,10 @@ def retrieve(self, query, meta_data, query_id=None):
logger.info(f"Rerank time: {rerank_time_sec:.3f} sec.")

if len(passages) > self.config["rerank"]["topk"]:
passages = passages[:self.config["rerank"]["topk"]]
passages = passages[: self.config["rerank"]["topk"]]

docs = aggregate_passages(passages)
return passages, docs


def get_field_categories(self, field, topk):
return self.retrieverElasticSearch.get_categories(field, topk)
return self.retrieverElasticSearch.get_categories(field, topk)
35 changes: 22 additions & 13 deletions denser_retriever/retriever_milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ def create_index(self):
for key in self.field_types:
internal_key = self.field_internal_names[key]
# both category and date type (unix timestamp) use INT64 type
fields.append(FieldSchema(name=internal_key, dtype=DataType.INT64, max_length=self.field_max_length),)
fields.append(
FieldSchema(name=internal_key, dtype=DataType.INT64, max_length=self.field_max_length),
)
self.field_cat_to_id[key] = {}
self.field_id_to_cat[key] = []

schema = CollectionSchema(fields, "Milvus schema")
self.index = Collection(self.index_name, schema, consistency_level="Strong")


def connect_index(self):
connections.connect(
"default",
Expand All @@ -85,7 +86,9 @@ def connect_index(self):
for key in self.field_types:
internal_key = self.field_internal_names[key]
# both category and date type (unix timestamp) use INT64 type
fields.append(FieldSchema(name=internal_key, dtype=DataType.INT64, max_length=self.field_max_length), )
fields.append(
FieldSchema(name=internal_key, dtype=DataType.INT64, max_length=self.field_max_length),
)

schema = CollectionSchema(fields, "Milvus schema")
self.index = Collection(self.index_name, schema, consistency_level="Strong")
Expand All @@ -95,7 +98,7 @@ def connect_index(self):
output_prefix = self.config["output_prefix"]
exp_dir = os.path.join(output_prefix, f"exp_{self.index_name}")
fields_file = os.path.join(exp_dir, "milvus_fields.json")
with open(fields_file, 'r') as file:
with open(fields_file, "r") as file:
self.field_cat_to_id, self.field_id_to_cat = json.load(file)

def generate_embedding(self, passages):
Expand Down Expand Up @@ -125,15 +128,15 @@ def ingest(self, doc_or_passage_file, batch_size):
if category_or_date_str:
type = self.field_types[field]["type"]
if type == "date":
date_obj = datetime.strptime(category_or_date_str, '%Y-%m-%d')
date_obj = datetime.strptime(category_or_date_str, "%Y-%m-%d")
unix_time = int(date_obj.timestamp())
fieldss[i].append(unix_time)
else: # categorical
else: # categorical
if category_or_date_str not in self.field_cat_to_id[field]:
self.field_cat_to_id[field][category_or_date_str] = len(self.field_cat_to_id[field])
self.field_id_to_cat[field].append(category_or_date_str)
fieldss[i].append(self.field_cat_to_id[field][category_or_date_str])
else: # missing category value
else: # missing category value
fieldss[i].append(-1)
record_id += 1
if len(batch) == batch_size:
Expand Down Expand Up @@ -179,8 +182,10 @@ def ingest(self, doc_or_passage_file, batch_size):
if not os.path.exists(exp_dir):
os.makedirs(exp_dir)
fields_file = os.path.join(exp_dir, "milvus_fields.json")
with open(fields_file, 'w') as file:
json.dump([self.field_cat_to_id, self.field_id_to_cat], file, ensure_ascii=False, indent=4) # 'indent' for pretty printing
with open(fields_file, "w") as file:
json.dump(
[self.field_cat_to_id, self.field_id_to_cat], file, ensure_ascii=False, indent=4
) # 'indent' for pretty printing

def retrieve(self, query_text, meta_data, query_id=None):
if not self.index:
Expand Down Expand Up @@ -212,8 +217,12 @@ def retrieve(self, query_text, meta_data, query_id=None):
"params": {"nprobe": 10},
}
result = self.index.search(
query_embedding, "embeddings", search_params, limit=self.config["vector"]["topk"], expr=expr_str,
output_fields=["source", "title", "text", "pid"] + list(self.field_internal_names.values())
query_embedding,
"embeddings",
search_params,
limit=self.config["vector"]["topk"],
expr=expr_str,
output_fields=["source", "title", "text", "pid"] + list(self.field_internal_names.values()),
)

topk_used = min(len(result[0]), self.config["vector"]["topk"])
Expand All @@ -230,10 +239,10 @@ def retrieve(self, query_text, meta_data, query_id=None):
}
for field in self.field_types.keys():
internal_field = self.field_internal_names[field]
cat_id_or_unix_time = hit.entity.__dict__['fields'].get(internal_field)
cat_id_or_unix_time = hit.entity.__dict__["fields"].get(internal_field)
type = self.field_types[field]["type"]
if type == "date":
date = datetime.utcfromtimestamp(cat_id_or_unix_time).strftime('%Y-%m-%d')
date = datetime.utcfromtimestamp(cat_id_or_unix_time).strftime("%Y-%m-%d")
passage[field] = date
else:
passage[field] = self.field_id_to_cat[field][cat_id_or_unix_time]
Expand Down
Loading
Loading