Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add milvus vector db integration #419

Merged
merged 7 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/modules/vector_db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_embedding_dimensions(self, embeddings: Embeddings) -> int:
Fetch embedding dimensions
"""
# Calculate embedding size
logger.debug(f"[VectorDB] Embedding a dummy doc to get vector dimensions")
logger.debug("Embedding a dummy doc to get vector dimensions")
partial_embeddings = embeddings.embed_documents(["Initial document"])
vector_size = len(partial_embeddings[0])
logger.debug(f"Vector size: {vector_size}")
Expand Down
300 changes: 300 additions & 0 deletions backend/modules/vector_db/milvus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
from typing import List

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain_milvus import Milvus
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient

from backend.constants import (
DATA_POINT_FQN_METADATA_KEY,
DATA_POINT_HASH_METADATA_KEY,
DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
)
from backend.logger import logger
from backend.modules.vector_db.base import BaseVectorDB
from backend.types import DataPointVector, VectorDBConfig

MAX_SCROLL_LIMIT = int(1e6)
BATCH_SIZE = 1000


class MilvusVectorDB(BaseVectorDB):
def __init__(self, config: VectorDBConfig):
"""
Initialize Milvus vector database client
Args:
:param config: VectorDBConfig
- provider: str
- local: bool
- url: str
URI of the Milvus server.
- If you only need a local vector database for small scale data or prototyping,
setting the uri as a local file, e.g.`./milvus.db`, is the most convenient method,
as it automatically utilizes [Milvus Lite](https://milvus.io/docs/milvus_lite.md)
to store all data in this file.
- If you have large scale of data, say more than a million vectors, you can set up
a more performant Milvus server on [Docker or Kubernetes](https://milvus.io/docs/quickstart.md).
In this setup, please use the server address and port as your uri, e.g.`http://localhost:19530`.
If you enable the authentication feature on Milvus,
use "<your_username>:<your_password>" as the token, otherwise don't set the token.
- If you use [Zilliz Cloud](https://zilliz.com/cloud), the fully managed cloud
service for Milvus, adjust the `uri` and `token`, which correspond to the
[Public Endpoint and API key](https://docs.zilliz.com/docs/on-zilliz-cloud-console#cluster-details)
- api_key: str
Token for authentication with the Milvus server.
"""
# TODO: create an extended config for Milvus like done in Qdrant
logger.debug(f"Connecting to Milvus using config: {config.model_dump()}")
self.config = config
self.metric_type = config.config.get("metric_type", "COSINE")
# Milvus-lite is used for local == True
if config.local is True:
# TODO: make this path customizable
self.url = "./cognita_milvus.db"
self.api_key = ""
self.milvus_client = MilvusClient(
uri=self.url,
db_name=config.config.get("db_name", "milvus_default_db"),
)
else:
self.url = config.url
self.api_key = config.api_key
if not self.api_key:
api_key = None

self.milvus_client = MilvusClient(
uri=self.url,
token=api_key,
db_name=config.config.get("db_name", "milvus_default_db"),
)

def create_collection(self, collection_name: str, embeddings: Embeddings):
"""
Create a collection in the vector database
Args:
:param collection_name: str - Name of the collection
:param embeddings: Embeddings - Embeddings object to be used for creating embeddings of the documents
Current implementation includes Quick setup in which the collection is created, indexed and loaded into the memory.

"""
# TODO: Add customized setup with indexed params
logger.debug(f"[Milvus] Creating new collection {collection_name}")

vector_size = self.get_embedding_dimensions(embeddings)

fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="metadata", dtype=DataType.JSON),
]

schema = CollectionSchema(
fields=fields, description=f"Collection for {collection_name}"
)

self.milvus_client.create_collection(
collection_name=collection_name,
dimension=vector_size,
metric_type=self.metric_type, # https://milvus.io/docs/metric.md#Metric-Types : check for other supported metrics
schema=schema,
auto_id=True,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, this is Jinhong from Milvus. We have transitioned to the new MilvusClient interface and are actively working to phase out the use of the older ORM interface. In this context, your current approach to adding fields to a collection may be updated to the following:

schema = self.milvus_client.create_schema(
    auto_id=False, enable_dynamic_field=True
)

schema.add_field(
    field_name="id",
    datatype=DataType.INT64,
    is_primary=True,
)

schema.add_field(
    field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=vector_size
)

schema.add_field(
    field_name="text",
    datatype=DataType.VARCHAR,
    max_length=65535,
)

schema.add_field(
    field_name="metadata",
    datatype=DataType.JSON,
)

index_params = self.milvus_client.prepare_index_params()

index_params.add_index(
    field_name="vector",
    index_type="FLAT",
    metric_type=self.metric_type,
)

self.milvus_client.create_collection(
    collection_name=collection_name,
    schema=schema,
    index_params=index_params,
)


# Can use this to create custom multiple indices
index_params = self.milvus_client.prepare_index_params()
index_params.add_index(
field_name="vector", index_type="FLAT", metric_type=self.metric_type
)
self.milvus_client.create_index(
collection_name=collection_name, index_params=index_params
)

logger.debug(f"[Milvus] Created new collection {collection_name}")

def _delete_existing_documents(
self, collection_name: str, documents: List[Document]
):
"""
Delete existing documents from the collection
"""
# Instead of using document IDs, we'll delete based on metadata matching
for doc in documents:
if (
DATA_POINT_FQN_METADATA_KEY in doc.metadata
and DATA_POINT_HASH_METADATA_KEY in doc.metadata
):
delete_expr = (
f'metadata["{DATA_POINT_FQN_METADATA_KEY}"] == "{doc.metadata[DATA_POINT_FQN_METADATA_KEY]}" && '
f'metadata["{DATA_POINT_HASH_METADATA_KEY}"] == "{doc.metadata[DATA_POINT_HASH_METADATA_KEY]}"'
)

logger.debug(
f"[Milvus] Deleting records matching expression: {delete_expr}"
)

self.milvus_client.delete(
collection_name=collection_name,
filter=delete_expr,
)

def upsert_documents(
self,
collection_name: str,
documents: List[Document],
embeddings: Embeddings,
incremental: bool = True,
):
"""
Upsert documents in the database.
Upsert = Insert / update
- Check if collection exists or not
- Check if collection is empty or not
- If collection is empty, insert all documents
- If collection is not empty, delete existing documents and insert new documents
"""
if len(documents) == 0:
logger.warning("No documents to index")
return

logger.debug(
f"[Milvus] Adding {len(documents)} documents to collection {collection_name}"
)

if not self.milvus_client.has_collection(collection_name):
mnvsk97 marked this conversation as resolved.
Show resolved Hide resolved
raise Exception(
f"Collection {collection_name} does not exist. Please create it first using `create_collection`."
)

stats = self.milvus_client.get_collection_stats(collection_name=collection_name)
if stats["row_count"] == 0:
logger.warning(
f"[Milvus] Collection {collection_name} is empty. Inserting all documents."
)
self.get_vector_store(collection_name, embeddings).add_documents(
documents=documents
)

if incremental and len(documents) > 0:
self._delete_existing_documents(collection_name, documents)

self.get_vector_store(collection_name, embeddings).add_documents(
documents=documents
)

logger.debug(
f"[Milvus] Upserted {len(documents)} documents to collection {collection_name}"
)

def get_collections(self) -> List[str]:
logger.debug("[Milvus] Fetching collections from the vector database")
collections = self.milvus_client.list_collections()
logger.debug(f"[Milvus] Fetched {len(collections)} collections")
return collections

def delete_collection(self, collection_name: str):
logger.debug(f"[Milvus] Deleting {collection_name} collection")
self.milvus_client.drop_collection(collection_name)
logger.debug(f"[Milvus] Deleted {collection_name} collection")

def get_vector_store(self, collection_name: str, embeddings: Embeddings):
logger.debug(f"[Milvus] Getting vector store for collection {collection_name}")
return Milvus(
collection_name=collection_name,
connection_args={
"uri": self.url,
"token": self.api_key,
},
embedding_function=embeddings,
auto_id=True,
primary_field="id",
text_field="text",
metadata_field="metadata",
)

def get_vector_client(self):
logger.debug("[Milvus] Getting Milvus client")
return self.milvus_client

def list_data_point_vectors(
self,
collection_name: str,
data_source_fqn: str,
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
) -> List[DataPointVector]:
"""
Get vectors from the collection
"""
logger.debug(
f"[Milvus] Listing data point vectors for collection {collection_name}"
)
filter_expr = (
f'metadata["{DATA_POINT_FQN_METADATA_KEY}"] == "{data_source_fqn}"'
)

data_point_vectors: List[DataPointVector] = []

offset = 0

while True:
search_result = self.milvus_client.query(
collection_name=collection_name,
filter=filter_expr,
output_fields=[
"*"
], # returning all the fields of the entity / data point
limit=batch_size,
offset=offset,
)

for result in search_result:
if result.get("metadata", {}).get(
DATA_POINT_FQN_METADATA_KEY
) and result.get("metadata", {}).get(DATA_POINT_HASH_METADATA_KEY):
data_point_vectors.append(
DataPointVector(
data_point_vector_id=str(result["id"]),
data_point_fqn=result["metadata"][
DATA_POINT_FQN_METADATA_KEY
],
data_point_hash=result["metadata"][
DATA_POINT_HASH_METADATA_KEY
],
)
)

if (
len(search_result) < batch_size
or len(data_point_vectors) >= MAX_SCROLL_LIMIT
):
break

offset += batch_size

logger.debug(f"[Milvus] Listed {len(data_point_vectors)} data point vectors")

return data_point_vectors

def delete_data_point_vectors(
self,
collection_name: str,
data_point_vectors: List[DataPointVector],
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
):
"""
Delete vectors from the collection
"""
logger.debug(f"[Milvus] Deleting {len(data_point_vectors)} data point vectors")

for i in range(0, len(data_point_vectors), batch_size):
batch_vectors = data_point_vectors[i : i + batch_size]

delete_expr = " or ".join(
[f"id == {vector.data_point_vector_id}" for vector in batch_vectors]
)

self.milvus_client.delete(
collection_name=collection_name, filter=delete_expr
)

logger.debug(f"[Milvus] Deleted {len(data_point_vectors)} data point vectors")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding the following 3 functions: list_documents_in_collection, delete_documents, and list_document_vector_points Here is my implementation based on your code:

def list_documents_in_collection(
    self, collection_name: str, base_document_id: str = None
) -> List[str]:
    """
    List all documents in a collection.
    Args:
        collection_name (str): The name of the collection.
        base_document_id (str, optional): Base document ID for filtering. Defaults to None.
    Returns:
        List[str]: List of document IDs.
    """
    logger.debug(
        f"[Milvus] Listing all documents with base document ID {base_document_id} for collection {collection_name}"
    )

    stop = False
    offset = 0
    document_ids_set = set()

    while not stop:
        filter_expr = (
            f'metadata["{DATA_POINT_FQN_METADATA_KEY}"] == "{base_document_id}"'
            if base_document_id
            else ""
        )

        search_result = self.milvus_client.query(
            collection_name=collection_name,
            filter=filter_expr,
            output_fields=["metadata"],
            limit=BATCH_SIZE,
            offset=offset,
        )

        if not search_result:
            stop = True
            break

        for doc in search_result:
            metadata = doc.get("metadata", {})
            if metadata.get(DATA_POINT_FQN_METADATA_KEY):
                document_ids_set.add(metadata.get(DATA_POINT_FQN_METADATA_KEY))
            if len(document_ids_set) > MAX_SCROLL_LIMIT:
                stop = True
                break

        if len(search_result) < BATCH_SIZE:
            stop = True
        else:
            offset += BATCH_SIZE

    logger.debug(
        f"[Milvus] Found {len(document_ids_set)} documents with base document ID {base_document_id} in collection {collection_name}"
    )
    return list(document_ids_set)
def delete_documents(self, collection_name: str, document_ids: List[str]):
    """
    Delete documents from a collection based on document IDs.
    Args:
        collection_name (str): The name of the collection.
        document_ids (List[str]): List of document IDs to delete.
    """
    logger.debug(
        f"[Milvus] Deleting {len(document_ids)} documents from collection {collection_name}"
    )

    if not document_ids:
        logger.warning("[Milvus] No document IDs provided for deletion.")
        return

    try:
        for i in range(0, len(document_ids), BATCH_SIZE):
            document_ids_to_delete = document_ids[i : i + BATCH_SIZE]

            delete_expr = " or ".join(
                [
                    f'metadata["{DATA_POINT_FQN_METADATA_KEY}"] == "{doc_id}"'
                    for doc_id in document_ids_to_delete
                ]
            )

            self.milvus_client.delete(
                collection_name=collection_name, filter=delete_expr
            )

        logger.debug(
            f"[Milvus] Deleted {len(document_ids)} documents from collection {collection_name}"
        )
    except Exception as exp:
        logger.error(f"[Milvus] Error deleting documents: {exp}")
def list_document_vector_points(
    self, collection_name: str
) -> List[DataPointVector]:
    """
    List all document vector points in a collection.
    Args:
        collection_name (str): The name of the collection.
    Returns:
        List[DataPointVector]: List of vector points with metadata.
    """
    logger.debug(
        f"[Milvus] Listing all document vector points for collection {collection_name}"
    )

    stop = False
    offset = 0
    document_vector_points: List[DataPointVector] = []

    while not stop:
        search_result = self.milvus_client.query(
            collection_name=collection_name,
            output_fields=["id", "metadata"],
            limit=BATCH_SIZE,
            offset=offset,
        )

        if not search_result:
            stop = True
            break

        for doc in search_result:
            metadata = doc.get("metadata", {})
            if metadata.get(DATA_POINT_FQN_METADATA_KEY) and metadata.get(
                DATA_POINT_HASH_METADATA_KEY
            ):
                document_vector_points.append(
                    DataPointVector(
                        data_point_vector_id=str(doc["id"]),
                        data_point_fqn=metadata.get(DATA_POINT_FQN_METADATA_KEY),
                        data_point_hash=metadata.get(DATA_POINT_HASH_METADATA_KEY),
                    )
                )
            if len(document_vector_points) > MAX_SCROLL_LIMIT:
                stop = True
                break

        if len(search_result) < BATCH_SIZE:
            stop = True
        else:
            offset += BATCH_SIZE

    logger.debug(
        f"[Milvus] Listed {len(document_vector_points)} document vector points for collection {collection_name}"
    )
    return document_vector_points

6 changes: 3 additions & 3 deletions backend/modules/vector_db/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def _create_search_index(self, collection_name: str, embeddings: Embeddings):
result = self.db[collection_name].create_search_index(model=search_index_model)
logger.debug(f"New search index named {result} is building.")

# Immediate avaialbility of the index is not guaranteed upon creation.
# Immediate availability of the index is not guaranteed upon creation.
# MongoDB documentation recommends polling for the index to be ready.
# Ensure this check to provide a seamless experience.
# TODO (mnvsk97): We might want to introduce a new status in the ingestion runs to reflex this.
logger.debug(
"Polling to check if the index is ready. This may take up to a minute."
)
predicate = lambda index: index.get("queryable") is True
predicate = lambda index: index.get("queryable") is True # noqa: E731
while True:
indices = list(
self.db[collection_name].list_search_indexes("vector_search_index")
Expand All @@ -96,7 +96,7 @@ def upsert_documents(
f"[Mongo] Adding {len(documents)} documents to collection {collection_name}"
)

"""Upsert documenlots with their embeddings"""
"""Upsert documents with their embeddings"""
collection = self.db[collection_name]

data_point_fqns = []
Expand Down
5 changes: 2 additions & 3 deletions backend/modules/vector_db/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def create_collection(self, collection_name: str, embeddings: Embeddings):
logger.debug(f"[Qdrant] Creating new collection {collection_name}")

# Calculate embedding size
logger.debug(f"[Qdrant] Embedding a dummy doc to get vector dimensions")
partial_embeddings = embeddings.embed_documents(["Initial document"])
vector_size = len(partial_embeddings[0])
logger.debug(f"Vector size: {vector_size}")
Expand Down Expand Up @@ -166,7 +165,7 @@ def upsert_documents(
)

def get_collections(self) -> List[str]:
logger.debug(f"[Qdrant] Fetching collections")
logger.debug("[Qdrant] Fetching collections")
collections = self.qdrant_client.get_collections().collections
logger.debug(f"[Qdrant] Fetched {len(collections)} collections")
return [collection.name for collection in collections]
Expand All @@ -185,7 +184,7 @@ def get_vector_store(self, collection_name: str, embeddings: Embeddings):
)

def get_vector_client(self):
logger.debug(f"[Qdrant] Getting Qdrant client")
logger.debug("[Qdrant] Getting Qdrant client")
return self.qdrant_client

def list_data_point_vectors(
Expand Down
5 changes: 5 additions & 0 deletions backend/vectordb.requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ weaviate-client==3.25.3
### MongoDB
pymongo==4.10.1
langchain-mongodb==0.2.0


### Milvus
pymilvus==2.4.10
langchain-milvus==0.1.7
2 changes: 2 additions & 0 deletions compose.env
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ ML_REPO_NAME=''
VECTOR_DB_CONFIG='{"provider":"qdrant","url":"http://qdrant-server:6333", "config": {"grpc_port": 6334, "prefer_grpc": false}}'
# MONGO Example
# VECTOR_DB_CONFIG='{"provider":"mongo","url":"connection_uri", "config": {"database_name": "cognita"}}'
# Milvus Example
# VECTOR_DB_CONFIG='{"provider":"Milvus", "url":"connection_uri", "api_key":"milvus_auth_token", "config":{"db_name":"cognita", "metric_type":"COSINE"}}'
COGNITA_BACKEND_PORT=8000

UNSTRUCTURED_IO_URL=http://unstructured-io-parsers:9500/
Expand Down
Loading