Skip to content

Commit

Permalink
Merge pull request #419 from AbhishekRP2002/main
Browse files Browse the repository at this point in the history
feat: add milvus vector db integration
  • Loading branch information
mnvsk97 authored Jan 3, 2025
2 parents b509cb8 + 1258f3a commit 431d682
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 7 deletions.
2 changes: 1 addition & 1 deletion backend/modules/vector_db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_embedding_dimensions(self, embeddings: Embeddings) -> int:
Fetch embedding dimensions
"""
# Calculate embedding size
logger.debug(f"[VectorDB] Embedding a dummy doc to get vector dimensions")
logger.debug("Embedding a dummy doc to get vector dimensions")
partial_embeddings = embeddings.embed_documents(["Initial document"])
vector_size = len(partial_embeddings[0])
logger.debug(f"Vector size: {vector_size}")
Expand Down
300 changes: 300 additions & 0 deletions backend/modules/vector_db/milvus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
from typing import List

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain_milvus import Milvus
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient

from backend.constants import (
DATA_POINT_FQN_METADATA_KEY,
DATA_POINT_HASH_METADATA_KEY,
DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
)
from backend.logger import logger
from backend.modules.vector_db.base import BaseVectorDB
from backend.types import DataPointVector, VectorDBConfig

MAX_SCROLL_LIMIT = int(1e6)
BATCH_SIZE = 1000


class MilvusVectorDB(BaseVectorDB):
def __init__(self, config: VectorDBConfig):
"""
Initialize Milvus vector database client
Args:
:param config: VectorDBConfig
- provider: str
- local: bool
- url: str
URI of the Milvus server.
- If you only need a local vector database for small scale data or prototyping,
setting the uri as a local file, e.g.`./milvus.db`, is the most convenient method,
as it automatically utilizes [Milvus Lite](https://milvus.io/docs/milvus_lite.md)
to store all data in this file.
- If you have large scale of data, say more than a million vectors, you can set up
a more performant Milvus server on [Docker or Kubernetes](https://milvus.io/docs/quickstart.md).
In this setup, please use the server address and port as your uri, e.g.`http://localhost:19530`.
If you enable the authentication feature on Milvus,
use "<your_username>:<your_password>" as the token, otherwise don't set the token.
- If you use [Zilliz Cloud](https://zilliz.com/cloud), the fully managed cloud
service for Milvus, adjust the `uri` and `token`, which correspond to the
[Public Endpoint and API key](https://docs.zilliz.com/docs/on-zilliz-cloud-console#cluster-details)
- api_key: str
Token for authentication with the Milvus server.
"""
# TODO: create an extended config for Milvus like done in Qdrant
logger.debug(f"Connecting to Milvus using config: {config.model_dump()}")
self.config = config
self.metric_type = config.config.get("metric_type", "COSINE")
# Milvus-lite is used for local == True
if config.local is True:
# TODO: make this path customizable
self.url = "./cognita_milvus.db"
self.api_key = ""
self.milvus_client = MilvusClient(
uri=self.url,
db_name=config.config.get("db_name", "milvus_default_db"),
)
else:
self.url = config.url
self.api_key = config.api_key
if not self.api_key:
api_key = None

self.milvus_client = MilvusClient(
uri=self.url,
token=api_key,
db_name=config.config.get("db_name", "milvus_default_db"),
)

def create_collection(self, collection_name: str, embeddings: Embeddings):
"""
Create a collection in the vector database
Args:
:param collection_name: str - Name of the collection
:param embeddings: Embeddings - Embeddings object to be used for creating embeddings of the documents
Current implementation includes Quick setup in which the collection is created, indexed and loaded into the memory.
"""
# TODO: Add customized setup with indexed params
logger.debug(f"[Milvus] Creating new collection {collection_name}")

vector_size = self.get_embedding_dimensions(embeddings)

fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="metadata", dtype=DataType.JSON),
]

schema = CollectionSchema(
fields=fields, description=f"Collection for {collection_name}"
)

self.milvus_client.create_collection(
collection_name=collection_name,
dimension=vector_size,
metric_type=self.metric_type, # https://milvus.io/docs/metric.md#Metric-Types : check for other supported metrics
schema=schema,
auto_id=True,
)

# Can use this to create custom multiple indices
index_params = self.milvus_client.prepare_index_params()
index_params.add_index(
field_name="vector", index_type="FLAT", metric_type=self.metric_type
)
self.milvus_client.create_index(
collection_name=collection_name, index_params=index_params
)

logger.debug(f"[Milvus] Created new collection {collection_name}")

def _delete_existing_documents(
self, collection_name: str, documents: List[Document]
):
"""
Delete existing documents from the collection
"""
# Instead of using document IDs, we'll delete based on metadata matching
for doc in documents:
if (
DATA_POINT_FQN_METADATA_KEY in doc.metadata
and DATA_POINT_HASH_METADATA_KEY in doc.metadata
):
delete_expr = (
f'metadata["{DATA_POINT_FQN_METADATA_KEY}"] == "{doc.metadata[DATA_POINT_FQN_METADATA_KEY]}" && '
f'metadata["{DATA_POINT_HASH_METADATA_KEY}"] == "{doc.metadata[DATA_POINT_HASH_METADATA_KEY]}"'
)

logger.debug(
f"[Milvus] Deleting records matching expression: {delete_expr}"
)

self.milvus_client.delete(
collection_name=collection_name,
filter=delete_expr,
)

def upsert_documents(
self,
collection_name: str,
documents: List[Document],
embeddings: Embeddings,
incremental: bool = True,
):
"""
Upsert documents in the database.
Upsert = Insert / update
- Check if collection exists or not
- Check if collection is empty or not
- If collection is empty, insert all documents
- If collection is not empty, delete existing documents and insert new documents
"""
if len(documents) == 0:
logger.warning("No documents to index")
return

logger.debug(
f"[Milvus] Adding {len(documents)} documents to collection {collection_name}"
)

if not self.milvus_client.has_collection(collection_name):
raise Exception(
f"Collection {collection_name} does not exist. Please create it first using `create_collection`."
)

stats = self.milvus_client.get_collection_stats(collection_name=collection_name)
if stats["row_count"] == 0:
logger.warning(
f"[Milvus] Collection {collection_name} is empty. Inserting all documents."
)
self.get_vector_store(collection_name, embeddings).add_documents(
documents=documents
)

if incremental and len(documents) > 0:
self._delete_existing_documents(collection_name, documents)

self.get_vector_store(collection_name, embeddings).add_documents(
documents=documents
)

logger.debug(
f"[Milvus] Upserted {len(documents)} documents to collection {collection_name}"
)

def get_collections(self) -> List[str]:
logger.debug("[Milvus] Fetching collections from the vector database")
collections = self.milvus_client.list_collections()
logger.debug(f"[Milvus] Fetched {len(collections)} collections")
return collections

def delete_collection(self, collection_name: str):
logger.debug(f"[Milvus] Deleting {collection_name} collection")
self.milvus_client.drop_collection(collection_name)
logger.debug(f"[Milvus] Deleted {collection_name} collection")

def get_vector_store(self, collection_name: str, embeddings: Embeddings):
logger.debug(f"[Milvus] Getting vector store for collection {collection_name}")
return Milvus(
collection_name=collection_name,
connection_args={
"uri": self.url,
"token": self.api_key,
},
embedding_function=embeddings,
auto_id=True,
primary_field="id",
text_field="text",
metadata_field="metadata",
)

def get_vector_client(self):
logger.debug("[Milvus] Getting Milvus client")
return self.milvus_client

def list_data_point_vectors(
self,
collection_name: str,
data_source_fqn: str,
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
) -> List[DataPointVector]:
"""
Get vectors from the collection
"""
logger.debug(
f"[Milvus] Listing data point vectors for collection {collection_name}"
)
filter_expr = (
f'metadata["{DATA_POINT_FQN_METADATA_KEY}"] == "{data_source_fqn}"'
)

data_point_vectors: List[DataPointVector] = []

offset = 0

while True:
search_result = self.milvus_client.query(
collection_name=collection_name,
filter=filter_expr,
output_fields=[
"*"
], # returning all the fields of the entity / data point
limit=batch_size,
offset=offset,
)

for result in search_result:
if result.get("metadata", {}).get(
DATA_POINT_FQN_METADATA_KEY
) and result.get("metadata", {}).get(DATA_POINT_HASH_METADATA_KEY):
data_point_vectors.append(
DataPointVector(
data_point_vector_id=str(result["id"]),
data_point_fqn=result["metadata"][
DATA_POINT_FQN_METADATA_KEY
],
data_point_hash=result["metadata"][
DATA_POINT_HASH_METADATA_KEY
],
)
)

if (
len(search_result) < batch_size
or len(data_point_vectors) >= MAX_SCROLL_LIMIT
):
break

offset += batch_size

logger.debug(f"[Milvus] Listed {len(data_point_vectors)} data point vectors")

return data_point_vectors

def delete_data_point_vectors(
self,
collection_name: str,
data_point_vectors: List[DataPointVector],
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
):
"""
Delete vectors from the collection
"""
logger.debug(f"[Milvus] Deleting {len(data_point_vectors)} data point vectors")

for i in range(0, len(data_point_vectors), batch_size):
batch_vectors = data_point_vectors[i : i + batch_size]

delete_expr = " or ".join(
[f"id == {vector.data_point_vector_id}" for vector in batch_vectors]
)

self.milvus_client.delete(
collection_name=collection_name, filter=delete_expr
)

logger.debug(f"[Milvus] Deleted {len(data_point_vectors)} data point vectors")
6 changes: 3 additions & 3 deletions backend/modules/vector_db/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def _create_search_index(self, collection_name: str, embeddings: Embeddings):
result = self.db[collection_name].create_search_index(model=search_index_model)
logger.debug(f"New search index named {result} is building.")

# Immediate avaialbility of the index is not guaranteed upon creation.
# Immediate availability of the index is not guaranteed upon creation.
# MongoDB documentation recommends polling for the index to be ready.
# Ensure this check to provide a seamless experience.
# TODO (mnvsk97): We might want to introduce a new status in the ingestion runs to reflex this.
logger.debug(
"Polling to check if the index is ready. This may take up to a minute."
)
predicate = lambda index: index.get("queryable") is True
predicate = lambda index: index.get("queryable") is True # noqa: E731
while True:
indices = list(
self.db[collection_name].list_search_indexes("vector_search_index")
Expand All @@ -96,7 +96,7 @@ def upsert_documents(
f"[Mongo] Adding {len(documents)} documents to collection {collection_name}"
)

"""Upsert documenlots with their embeddings"""
"""Upsert documents with their embeddings"""
collection = self.db[collection_name]

data_point_fqns = []
Expand Down
5 changes: 2 additions & 3 deletions backend/modules/vector_db/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def create_collection(self, collection_name: str, embeddings: Embeddings):
logger.debug(f"[Qdrant] Creating new collection {collection_name}")

# Calculate embedding size
logger.debug(f"[Qdrant] Embedding a dummy doc to get vector dimensions")
partial_embeddings = embeddings.embed_documents(["Initial document"])
vector_size = len(partial_embeddings[0])
logger.debug(f"Vector size: {vector_size}")
Expand Down Expand Up @@ -166,7 +165,7 @@ def upsert_documents(
)

def get_collections(self) -> List[str]:
logger.debug(f"[Qdrant] Fetching collections")
logger.debug("[Qdrant] Fetching collections")
collections = self.qdrant_client.get_collections().collections
logger.debug(f"[Qdrant] Fetched {len(collections)} collections")
return [collection.name for collection in collections]
Expand All @@ -185,7 +184,7 @@ def get_vector_store(self, collection_name: str, embeddings: Embeddings):
)

def get_vector_client(self):
logger.debug(f"[Qdrant] Getting Qdrant client")
logger.debug("[Qdrant] Getting Qdrant client")
return self.qdrant_client

def list_data_point_vectors(
Expand Down
5 changes: 5 additions & 0 deletions backend/vectordb.requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ weaviate-client==3.25.3
### MongoDB
pymongo==4.10.1
langchain-mongodb==0.2.0


### Milvus
pymilvus==2.4.10
langchain-milvus==0.1.7
2 changes: 2 additions & 0 deletions compose.env
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ ML_REPO_NAME=''
VECTOR_DB_CONFIG='{"provider":"qdrant","url":"http://qdrant-server:6333", "config": {"grpc_port": 6334, "prefer_grpc": false}}'
# MONGO Example
# VECTOR_DB_CONFIG='{"provider":"mongo","url":"connection_uri", "config": {"database_name": "cognita"}}'
# Milvus Example
# VECTOR_DB_CONFIG='{"provider":"Milvus", "url":"connection_uri", "api_key":"milvus_auth_token", "config":{"db_name":"cognita", "metric_type":"COSINE"}}'
COGNITA_BACKEND_PORT=8000

UNSTRUCTURED_IO_URL=http://unstructured-io-parsers:9500/
Expand Down

0 comments on commit 431d682

Please sign in to comment.