-
Notifications
You must be signed in to change notification settings - Fork 290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add milvus vector db integration #419
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
388fd05
feat: add milvus vector db integration
abhishekrp-ai2002 80b1512
chore: requested changes addressed
AbhishekRP2002 3625df7
add example config in compose.env
AbhishekRP2002 9f7b060
fix: bug fixes encountered during testing
AbhishekRP2002 957053a
Merge branch 'main' into main
mnvsk97 f94b4ff
chore: fix pre-commit checks
AbhishekRP2002 1258f3a
Merge branch 'main' of github.com-personal:AbhishekRP2002/cognita
AbhishekRP2002 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
from typing import List | ||
|
||
from langchain.docstore.document import Document | ||
from langchain.embeddings.base import Embeddings | ||
from langchain_milvus import Milvus | ||
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient | ||
|
||
from backend.constants import ( | ||
DATA_POINT_FQN_METADATA_KEY, | ||
DATA_POINT_HASH_METADATA_KEY, | ||
DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, | ||
) | ||
from backend.logger import logger | ||
from backend.modules.vector_db.base import BaseVectorDB | ||
from backend.types import DataPointVector, VectorDBConfig | ||
|
||
MAX_SCROLL_LIMIT = int(1e6) | ||
BATCH_SIZE = 1000 | ||
|
||
|
||
class MilvusVectorDB(BaseVectorDB): | ||
def __init__(self, config: VectorDBConfig): | ||
""" | ||
Initialize Milvus vector database client | ||
Args: | ||
:param config: VectorDBConfig | ||
- provider: str | ||
- local: bool | ||
- url: str | ||
URI of the Milvus server. | ||
- If you only need a local vector database for small scale data or prototyping, | ||
setting the uri as a local file, e.g.`./milvus.db`, is the most convenient method, | ||
as it automatically utilizes [Milvus Lite](https://milvus.io/docs/milvus_lite.md) | ||
to store all data in this file. | ||
- If you have large scale of data, say more than a million vectors, you can set up | ||
a more performant Milvus server on [Docker or Kubernetes](https://milvus.io/docs/quickstart.md). | ||
In this setup, please use the server address and port as your uri, e.g.`http://localhost:19530`. | ||
If you enable the authentication feature on Milvus, | ||
use "<your_username>:<your_password>" as the token, otherwise don't set the token. | ||
- If you use [Zilliz Cloud](https://zilliz.com/cloud), the fully managed cloud | ||
service for Milvus, adjust the `uri` and `token`, which correspond to the | ||
[Public Endpoint and API key](https://docs.zilliz.com/docs/on-zilliz-cloud-console#cluster-details) | ||
- api_key: str | ||
Token for authentication with the Milvus server. | ||
""" | ||
# TODO: create an extended config for Milvus like done in Qdrant | ||
logger.debug(f"Connecting to Milvus using config: {config.model_dump()}") | ||
self.config = config | ||
self.metric_type = config.config.get("metric_type", "COSINE") | ||
# Milvus-lite is used for local == True | ||
if config.local is True: | ||
# TODO: make this path customizable | ||
self.url = "./cognita_milvus.db" | ||
self.api_key = "" | ||
self.milvus_client = MilvusClient( | ||
uri=self.url, | ||
db_name=config.config.get("db_name", "milvus_default_db"), | ||
) | ||
else: | ||
self.url = config.url | ||
self.api_key = config.api_key | ||
if not self.api_key: | ||
api_key = None | ||
|
||
self.milvus_client = MilvusClient( | ||
uri=self.url, | ||
token=api_key, | ||
db_name=config.config.get("db_name", "milvus_default_db"), | ||
) | ||
|
||
def create_collection(self, collection_name: str, embeddings: Embeddings): | ||
""" | ||
Create a collection in the vector database | ||
Args: | ||
:param collection_name: str - Name of the collection | ||
:param embeddings: Embeddings - Embeddings object to be used for creating embeddings of the documents | ||
Current implementation includes Quick setup in which the collection is created, indexed and loaded into the memory. | ||
|
||
""" | ||
# TODO: Add customized setup with indexed params | ||
logger.debug(f"[Milvus] Creating new collection {collection_name}") | ||
|
||
vector_size = self.get_embedding_dimensions(embeddings) | ||
|
||
fields = [ | ||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), | ||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size), | ||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), | ||
FieldSchema(name="metadata", dtype=DataType.JSON), | ||
] | ||
|
||
schema = CollectionSchema( | ||
fields=fields, description=f"Collection for {collection_name}" | ||
) | ||
|
||
self.milvus_client.create_collection( | ||
collection_name=collection_name, | ||
dimension=vector_size, | ||
metric_type=self.metric_type, # https://milvus.io/docs/metric.md#Metric-Types : check for other supported metrics | ||
schema=schema, | ||
auto_id=True, | ||
) | ||
|
||
# Can use this to create custom multiple indices | ||
index_params = self.milvus_client.prepare_index_params() | ||
index_params.add_index( | ||
field_name="vector", index_type="FLAT", metric_type=self.metric_type | ||
) | ||
self.milvus_client.create_index( | ||
collection_name=collection_name, index_params=index_params | ||
) | ||
|
||
logger.debug(f"[Milvus] Created new collection {collection_name}") | ||
|
||
def _delete_existing_documents( | ||
self, collection_name: str, documents: List[Document] | ||
): | ||
""" | ||
Delete existing documents from the collection | ||
""" | ||
# Instead of using document IDs, we'll delete based on metadata matching | ||
for doc in documents: | ||
if ( | ||
DATA_POINT_FQN_METADATA_KEY in doc.metadata | ||
and DATA_POINT_HASH_METADATA_KEY in doc.metadata | ||
): | ||
delete_expr = ( | ||
f'metadata["{DATA_POINT_FQN_METADATA_KEY}"] == "{doc.metadata[DATA_POINT_FQN_METADATA_KEY]}" && ' | ||
f'metadata["{DATA_POINT_HASH_METADATA_KEY}"] == "{doc.metadata[DATA_POINT_HASH_METADATA_KEY]}"' | ||
) | ||
|
||
logger.debug( | ||
f"[Milvus] Deleting records matching expression: {delete_expr}" | ||
) | ||
|
||
self.milvus_client.delete( | ||
collection_name=collection_name, | ||
filter=delete_expr, | ||
) | ||
|
||
def upsert_documents( | ||
self, | ||
collection_name: str, | ||
documents: List[Document], | ||
embeddings: Embeddings, | ||
incremental: bool = True, | ||
): | ||
""" | ||
Upsert documents in the database. | ||
Upsert = Insert / update | ||
- Check if collection exists or not | ||
- Check if collection is empty or not | ||
- If collection is empty, insert all documents | ||
- If collection is not empty, delete existing documents and insert new documents | ||
""" | ||
if len(documents) == 0: | ||
logger.warning("No documents to index") | ||
return | ||
|
||
logger.debug( | ||
f"[Milvus] Adding {len(documents)} documents to collection {collection_name}" | ||
) | ||
|
||
if not self.milvus_client.has_collection(collection_name): | ||
mnvsk97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise Exception( | ||
f"Collection {collection_name} does not exist. Please create it first using `create_collection`." | ||
) | ||
|
||
stats = self.milvus_client.get_collection_stats(collection_name=collection_name) | ||
if stats["row_count"] == 0: | ||
logger.warning( | ||
f"[Milvus] Collection {collection_name} is empty. Inserting all documents." | ||
) | ||
self.get_vector_store(collection_name, embeddings).add_documents( | ||
documents=documents | ||
) | ||
|
||
if incremental and len(documents) > 0: | ||
self._delete_existing_documents(collection_name, documents) | ||
|
||
self.get_vector_store(collection_name, embeddings).add_documents( | ||
documents=documents | ||
) | ||
|
||
logger.debug( | ||
f"[Milvus] Upserted {len(documents)} documents to collection {collection_name}" | ||
) | ||
|
||
def get_collections(self) -> List[str]: | ||
logger.debug("[Milvus] Fetching collections from the vector database") | ||
collections = self.milvus_client.list_collections() | ||
logger.debug(f"[Milvus] Fetched {len(collections)} collections") | ||
return collections | ||
|
||
def delete_collection(self, collection_name: str): | ||
logger.debug(f"[Milvus] Deleting {collection_name} collection") | ||
self.milvus_client.drop_collection(collection_name) | ||
logger.debug(f"[Milvus] Deleted {collection_name} collection") | ||
|
||
def get_vector_store(self, collection_name: str, embeddings: Embeddings): | ||
logger.debug(f"[Milvus] Getting vector store for collection {collection_name}") | ||
return Milvus( | ||
collection_name=collection_name, | ||
connection_args={ | ||
"uri": self.url, | ||
"token": self.api_key, | ||
}, | ||
embedding_function=embeddings, | ||
auto_id=True, | ||
primary_field="id", | ||
text_field="text", | ||
metadata_field="metadata", | ||
) | ||
|
||
def get_vector_client(self): | ||
logger.debug("[Milvus] Getting Milvus client") | ||
return self.milvus_client | ||
|
||
def list_data_point_vectors( | ||
self, | ||
collection_name: str, | ||
data_source_fqn: str, | ||
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, | ||
) -> List[DataPointVector]: | ||
""" | ||
Get vectors from the collection | ||
""" | ||
logger.debug( | ||
f"[Milvus] Listing data point vectors for collection {collection_name}" | ||
) | ||
filter_expr = ( | ||
f'metadata["{DATA_POINT_FQN_METADATA_KEY}"] == "{data_source_fqn}"' | ||
) | ||
|
||
data_point_vectors: List[DataPointVector] = [] | ||
|
||
offset = 0 | ||
|
||
while True: | ||
search_result = self.milvus_client.query( | ||
collection_name=collection_name, | ||
filter=filter_expr, | ||
output_fields=[ | ||
"*" | ||
], # returning all the fields of the entity / data point | ||
limit=batch_size, | ||
offset=offset, | ||
) | ||
|
||
for result in search_result: | ||
if result.get("metadata", {}).get( | ||
DATA_POINT_FQN_METADATA_KEY | ||
) and result.get("metadata", {}).get(DATA_POINT_HASH_METADATA_KEY): | ||
data_point_vectors.append( | ||
DataPointVector( | ||
data_point_vector_id=str(result["id"]), | ||
data_point_fqn=result["metadata"][ | ||
DATA_POINT_FQN_METADATA_KEY | ||
], | ||
data_point_hash=result["metadata"][ | ||
DATA_POINT_HASH_METADATA_KEY | ||
], | ||
) | ||
) | ||
|
||
if ( | ||
len(search_result) < batch_size | ||
or len(data_point_vectors) >= MAX_SCROLL_LIMIT | ||
): | ||
break | ||
|
||
offset += batch_size | ||
|
||
logger.debug(f"[Milvus] Listed {len(data_point_vectors)} data point vectors") | ||
|
||
return data_point_vectors | ||
|
||
def delete_data_point_vectors( | ||
self, | ||
collection_name: str, | ||
data_point_vectors: List[DataPointVector], | ||
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, | ||
): | ||
""" | ||
Delete vectors from the collection | ||
""" | ||
logger.debug(f"[Milvus] Deleting {len(data_point_vectors)} data point vectors") | ||
|
||
for i in range(0, len(data_point_vectors), batch_size): | ||
batch_vectors = data_point_vectors[i : i + batch_size] | ||
|
||
delete_expr = " or ".join( | ||
[f"id == {vector.data_point_vector_id}" for vector in batch_vectors] | ||
) | ||
|
||
self.milvus_client.delete( | ||
collection_name=collection_name, filter=delete_expr | ||
) | ||
|
||
logger.debug(f"[Milvus] Deleted {len(data_point_vectors)} data point vectors") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding the following 3 functions:
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, this is Jinhong from Milvus. We have transitioned to the new MilvusClient interface and are actively working to phase out the use of the older ORM interface. In this context, your current approach to adding fields to a collection may be updated to the following: