diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py index 2089218d715be..f865f89609491 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py @@ -7,7 +7,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from azure.search.documents import SearchClient +from azure.search.documents.aio import SearchClient as AsyncSearchClient from azure.search.documents.indexes import SearchIndexClient +from azure.search.documents.indexes.aio import ( + SearchIndexClient as AsyncSearchIndexClient, +) from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.schema import BaseNode, MetadataMode, TextNode @@ -107,7 +111,10 @@ class AzureAISearchVectorStore(BasePydanticVectorStore): flat_metadata: bool = True _index_client: SearchIndexClient = PrivateAttr() + _index_name: Optional[str] = PrivateAttr() + _async_index_client: AsyncSearchIndexClient = PrivateAttr() _search_client: SearchClient = PrivateAttr() + _async_search_client: AsyncSearchClient = PrivateAttr() _embedding_dimensionality: int = PrivateAttr() _language_analyzer: str = PrivateAttr() _field_mapping: Dict[str, str] = PrivateAttr() @@ -157,6 +164,18 @@ def _create_index_if_not_exists(self, index_name: str) -> None: ) self._create_index(index_name) + async def _acreate_index_if_not_exists(self, index_name: str) -> None: + list_index_names = set() + + async for index in self._async_index_client.list_index_names(): + list_index_names.add(index) + + if index_name not in list_index_names: + logger.info( + f"Index {index_name} does not exist in Azure AI Search, creating index" + ) + await self._acreate_index(index_name) + def _create_metadata_index_fields(self) -> List[Any]: """Create a list of index fields for storing metadata values.""" from azure.search.documents.indexes.models import SimpleField @@ -283,9 +302,113 @@ def _create_index(self, index_name: Optional[str]) -> None: vector_search=vector_search, semantic_search=semantic_search, ) + logger.debug(f"Creating {index_name} search index") self._index_client.create_index(index) + async def _acreate_index(self, index_name: Optional[str]) -> None: + """ + Creates a default index based on the supplied index name, key field names and + metadata filtering keys. + """ + from azure.search.documents.indexes.models import ( + ExhaustiveKnnAlgorithmConfiguration, + ExhaustiveKnnParameters, + HnswAlgorithmConfiguration, + HnswParameters, + SearchableField, + SearchField, + SearchFieldDataType, + SearchIndex, + SemanticConfiguration, + SemanticField, + SemanticPrioritizedFields, + SemanticSearch, + SimpleField, + VectorSearch, + VectorSearchAlgorithmKind, + VectorSearchAlgorithmMetric, + VectorSearchProfile, + ) + + logger.info(f"Configuring {index_name} fields for Azure AI Search") + fields = [ + SimpleField(name=self._field_mapping["id"], type="Edm.String", key=True), + SearchableField( + name=self._field_mapping["chunk"], + type="Edm.String", + analyzer_name=self._language_analyzer, + ), + SearchField( + name=self._field_mapping["embedding"], + type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + searchable=True, + vector_search_dimensions=self._embedding_dimensionality, + vector_search_profile_name=self._vector_profile_name, + ), + SimpleField(name=self._field_mapping["metadata"], type="Edm.String"), + SimpleField( + name=self._field_mapping["doc_id"], type="Edm.String", filterable=True + ), + ] + logger.info(f"Configuring {index_name} metadata fields") + metadata_index_fields = self._create_metadata_index_fields() + fields.extend(metadata_index_fields) + logger.info(f"Configuring {index_name} vector search") + # Configure the vector search algorithms and profiles + vector_search = VectorSearch( + algorithms=[ + HnswAlgorithmConfiguration( + name="myHnsw", + kind=VectorSearchAlgorithmKind.HNSW, + # For more information on HNSw parameters, visit https://learn.microsoft.com//azure/search/vector-search-ranking#creating-the-hnsw-graph + parameters=HnswParameters( + m=4, + ef_construction=400, + ef_search=500, + metric=VectorSearchAlgorithmMetric.COSINE, + ), + ), + ExhaustiveKnnAlgorithmConfiguration( + name="myExhaustiveKnn", + kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN, + parameters=ExhaustiveKnnParameters( + metric=VectorSearchAlgorithmMetric.COSINE, + ), + ), + ], + profiles=[ + VectorSearchProfile( + name="myHnswProfile", + algorithm_configuration_name="myHnsw", + ), + # Add more profiles if needed + VectorSearchProfile( + name="myExhaustiveKnnProfile", + algorithm_configuration_name="myExhaustiveKnn", + ), + # Add more profiles if needed + ], + ) + logger.info(f"Configuring {index_name} semantic search") + semantic_config = SemanticConfiguration( + name="mySemanticConfig", + prioritized_fields=SemanticPrioritizedFields( + content_fields=[SemanticField(field_name=self._field_mapping["chunk"])], + ), + ) + + semantic_search = SemanticSearch(configurations=[semantic_config]) + + index = SearchIndex( + name=index_name, + fields=fields, + vector_search=vector_search, + semantic_search=semantic_search, + ) + logger.debug(f"Creating {index_name} search index") + await self._async_index_client.create_index(index) + def _validate_index(self, index_name: Optional[str]) -> None: if self._index_client and index_name: if index_name not in self._index_client.list_index_names(): @@ -293,6 +416,18 @@ def _validate_index(self, index_name: Optional[str]) -> None: f"Validation failed, index {index_name} does not exist." ) + async def _avalidate_index(self, index_name: Optional[str]) -> None: + list_index_names = set() + + async for index in self._async_index_client.list_index_names(): + list_index_names.add(index) + + if self._async_index_client and index_name: + if index_name not in list_index_names: + raise ValueError( + f"Validation failed, index {index_name} does not exist." + ) + def __init__( self, search_or_index_client: Any, @@ -378,8 +513,13 @@ def __init__( raise ImportError(import_err_msg) self._index_client: SearchIndexClient = cast(SearchIndexClient, None) + self._async_index_client: AsyncSearchIndexClient = cast( + AsyncSearchIndexClient, None + ) self._search_client: SearchClient = cast(SearchClient, None) + self._async_search_client: AsyncSearchClient = cast(AsyncSearchClient, None) self._embedding_dimensionality = embedding_dimensionality + self._index_name = index_name if vector_algorithm_type == "exhaustiveKnn": self._vector_profile_name = "myExhaustiveKnnProfile" @@ -408,6 +548,22 @@ def __init__( index_name=index_name ) + elif isinstance(search_or_index_client, AsyncSearchIndexClient): + # If SearchIndexClient is supplied so must index_name + self._async_index_client = cast( + AsyncSearchIndexClient, search_or_index_client + ) + + if not index_name: + raise ValueError( + "index_name must be supplied if search_or_index_client is of " + "type azure.search.documents.aio.SearchIndexClient" + ) + + self._async_search_client = self._async_index_client.get_search_client( + index_name=index_name + ) + elif isinstance(search_or_index_client, SearchClient): self._search_client = cast(SearchClient, search_or_index_client) @@ -418,23 +574,43 @@ def __init__( "is of type azure.search.documents.SearchClient" ) - if not self._index_client and not self._search_client: - raise ValueError( - "search_or_index_client must be of type " - "azure.search.documents.SearchClient or " - "azure.search.documents.SearchIndexClient" + elif isinstance(search_or_index_client, AsyncSearchClient): + self._async_search_client = cast( + AsyncSearchClient, search_or_index_client ) + + # Validate index_name + if index_name: + raise ValueError( + "index_name cannot be supplied if search_or_index_client " + "is of type azure.search.documents.SearchClient" + ) + + if isinstance(search_or_index_client, AsyncSearchIndexClient): + if not self._async_index_client and not self._async_search_client: + raise ValueError( + "search_or_index_client must be of type " + "azure.search.documents.SearchIndexClient or " + "azure.search.documents.SearchClient" + ) + + if isinstance(search_or_index_client, SearchIndexClient): + if not self._index_client and not self._search_client: + raise ValueError( + "search_or_index_client must be of type " + "azure.search.documents.SearchIndexClient or " + "azure.search.documents.SearchClient" + ) else: raise ValueError("search_or_index_client not specified") - if ( - index_management == IndexManagement.CREATE_IF_NOT_EXISTS - and not self._index_client + if index_management == IndexManagement.CREATE_IF_NOT_EXISTS and not ( + self._index_client or self._async_index_client ): raise ValueError( "index_management has value of IndexManagement.CREATE_IF_NOT_EXISTS " "but search_or_index_client is not of type " - "azure.search.documents.SearchIndexClient" + "azure.search.documents.SearchIndexClient or azure.search.documents.aio.SearchIndexClient " ) self._index_management = index_management @@ -459,12 +635,14 @@ def __init__( filterable_metadata_field_keys ) - if self._index_management == IndexManagement.CREATE_IF_NOT_EXISTS: - if index_name: - self._create_index_if_not_exists(index_name) + # need to do lazy init for async client + if not isinstance(search_or_index_client, AsyncSearchIndexClient): + if self._index_management == IndexManagement.CREATE_IF_NOT_EXISTS: + if index_name: + self._create_index_if_not_exists(index_name) - if self._index_management == IndexManagement.VALIDATE_INDEX: - self._validate_index(index_name) + if self._index_management == IndexManagement.VALIDATE_INDEX: + self._validate_index(index_name) super().__init__() @@ -473,6 +651,11 @@ def client(self) -> Any: """Get client.""" return self._search_client + @property + def aclient(self) -> Any: + """Get async client.""" + return self._async_search_client + def _default_index_mapping( self, enriched_doc: Dict[str, str], metadata: Dict[str, Any] ) -> Dict[str, str]: @@ -551,6 +734,74 @@ def add( return ids + async def async_add( + self, + nodes: List[BaseNode], + **add_kwargs: Any, + ) -> List[str]: + """ + Add nodes to index associated with the configured search client. + + Args: + nodes: List[BaseNode]: nodes with embeddings + + """ + from azure.search.documents import IndexDocumentsBatch + + if not self._async_search_client: + raise ValueError("Async Search client not initialized") + + if len(nodes) > 0: + if self._index_management == IndexManagement.CREATE_IF_NOT_EXISTS: + if self._index_name: + await self._acreate_index_if_not_exists(self._index_name) + + if self._index_management == IndexManagement.VALIDATE_INDEX: + await self._avalidate_index(self._index_name) + + accumulator = IndexDocumentsBatch() + documents = [] + + ids = [] + accumulated_size = 0 + max_size = 16 * 1024 * 1024 # 16MB in bytes + max_docs = 1000 + + for node in nodes: + logger.debug(f"Processing embedding: {node.node_id}") + ids.append(node.node_id) + + index_document = self._create_index_document(node) + document_size = len( + str(node.get_content(metadata_mode=MetadataMode.NONE)).encode("utf-8") + ) + documents.append(index_document) + accumulated_size += document_size + + accumulator.add_upload_actions(index_document) + + if len(documents) >= max_docs or accumulated_size >= max_size: + logger.info( + f"Uploading batch of size {len(documents)}, " + f"current progress {len(ids)} of {len(nodes)}, " + f"accumulated size {accumulated_size / (1024 * 1024):.2f} MB" + ) + await self._async_search_client.index_documents(accumulator) + accumulator.dequeue_actions() + documents = [] + accumulated_size = 0 + + # Upload remaining batch + if documents: + logger.info( + f"Uploading remaining batch of size {len(documents)}, " + f"current progress {len(ids)} of {len(nodes)}, " + f"accumulated size {accumulated_size / (1024 * 1024):.2f} MB" + ) + await self._async_search_client.index_documents(accumulator) + + return ids + def _create_index_document(self, node: BaseNode) -> Dict[str, Any]: """Create AI Search index document from embedding result.""" doc: Dict[str, Any] = {} @@ -591,6 +842,30 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: logger.debug(f"Deleting {len(docs_to_delete)} documents") self._search_client.delete_documents(docs_to_delete) + async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """ + Delete documents from the AI Search Index + with doc_id_field_key field equal to ref_doc_id. + """ + # Locate documents to delete + filter = f'{self._field_mapping["doc_id"]} eq \'{ref_doc_id}\'' + + results = await self._async_search_client.search(search_text="*", filter=filter) + + logger.debug(f"Searching with filter {filter}") + + docs_to_delete = [] + + for result in results: + doc = {} + doc["id"] = result[self._field_mapping["id"]] + logger.debug(f"Found document to delete: {doc}") + docs_to_delete.append(doc) + + if len(docs_to_delete) > 0: + logger.debug(f"Deleting {len(docs_to_delete)} documents") + await self._search_client.delete_documents(docs_to_delete) + def _create_odata_filter(self, metadata_filters: MetadataFilters) -> str: """Generate an OData filter string using supplied metadata filters.""" odata_filter: List[str] = [] @@ -650,6 +925,31 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul ) return azure_query_result_search.search() + async def aquery( + self, query: VectorStoreQuery, **kwargs: Any + ) -> VectorStoreQueryResult: + odata_filter = None + if query.filters is not None: + odata_filter = self._create_odata_filter(query.filters) + azure_query_result_search: AzureQueryResultSearchBase = ( + AzureQueryResultSearchDefault( + query, self._field_mapping, odata_filter, self._async_search_client + ) + ) + if query.mode == VectorStoreQueryMode.SPARSE: + azure_query_result_search = AzureQueryResultSearchSparse( + query, self._field_mapping, odata_filter, self._async_search_client + ) + elif query.mode == VectorStoreQueryMode.HYBRID: + azure_query_result_search = AzureQueryResultSearchHybrid( + query, self._field_mapping, odata_filter, self._async_search_client + ) + elif query.mode == VectorStoreQueryMode.SEMANTIC_HYBRID: + azure_query_result_search = AzureQueryResultSearchSemanticHybrid( + query, self._field_mapping, odata_filter, self._async_search_client + ) + return await azure_query_result_search.asearch() + class AzureQueryResultSearchBase: def __init__( @@ -732,11 +1032,70 @@ def _create_query_result( nodes=node_result, similarities=score_result, ids=id_result ) + async def _acreate_query_result( + self, search_query: str, vectors: Optional[List[Any]] + ) -> VectorStoreQueryResult: + results = await self._search_client.search( + search_text=search_query, + vector_queries=vectors, + top=self._query.similarity_top_k, + select=self._select_fields, + filter=self._odata_filter, + ) + + id_result = [] + node_result = [] + score_result = [] + + async for result in results: + node_id = result[self._field_mapping["id"]] + metadata_str = result[self._field_mapping["metadata"]] + metadata = json.loads(metadata_str) if metadata_str else {} + score = result["@search.score"] + chunk = result[self._field_mapping["chunk"]] + + try: + node = metadata_dict_to_node(metadata) + node.set_content(chunk) + except Exception: + # NOTE: deprecated legacy logic for backward compatibility + metadata, node_info, relationships = legacy_metadata_dict_to_node( + metadata + ) + + node = TextNode( + text=chunk, + id_=node_id, + metadata=metadata, + start_char_idx=node_info.get("start", None), + end_char_idx=node_info.get("end", None), + relationships=relationships, + ) + + logger.debug(f"Retrieved node id {node_id} with node data of {node}") + + id_result.append(node_id) + node_result.append(node) + score_result.append(score) + + logger.debug( + f"Search query '{search_query}' returned {len(id_result)} results." + ) + + return VectorStoreQueryResult( + nodes=node_result, similarities=score_result, ids=id_result + ) + def search(self) -> VectorStoreQueryResult: search_query = self._create_search_query() vectors = self._create_query_vector() return self._create_query_result(search_query, vectors) + async def asearch(self) -> VectorStoreQueryResult: + search_query = self._create_search_query() + vectors = self._create_query_vector() + return await self._acreate_query_result(search_query, vectors) + class AzureQueryResultSearchDefault(AzureQueryResultSearchBase): def _create_query_vector(self) -> Optional[List[Any]]: diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml index 3115d45a61fcb..a0eff83c2386a 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml @@ -28,7 +28,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-vector-stores-azureaisearch" readme = "README.md" -version = "0.1.8" +version = "0.1.9" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"