Skip to content

Commit

Permalink
feat: milvus vectordb integration
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhonglin-ryan committed Dec 25, 2024
1 parent 89d25c9 commit ac62359
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 72 deletions.
163 changes: 95 additions & 68 deletions backend/modules/vector_db/milvus.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import json
from typing import List
from uuid import uuid4

from langchain.embeddings.base import Embeddings
from langchain_milvus import Milvus
from langchain_core.vectorstores import VectorStore
from langchain.docstore.document import Document
from pymilvus import MilvusClient, DataType

from backend.logger import logger
from backend.constants import DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, DATA_POINT_FQN_METADATA_KEY, \
DATA_POINT_HASH_METADATA_KEY
from backend.constants import (
DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
DATA_POINT_FQN_METADATA_KEY,
DATA_POINT_HASH_METADATA_KEY,
)
from backend.modules.vector_db import BaseVectorDB
from backend.types import DataPointVector, VectorDBConfig

Expand All @@ -20,28 +26,23 @@ def __init__(self, config: VectorDBConfig):
logger.debug(f"Connecting to milvus using config: {config.model_dump()}")
if config.local:
self.uri = config.url if config.url else "./milvus_local.db"
self.milvus_client = MilvusClient(
uri=self.uri
)
self.milvus_client = MilvusClient(uri=self.uri)
else:
self.uri = config.url
self.token = config.api_key
self.milvus_client = MilvusClient(
uri=self.uri,
token=self.token
)
self.milvus_client = MilvusClient(uri=self.uri, token=self.token)

def create_collection(self, collection_name: str, embeddings: Embeddings):
"""
Create a new collection in Milvus with the given schema and embedding configuration.
Create a new collection in Milvus with the given schema and embedding configuration.
Args:
collection_name (str): The name of the collection to be created.
embeddings (Embeddings): An embedding function to determine vector size.
Args:
collection_name (str): The name of the collection to be created.
embeddings (Embeddings): An embedding function to determine vector size.
Returns:
None
"""
Returns:
None
"""
logger.debug(f"[Milvus] Creating new collection {collection_name}")

# Calculate embedding size
Expand All @@ -50,11 +51,24 @@ def create_collection(self, collection_name: str, embeddings: Embeddings):
vector_size = len(partial_embeddings[0])
logger.debug(f"Vector size: {vector_size}")

schema = self.milvus_client.create_schema(auto_id=False, enable_dynamic_field=True)

schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=65535)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=vector_size)
schema = self.milvus_client.create_schema(
auto_id=False, enable_dynamic_field=True
)

schema.add_field(
field_name="id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=65535,
)
schema.add_field(
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=vector_size
)
schema.add_field(
field_name="text",
datatype=DataType.VARCHAR,
max_length=65535,
)
index_params = self.milvus_client.prepare_index_params()
index_params.add_index(
field_name="vector",
Expand Down Expand Up @@ -92,19 +106,20 @@ def _get_records_to_be_upserted(
)

try:
filter_condition = f'metadata.{DATA_POINT_FQN_METADATA_KEY} in {data_point_fqns}'
if not data_point_fqns:
return []
filter_condition = (
f"{DATA_POINT_FQN_METADATA_KEY} in {json.dumps(data_point_fqns)}"
)
results = self.milvus_client.query(
collection_name=collection_name,
filter=filter_condition,
output_fields=["id", "metadata"],
)
record_ids_to_be_upserted = []

for record in results:
metadata = record.get("metadata")
if metadata:
if metadata.get(DATA_POINT_FQN_METADATA_KEY):
record_ids_to_be_upserted.append(record["id"])
if record.get(DATA_POINT_FQN_METADATA_KEY):
record_ids_to_be_upserted.append(record["id"])

logger.debug(
f"[Milvus] Incremental Ingestion: collection={collection_name}, Updates={len(record_ids_to_be_upserted)}"
Expand All @@ -115,11 +130,11 @@ def _get_records_to_be_upserted(
return []

def upsert_documents(
self,
collection_name: str,
documents: List[Document],
embeddings: Embeddings,
incremental: bool = True,
self,
collection_name: str,
documents: List[Document],
embeddings: Embeddings,
incremental: bool = True,
):
"""
Upsert (update or insert) documents into the specified Milvus collection.
Expand All @@ -134,7 +149,6 @@ def upsert_documents(
None
"""


if len(documents) == 0:
logger.warning("No documents to index")
return
Expand Down Expand Up @@ -166,15 +180,25 @@ def upsert_documents(
logger.debug(
f"[Milvus] Deleted {len(record_ids_to_be_upserted)} outdated documents from collection {collection_name}"
)
for doc in documents:
fqn = doc.metadata.get(DATA_POINT_FQN_METADATA_KEY)
if not fqn:
raise ValueError(
"[Milvus] Each document must have a unique data_point_fqn"
)
doc.metadata["id"] = fqn
doc.metadata["text"] = doc.page_content

ids = [doc.metadata["id"] for doc in documents]

Milvus(
collection_name=collection_name,
embedding_function=embeddings,
connection_args={
'uri': self.uri,
'token': self.token,
}
).add_documents(documents=documents)
"uri": self.uri,
"token": self.token,
},
).add_documents(documents=documents, ids=ids)
logger.debug(
f"[Milvus] Added {len(documents)} documents to collection {collection_name}"
)
Expand All @@ -190,15 +214,17 @@ def delete_collection(self, collection_name: str):
self.milvus_client.drop_collection(collection_name=collection_name)
logger.debug(f"[Milvus] Deleted {collection_name} collection")

def get_vector_store(self, collection_name: str, embeddings: Embeddings) -> VectorStore:
def get_vector_store(
self, collection_name: str, embeddings: Embeddings
) -> VectorStore:
logger.debug(f"[Milvus] Getting vector store for collection {collection_name}")
return Milvus(
embedding_function=embeddings,
collection_name=collection_name,
connection_args={
'uri': self.uri,
'token': self.token,
}
"uri": self.uri,
"token": self.token,
},
)

def get_vector_client(self):
Expand Down Expand Up @@ -231,27 +257,28 @@ def list_data_point_vectors(

while not stop:
try:
filter_condition = f'metadata.{DATA_POINT_FQN_METADATA_KEY} == "{data_source_fqn}"'
filter_condition = (
f'{DATA_POINT_FQN_METADATA_KEY} == "{data_source_fqn}"'
)

results = self.milvus_client.query(
collection_name=collection_name,
filter=filter_condition,
output_fields=["id", "metadata"],
limit=batch_size,
offset=offset,
)

for record in results:
metadata = record.get("metadata", {})
if ( metadata
and metadata.get(DATA_POINT_FQN_METADATA_KEY)
and metadata.get(DATA_POINT_HASH_METADATA_KEY)
if record.get(DATA_POINT_FQN_METADATA_KEY) and record.get(
DATA_POINT_HASH_METADATA_KEY
):
data_point_vectors.append(
DataPointVector(
data_point_vector_id=record["id"],
data_point_fqn=metadata.get(DATA_POINT_FQN_METADATA_KEY),
data_point_hash=metadata.get(DATA_POINT_HASH_METADATA_KEY),
data_point_fqn=record.get(DATA_POINT_FQN_METADATA_KEY),
data_point_hash=record.get(
DATA_POINT_HASH_METADATA_KEY
),
)
)

Expand Down Expand Up @@ -294,7 +321,7 @@ def delete_data_point_vectors(
deleted_vectors_count = 0

for i in range(0, vectors_to_be_deleted_count, batch_size):
data_point_vectors_to_be_processed = data_point_vectors[i: i + batch_size]
data_point_vectors_to_be_processed = data_point_vectors[i : i + batch_size]

try:
self.milvus_client.delete(
Expand All @@ -319,7 +346,10 @@ def delete_data_point_vectors(
)

def list_documents_in_collection(
self, collection_name: str, base_document_id: str = None, batch_size: int = BATCH_SIZE
self,
collection_name: str,
base_document_id: str = None,
batch_size: int = BATCH_SIZE,
) -> List[str]:
"""
List all documents in a Milvus collection, optionally filtering by base document ID.
Expand All @@ -343,24 +373,23 @@ def list_documents_in_collection(
while not stop:
try:
if base_document_id:
filter_condition = f'metadata.{DATA_POINT_FQN_METADATA_KEY} == "{base_document_id}"'
filter_condition = (
f'{DATA_POINT_FQN_METADATA_KEY} == "{base_document_id}"'
)
else:
filter_condition = None

results = self.milvus_client.query(
collection_name=collection_name,
filter=filter_condition,
output_fields=["metadata"],
limit=batch_size,
offset=offset,
)

for record in results:
metadata = record.get("metadata", {})
if metadata:
document_id = metadata.get(DATA_POINT_FQN_METADATA_KEY)
if document_id:
document_ids_set.add(document_id)
document_id = record.get(DATA_POINT_FQN_METADATA_KEY)
if document_id:
document_ids_set.add(document_id)

if len(results) < batch_size:
stop = True
Expand Down Expand Up @@ -396,10 +425,12 @@ def delete_documents(self, collection_name: str, document_ids: List[str]):
)

for i in range(0, len(document_ids), BATCH_SIZE):
document_ids_to_be_processed = document_ids[i: i + BATCH_SIZE]
document_ids_to_be_processed = document_ids[i : i + BATCH_SIZE]

try:
filter_condition = f'metadata.{DATA_POINT_FQN_METADATA_KEY} in {document_ids_to_be_processed}'
if not document_ids_to_be_processed:
continue
filter_condition = f"{DATA_POINT_FQN_METADATA_KEY} in {json.dumps(document_ids_to_be_processed)}"
self.milvus_client.delete(
collection_name=collection_name,
filter=filter_condition,
Expand All @@ -411,7 +442,7 @@ def delete_documents(self, collection_name: str, document_ids: List[str]):
logger.error(f"[Milvus] Failed to delete documents: {str(e)}")

def list_document_vector_points(
self, collection_name: str
self, collection_name: str
) -> List[DataPointVector]:
"""
List all document vector points in a Milvus collection.
Expand All @@ -434,22 +465,19 @@ def list_document_vector_points(
try:
results = self.milvus_client.query(
collection_name=collection_name,
output_fields=["id", "metadata"],
limit=BATCH_SIZE,
offset=offset,
)

for record in results:
metadata = record.get("metadata", {})
if ( metadata
and metadata.get(DATA_POINT_FQN_METADATA_KEY)
and metadata.get(DATA_POINT_HASH_METADATA_KEY)
if record.get(DATA_POINT_FQN_METADATA_KEY) and record.get(
DATA_POINT_HASH_METADATA_KEY
):
document_vector_points.append(
DataPointVector(
point_id=record["id"],
document_id=metadata.get(DATA_POINT_FQN_METADATA_KEY),
document_hash=metadata.get(DATA_POINT_HASH_METADATA_KEY),
document_id=record.get(DATA_POINT_FQN_METADATA_KEY),
document_hash=record.get(DATA_POINT_HASH_METADATA_KEY),
)
)

Expand All @@ -470,4 +498,3 @@ def list_document_vector_points(
f"[Milvus] Listed {len(document_vector_points)} document vector points for collection {collection_name}"
)
return document_vector_points

10 changes: 6 additions & 4 deletions compose.env
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ POSTGRES_PASSWORD=test
MODELS_CONFIG_PATH="./models_config.yaml"
METADATA_STORE_CONFIG='{"provider":"prisma"}'
ML_REPO_NAME=''
# VECTOR_DB_CONFIG='{"provider":"qdrant","url":"http://qdrant-server:6333", "config": {"grpc_port": 6334, "prefer_grpc": false}}'
VECTOR_DB_CONFIG='{"provider":"qdrant","url":"http://qdrant-server:6333", "config": {"grpc_port": 6334, "prefer_grpc": false}}'
# MONGO Example
# VECTOR_DB_CONFIG='{"provider":"mongo","url":"connection_uri", "config": {"database_name": "cognita"}}'
# MILVUS Example
VECTOR_DB_CONFIG={"provider":"milvus","local":true}
# MILVUS CLOUD Example
# VECTOR_DB_CONFIG={"provider": "milvus", "local": false, "api_key": "YOUR_ZILLIZ_CLOUD_TOKEN", "url": "YOUR_ZIILIZ_CLOUD_URI"}
# MILVUS Local Example
# VECTOR_DB_CONFIG={"provider": "milvus", "local": true}
COGNITA_BACKEND_PORT=8000

UNSTRUCTURED_IO_URL=http://unstructured-io-parsers:9500/
Expand All @@ -31,7 +33,7 @@ VITE_DOCS_QA_ENABLE_REDIRECT=false
VITE_DOCS_QA_MAX_UPLOAD_SIZE_MB=200

## OpenAI
OPENAI_API_KEY=
OPENAI_API_KEY=""

## OLLAMA VARS
OLLAMA_MODEL=qwen2:1.5b
Expand Down

0 comments on commit ac62359

Please sign in to comment.