Skip to content

Commit

Permalink
Update MilvusVectorStore to Pydantic (#11432)
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi03071991 authored Feb 27, 2024
1 parent bf2c8a4 commit c79f36d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from typing import Any, Dict, List, Optional, Union

import pymilvus # noqa
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
Expand Down Expand Up @@ -39,7 +40,7 @@ def _to_milvus_filter(standard_filters: MetadataFilters) -> List[str]:
return filters


class MilvusVectorStore(VectorStore):
class MilvusVectorStore(BasePydanticVectorStore):
"""The Milvus Vector Store.
In this vector store we store the text, its embedding and
Expand Down Expand Up @@ -87,11 +88,27 @@ class MilvusVectorStore(VectorStore):
stores_text: bool = True
stores_node: bool = True

uri: str = "http://localhost:19530"
token: str = ""
collection_name: str = "llamacollection"
dim: Optional[int]
embedding_field: str = DEFAULT_EMBEDDING_KEY
doc_id_field: str = DEFAULT_DOC_ID_KEY
similarity_metric: str = "IP"
consistency_level: str = "Strong"
overwrite: bool = False
text_key: Optional[str]
index_config: Optional[dict]
search_config: Optional[dict]

_milvusclient: MilvusClient = PrivateAttr()
_collection: Any = PrivateAttr()

def __init__(
self,
uri: str = "http://localhost:19530",
token: str = "",
collection_name: str = "llamalection",
collection_name: str = "llamacollection",
dim: Optional[int] = None,
embedding_field: str = DEFAULT_EMBEDDING_KEY,
doc_id_field: str = DEFAULT_DOC_ID_KEY,
Expand All @@ -104,64 +121,56 @@ def __init__(
**kwargs: Any,
) -> None:
"""Init params."""
self.collection_name = collection_name
self.dim = dim
self.embedding_field = embedding_field
self.doc_id_field = doc_id_field
self.consistency_level = consistency_level
self.overwrite = overwrite
self.text_key = text_key
self.index_config: Dict[str, Any] = index_config.copy() if index_config else {}
# Note: The search configuration is set at construction to avoid having
# to change the API for usage of the vector store (i.e. to pass the
# search config along with the rest of the query).
self.search_config: Dict[str, Any] = (
search_config.copy() if search_config else {}
super().__init__(
collection_name=collection_name,
dim=dim,
embedding_field=embedding_field,
doc_id_field=doc_id_field,
consistency_level=consistency_level,
overwrite=overwrite,
text_key=text_key,
index_config=index_config if index_config else {},
search_config=search_config if search_config else {},
)

# Select the similarity metric
if similarity_metric.lower() in ("ip"):
self.similarity_metric = "IP"
elif similarity_metric.lower() in ("l2", "euclidean"):
self.similarity_metric = "L2"
similarity_metrics_map = {"ip": "IP", "l2": "L2", "euclidean": "L2"}
similarity_metric = similarity_metrics_map.get(similarity_metric.lower(), "L2")

# Connect to Milvus instance
self.milvusclient = MilvusClient(
self._milvusclient = MilvusClient(
uri=uri,
token=token,
**kwargs, # pass additional arguments such as server_pem_path
)

# Delete previous collection if overwriting
if self.overwrite and self.collection_name in self.client.list_collections():
self.milvusclient.drop_collection(self.collection_name)
if overwrite and collection_name in self.client.list_collections():
self._milvusclient.drop_collection(collection_name)

# Create the collection if it does not exist
if self.collection_name not in self.client.list_collections():
if self.dim is None:
if collection_name not in self.client.list_collections():
if dim is None:
raise ValueError("Dim argument required for collection creation.")
self.milvusclient.create_collection(
collection_name=self.collection_name,
dimension=self.dim,
self._milvusclient.create_collection(
collection_name=collection_name,
dimension=dim,
primary_field_name=MILVUS_ID_FIELD,
vector_field_name=self.embedding_field,
vector_field_name=embedding_field,
id_type="string",
metric_type=self.similarity_metric,
metric_type=similarity_metric,
max_length=65_535,
consistency_level=self.consistency_level,
consistency_level=consistency_level,
)

self.collection = Collection(
self.collection_name, using=self.milvusclient._using
)
self._collection = Collection(collection_name, using=self._milvusclient._using)
self._create_index_if_required()

logger.debug(f"Successfully created a new collection: {self.collection_name}")

@property
def client(self) -> Any:
"""Get client."""
return self.milvusclient
return self._milvusclient

def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add the embeddings and their nodes into Milvus.
Expand Down Expand Up @@ -189,8 +198,8 @@ def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
insert_list.append(entry)

# Insert the data into milvus
self.collection.insert(insert_list)
self.collection.flush()
self._collection.insert(insert_list)
self._collection.flush()
self._create_index_if_required()
logger.debug(
f"Successfully inserted embeddings into: {self.collection_name} "
Expand All @@ -217,13 +226,13 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:

# Begin by querying for the primary keys to delete
doc_ids = ['"' + entry + '"' for entry in doc_ids]
entries = self.milvusclient.query(
entries = self._milvusclient.query(
collection_name=self.collection_name,
filter=f"{self.doc_id_field} in [{','.join(doc_ids)}]",
)
if len(entries) > 0:
ids = [entry["id"] for entry in entries]
self.milvusclient.delete(collection_name=self.collection_name, pks=ids)
self._milvusclient.delete(collection_name=self.collection_name, pks=ids)
logger.debug(f"Successfully deleted embedding with doc_id: {doc_ids}")

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
Expand Down Expand Up @@ -267,7 +276,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
string_expr = " and ".join(expr)

# Perform the search
res = self.milvusclient.search(
res = self._milvusclient.search(
collection_name=self.collection_name,
data=[query.query_embedding],
filter=string_expr,
Expand Down Expand Up @@ -317,17 +326,17 @@ def _create_index_if_required(self, force: bool = False) -> None:
# provided to ensure that the index is created in the constructor even
# if self.overwrite is false. In the `add` method, the index is
# recreated only if self.overwrite is true.
if (self.collection.has_index() and self.overwrite) or force:
self.collection.release()
self.collection.drop_index()
if (self._collection.has_index() and self.overwrite) or force:
self._collection.release()
self._collection.drop_index()
base_params: Dict[str, Any] = self.index_config.copy()
index_type: str = base_params.pop("index_type", "FLAT")
index_params: Dict[str, Union[str, Dict[str, Any]]] = {
"params": base_params,
"metric_type": self.similarity_metric,
"index_type": index_type,
}
self.collection.create_index(
self._collection.create_index(
self.embedding_field, index_params=index_params
)
self.collection.load()
self._collection.load()
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from llama_index.core.vector_stores.types import VectorStore
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore


def test_class():
names_of_base_classes = [b.__name__ for b in MilvusVectorStore.__mro__]
assert VectorStore.__name__ in names_of_base_classes
assert BasePydanticVectorStore.__name__ in names_of_base_classes

0 comments on commit c79f36d

Please sign in to comment.