Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MilvusVectorStore to Pydantic #11432

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from typing import Any, Dict, List, Optional, Union

import pymilvus # noqa
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
Expand Down Expand Up @@ -39,7 +40,7 @@ def _to_milvus_filter(standard_filters: MetadataFilters) -> List[str]:
return filters


class MilvusVectorStore(VectorStore):
class MilvusVectorStore(BasePydanticVectorStore):
"""The Milvus Vector Store.

In this vector store we store the text, its embedding and
Expand Down Expand Up @@ -87,11 +88,27 @@ class MilvusVectorStore(VectorStore):
stores_text: bool = True
stores_node: bool = True

uri: str = "http://localhost:19530"
token: str = ""
collection_name: str = "llamacollection"
dim: Optional[int]
embedding_field: str = DEFAULT_EMBEDDING_KEY
doc_id_field: str = DEFAULT_DOC_ID_KEY
similarity_metric: str = "IP"
consistency_level: str = "Strong"
overwrite: bool = False
text_key: Optional[str]
index_config: Optional[dict]
search_config: Optional[dict]

_milvusclient: MilvusClient = PrivateAttr()
_collection: Any = PrivateAttr()

def __init__(
self,
uri: str = "http://localhost:19530",
token: str = "",
collection_name: str = "llamalection",
collection_name: str = "llamacollection",
dim: Optional[int] = None,
embedding_field: str = DEFAULT_EMBEDDING_KEY,
doc_id_field: str = DEFAULT_DOC_ID_KEY,
Expand All @@ -104,64 +121,56 @@ def __init__(
**kwargs: Any,
) -> None:
"""Init params."""
self.collection_name = collection_name
self.dim = dim
self.embedding_field = embedding_field
self.doc_id_field = doc_id_field
self.consistency_level = consistency_level
self.overwrite = overwrite
self.text_key = text_key
self.index_config: Dict[str, Any] = index_config.copy() if index_config else {}
# Note: The search configuration is set at construction to avoid having
# to change the API for usage of the vector store (i.e. to pass the
# search config along with the rest of the query).
self.search_config: Dict[str, Any] = (
search_config.copy() if search_config else {}
super().__init__(
collection_name=collection_name,
dim=dim,
embedding_field=embedding_field,
doc_id_field=doc_id_field,
consistency_level=consistency_level,
overwrite=overwrite,
text_key=text_key,
index_config=index_config if index_config else {},
search_config=search_config if search_config else {},
)

# Select the similarity metric
if similarity_metric.lower() in ("ip"):
self.similarity_metric = "IP"
elif similarity_metric.lower() in ("l2", "euclidean"):
self.similarity_metric = "L2"
similarity_metrics_map = {"ip": "IP", "l2": "L2", "euclidean": "L2"}
similarity_metric = similarity_metrics_map.get(similarity_metric.lower(), "L2")

# Connect to Milvus instance
self.milvusclient = MilvusClient(
self._milvusclient = MilvusClient(
uri=uri,
token=token,
**kwargs, # pass additional arguments such as server_pem_path
)

# Delete previous collection if overwriting
if self.overwrite and self.collection_name in self.client.list_collections():
self.milvusclient.drop_collection(self.collection_name)
if overwrite and collection_name in self.client.list_collections():
self._milvusclient.drop_collection(collection_name)

# Create the collection if it does not exist
if self.collection_name not in self.client.list_collections():
if self.dim is None:
if collection_name not in self.client.list_collections():
if dim is None:
raise ValueError("Dim argument required for collection creation.")
self.milvusclient.create_collection(
collection_name=self.collection_name,
dimension=self.dim,
self._milvusclient.create_collection(
collection_name=collection_name,
dimension=dim,
primary_field_name=MILVUS_ID_FIELD,
vector_field_name=self.embedding_field,
vector_field_name=embedding_field,
id_type="string",
metric_type=self.similarity_metric,
metric_type=similarity_metric,
max_length=65_535,
consistency_level=self.consistency_level,
consistency_level=consistency_level,
)

self.collection = Collection(
self.collection_name, using=self.milvusclient._using
)
self._collection = Collection(collection_name, using=self._milvusclient._using)
self._create_index_if_required()

logger.debug(f"Successfully created a new collection: {self.collection_name}")

@property
def client(self) -> Any:
"""Get client."""
return self.milvusclient
return self._milvusclient

def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add the embeddings and their nodes into Milvus.
Expand Down Expand Up @@ -189,8 +198,8 @@ def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
insert_list.append(entry)

# Insert the data into milvus
self.collection.insert(insert_list)
self.collection.flush()
self._collection.insert(insert_list)
self._collection.flush()
self._create_index_if_required()
logger.debug(
f"Successfully inserted embeddings into: {self.collection_name} "
Expand All @@ -217,13 +226,13 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:

# Begin by querying for the primary keys to delete
doc_ids = ['"' + entry + '"' for entry in doc_ids]
entries = self.milvusclient.query(
entries = self._milvusclient.query(
collection_name=self.collection_name,
filter=f"{self.doc_id_field} in [{','.join(doc_ids)}]",
)
if len(entries) > 0:
ids = [entry["id"] for entry in entries]
self.milvusclient.delete(collection_name=self.collection_name, pks=ids)
self._milvusclient.delete(collection_name=self.collection_name, pks=ids)
logger.debug(f"Successfully deleted embedding with doc_id: {doc_ids}")

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
Expand Down Expand Up @@ -267,7 +276,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
string_expr = " and ".join(expr)

# Perform the search
res = self.milvusclient.search(
res = self._milvusclient.search(
collection_name=self.collection_name,
data=[query.query_embedding],
filter=string_expr,
Expand Down Expand Up @@ -317,17 +326,17 @@ def _create_index_if_required(self, force: bool = False) -> None:
# provided to ensure that the index is created in the constructor even
# if self.overwrite is false. In the `add` method, the index is
# recreated only if self.overwrite is true.
if (self.collection.has_index() and self.overwrite) or force:
self.collection.release()
self.collection.drop_index()
if (self._collection.has_index() and self.overwrite) or force:
self._collection.release()
self._collection.drop_index()
base_params: Dict[str, Any] = self.index_config.copy()
index_type: str = base_params.pop("index_type", "FLAT")
index_params: Dict[str, Union[str, Dict[str, Any]]] = {
"params": base_params,
"metric_type": self.similarity_metric,
"index_type": index_type,
}
self.collection.create_index(
self._collection.create_index(
self.embedding_field, index_params=index_params
)
self.collection.load()
self._collection.load()
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from llama_index.core.vector_stores.types import VectorStore
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore


def test_class():
names_of_base_classes = [b.__name__ for b in MilvusVectorStore.__mro__]
assert VectorStore.__name__ in names_of_base_classes
assert BasePydanticVectorStore.__name__ in names_of_base_classes
Loading