Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for weaviate vector database #353

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 94 additions & 55 deletions backend/modules/vector_db/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import Any, Dict, List

import weaviate
import weaviate.classes as wvc
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores.weaviate import Weaviate
from langchain_weaviate.vectorstores import WeaviateVectorStore
from langchain_core.documents import Document

from backend.constants import DATA_POINT_FQN_METADATA_KEY
from backend.modules.vector_db.base import BaseVectorDB
from backend.types import DataPointVector, VectorDBConfig
from backend.logger import logger

BATCH_SIZE = 1000
MAX_SCROLL_LIMIT = int(1e6)

def decapitalize(s):
if not s:
Expand All @@ -18,29 +22,56 @@ def decapitalize(s):

class WeaviateVectorDB(BaseVectorDB):
def __init__(self, config: VectorDBConfig):
self.url = config.url
self.api_key = config.api_key
self.weaviate_client = weaviate.Client(
url=self.url,
**(
{"auth_client_secret": weaviate.AuthApiKey(api_key=self.api_key)}
if self.api_key
else {}
logger.debug(f"[Weaviate] Connecting using config: {config.model_dump()}")
if config.local is True:
self.weaviate_client = weaviate.connect_to_local()
else:
self.weaviate_client = weaviate.connect_to_weaviate_cloud(
cluster_url=config.url,
auth_credentials=wvc.init.Auth.api_key(config.api_key)
)

def create_collection(self, collection_name: str, embeddings: Embeddings):
logger.debug(f"[Weaviate] Creating new collection {collection_name}")
self.weaviate_client.collections.create(
name=collection_name.capitalize(),
replication_config=wvc.config.Configure.replication(
factor=1
),
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(name=DATA_POINT_FQN_METADATA_KEY, data_type=wvc.config.DataType.TEXT)
]
)
logger.debug(f"[Weaviate] Created new collection {collection_name}")

def create_collection(self, collection_name: str, embeddings: Embeddings):
self.weaviate_client.schema.create_class(
{
"class": collection_name.capitalize(),
"properties": [
{
"name": f"{DATA_POINT_FQN_METADATA_KEY}",
"dataType": ["text"],
},
],
}
def _get_records_to_be_updated(self, collection_name: str, data_point_fqns: List[str]):
logger.debug(
f"[Weaviate] Incremental Ingestion: Fetching documents for {len(data_point_fqns)} data point fqns for collection {collection_name}"
)
stop = False
offset = 0
record_ids_to_be_updated = []
while stop is not True:
records = self.weaviate_client.collections \
.get(collection_name.capitalize()).query \
.fetch_objects(
limit=BATCH_SIZE,
filters=wvc.query.Filter.by_property(DATA_POINT_FQN_METADATA_KEY).contains_any(data_point_fqns),
offset=offset,
return_properties=[DATA_POINT_FQN_METADATA_KEY]
)
if not records or len(records.objects) < BATCH_SIZE or len(record_ids_to_be_updated) > MAX_SCROLL_LIMIT:
stop = True
for record in records.objects:
record_ids_to_be_updated.append(record.uuid)
offset += BATCH_SIZE
logger.debug(
f"[Weaviate] Incremental Ingestion: collection={collection_name} Addition={len(data_point_fqns)}, Updates={len(record_ids_to_be_updated)}"
)
return record_ids_to_be_updated



def upsert_documents(
self,
Expand All @@ -59,25 +90,59 @@ def upsert_documents(
Returns:
- None
"""
Weaviate.from_documents(
if len(documents) == 0:
logger.warning("No documents to index")
return
logger.debug(
f"[Weaviate] Adding {len(documents)} documents to collection {collection_name}"
)

data_point_fqns = []
for document in documents:
if document.metadata.get(DATA_POINT_FQN_METADATA_KEY):
data_point_fqns.append(
document.metadata.get(DATA_POINT_FQN_METADATA_KEY)
)
records_to_be_updated:List[str]
if incremental:
records_to_be_updated = self._get_records_to_be_updated(collection_name, data_point_fqns)

WeaviateVectorStore.from_documents(
documents=documents,
embedding=embeddings,
client=self.weaviate_client,
index_name=collection_name.capitalize(),
)
logger.debug(
f"[Weaviate] Added {len(documents)} documents to collection {collection_name}"
)

if len(records_to_be_updated) > 0:
logger.debug(
f"[Weaviate] Deleting {len(records_to_be_updated)} outdated documents from collection {collection_name}"
)
collection = self.weaviate_client.collections.get(collection_name.capitalize())
for i in range(0, len(records_to_be_updated), BATCH_SIZE):
record_ids_to_be_processed = records_to_be_updated[i : i + BATCH_SIZE]
collection.data.delete_many(
where=wvc.query.Filter.by_id().contains_any(record_ids_to_be_processed)
)
logger.debug(
f"[Weaviate] Deleted {len(records_to_be_updated)} outdated documents from collection {collection_name}"
)

def get_collections(self) -> List[str]:
collections = self.weaviate_client.schema.get().get("classes", [])
return [decapitalize(collection["class"]) for collection in collections]
collections = self.weaviate_client.collections.list_all(simple=True)
return list(collections.keys())

def delete_collection(
self,
collection_name: str,
):
return self.weaviate_client.schema.delete_class(collection_name.capitalize())
return self.weaviate_client.collections.delete(collection_name.capitalize())

def get_vector_store(self, collection_name: str, embeddings: Embeddings):
return Weaviate(
return WeaviateVectorStore(
client=self.weaviate_client,
embedding=embeddings,
index_name=collection_name.capitalize(), # Weaviate stores the index name as capitalized
Expand All @@ -92,40 +157,13 @@ def list_documents_in_collection(
"""
List all documents in a collection
"""
# https://weaviate.io/developers/weaviate/search/aggregate#retrieve-groupedby-properties
response = (
self.weaviate_client.query.aggregate(collection_name.capitalize())
.with_group_by_filter([f"{DATA_POINT_FQN_METADATA_KEY}"])
.with_fields("groupedBy { value }")
.do()
)
groups: List[Dict[Any, Any]] = (
response.get("data", {})
.get("Aggregate", {})
.get(collection_name.capitalize(), [])
)
document_ids = set()
for group in groups:
# TODO (chiragjn): Revisit this, we should not be letting `value` be empty
document_ids.add(group.get("groupedBy", {}).get("value", "") or "")
return list(document_ids)
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function can be removed as it is not in the base class.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def delete_documents(self, collection_name: str, document_ids: List[str]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function can be removed as it is not in the base class.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
Delete documents from the collection that match given `document_id_match`
"""
# https://weaviate.io/developers/weaviate/manage-data/delete#delete-multiple-objects
res = self.weaviate_client.batch.delete_objects(
class_name=collection_name.capitalize(),
where={
"path": [f"{DATA_POINT_FQN_METADATA_KEY}"],
"operator": "ContainsAny",
"valueTextArray": document_ids,
},
)
deleted_vectors = res.get("results", {}).get("successful", None)
if deleted_vectors:
print(f"Deleted {len(document_ids)} documents from the collection")
pass

def get_vector_client(self):
return self.weaviate_client
Expand All @@ -136,7 +174,8 @@ def list_data_point_vectors(
data_source_fqn: str,
batch_size: int = 1000,
) -> List[DataPointVector]:
pass
document_vector_points: List[DataPointVector] = []
return document_vector_points

def delete_data_point_vectors(
self,
Expand Down
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ langchain==0.1.9
langchain-community==0.0.24
langchain-openai==0.1.7
langchain-core==0.1.46
langchain-weaviate==0.0.3
openai==1.35.3
tiktoken==0.7.0
uvicorn[standard]==0.23.2
Expand Down
2 changes: 1 addition & 1 deletion backend/vectordb.requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
singlestoredb==1.0.4

### Weaviate client (in progress)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove this in progress comment.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

weaviate-client==3.25.3
weaviate-client==4.7.1