From 73a176657648066d922dd8b3da7e3ed22f93fa2f Mon Sep 17 00:00:00 2001 From: Adam Schiller Date: Thu, 29 Feb 2024 10:30:33 -0800 Subject: [PATCH] add async capability to opensearch with tests and associated docker-compose.yml and pyproject.toml updates --- .../vector_stores/opensearch/base.py | 132 ++++++++++----- .../pyproject.toml | 5 +- .../tests/docker-compose.yml | 11 ++ .../tests/test_opensearch_client.py | 157 ++++++++++++++++++ 4 files changed, 267 insertions(+), 38 deletions(-) create mode 100644 llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/tests/docker-compose.yml create mode 100644 llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/tests/test_opensearch_client.py diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/llama_index/vector_stores/opensearch/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/llama_index/vector_stores/opensearch/base.py index 508ef45752dea..367520abf23c7 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/llama_index/vector_stores/opensearch/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/llama_index/vector_stores/opensearch/base.py @@ -1,7 +1,12 @@ """Elasticsearch/Opensearch vector store.""" + +import asyncio import json import uuid from typing import Any, Dict, Iterable, List, Optional, Union, cast + +import nest_asyncio + from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.schema import BaseNode, MetadataMode, TextNode @@ -16,9 +21,9 @@ metadata_dict_to_node, node_to_metadata_dict, ) -from opensearchpy import OpenSearch +from opensearchpy import AsyncOpenSearch from opensearchpy.exceptions import NotFoundError -from opensearchpy.helpers import bulk +from opensearchpy.helpers import async_bulk IMPORT_OPENSEARCH_PY_ERROR = ( "Could not import OpenSearch. Please install it with `pip install opensearch-py`." @@ -29,14 +34,14 @@ MATCH_ALL_QUERY = {"match_all": {}} # type: Dict -def _import_opensearch() -> Any: +def _import_async_opensearch() -> Any: """Import OpenSearch if available, otherwise raise error.""" - return OpenSearch + return AsyncOpenSearch -def _import_bulk() -> Any: +def _import_async_bulk() -> Any: """Import bulk if available, otherwise raise error.""" - return bulk + return async_bulk def _import_not_found_error() -> Any: @@ -44,21 +49,21 @@ def _import_not_found_error() -> Any: return NotFoundError -def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any: - """Get OpenSearch client from the opensearch_url, otherwise raise error.""" +def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any: + """Get AsyncOpenSearch client from the opensearch_url, otherwise raise error.""" try: - opensearch = _import_opensearch() + opensearch = _import_async_opensearch() client = opensearch(opensearch_url, **kwargs) except ValueError as e: raise ValueError( - f"OpenSearch client string provided is not in proper format. " + f"AsyncOpenSearch client string provided is not in proper format. " f"Got error: {e} " ) return client -def _bulk_ingest_embeddings( +async def _bulk_ingest_embeddings( client: Any, index_name: str, embeddings: List[List[float]], @@ -71,20 +76,20 @@ def _bulk_ingest_embeddings( max_chunk_bytes: Optional[int] = 1 * 1024 * 1024, is_aoss: bool = False, ) -> List[str]: - """Bulk Ingest Embeddings into given index.""" + """Async Bulk Ingest Embeddings into given index.""" if not mapping: mapping = {} - bulk = _import_bulk() + async_bulk = _import_async_bulk() not_found_error = _import_not_found_error() requests = [] return_ids = [] mapping = mapping try: - client.indices.get(index=index_name) + await client.indices.get(index=index_name) except not_found_error: - client.indices.create(index=index_name, body=mapping) + await client.indices.create(index=index_name, body=mapping) for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} @@ -102,9 +107,9 @@ def _bulk_ingest_embeddings( request["_id"] = _id requests.append(request) return_ids.append(_id) - bulk(client, requests, max_chunk_bytes=max_chunk_bytes) + await async_bulk(client, requests, max_chunk_bytes=max_chunk_bytes) if not is_aoss: - client.indices.refresh(index=index_name) + await client.indices.refresh(index=index_name) return return_ids @@ -135,7 +140,8 @@ def _knn_search_query( k: int, filters: Optional[MetadataFilters] = None, ) -> Dict: - """Do knn search. + """ + Do knn search. If there are no filters do approx-knn search. If there are (pre)-filters, do an exhaustive exact knn search using 'painless @@ -243,7 +249,8 @@ def _is_aoss_enabled(http_auth: Any) -> bool: class OpensearchVectorClient: - """Object encapsulating an Opensearch index that has vector search enabled. + """ + Object encapsulating an Opensearch index that has vector search enabled. If the index does not yet exist, it is created during init. Therefore, the underlying index is assumed to either: @@ -311,15 +318,22 @@ def __init__( } }, } - self._os_client = _get_opensearch_client(self._endpoint, **kwargs) + self._os_client = _get_async_opensearch_client(self._endpoint, **kwargs) not_found_error = _import_not_found_error() + event_loop = asyncio.get_event_loop() try: - self._os_client.indices.get(index=self._index) + event_loop.run_until_complete( + self._os_client.indices.get(index=self._index) + ) except not_found_error: - self._os_client.indices.create(index=self._index, body=idx_conf) - self._os_client.indices.refresh(index=self._index) + event_loop.run_until_complete( + self._os_client.indices.create(index=self._index, body=idx_conf) + ) + event_loop.run_until_complete( + self._os_client.indices.refresh(index=self._index) + ) - def index_results(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]: + async def index_results(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]: """Store results in the index.""" embeddings: List[List[float]] = [] texts: List[str] = [] @@ -331,7 +345,7 @@ def index_results(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]: texts.append(node.get_content(metadata_mode=MetadataMode.NONE)) metadatas.append(node_to_metadata_dict(node, remove_text=True)) - return _bulk_ingest_embeddings( + return await _bulk_ingest_embeddings( self._os_client, self._index, embeddings, @@ -345,16 +359,16 @@ def index_results(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]: is_aoss=self.is_aoss, ) - def delete_doc_id(self, doc_id: str) -> None: - """Delete a document. + async def delete_doc_id(self, doc_id: str) -> None: + """ + Delete a document. Args: doc_id (str): document id """ - body = {"query": {"match": {"metadata.ref_doc_id": doc_id}}} - self._os_client.delete_by_query(index=self._index, body=body) + await self._os_client.delete(index=self._index, id=doc_id) - def query( + async def aquery( self, query_mode: VectorStoreQueryMode, query_str: Optional[str], @@ -380,7 +394,7 @@ def query( ) params = None - res = self._os_client.search( + res = await self._os_client.search( index=self._index, body=search_query, params=params ) nodes = [] @@ -421,7 +435,8 @@ def query( class OpensearchVectorStore(BasePydanticVectorStore): - """Elasticsearch/Opensearch vector store. + """ + Elasticsearch/Opensearch vector store. Args: client (OpensearchVectorClient): Vector index client to use @@ -437,6 +452,7 @@ def __init__( ) -> None: """Initialize params.""" super().__init__() + nest_asyncio.apply() self._client = client @property @@ -449,13 +465,30 @@ def add( nodes: List[BaseNode], **add_kwargs: Any, ) -> List[str]: - """Add nodes to index. + """ + Add nodes to index. + + Args: + nodes: List[BaseNode]: list of nodes with embeddings. + + """ + return asyncio.get_event_loop().run_until_complete( + self.async_add(nodes, **add_kwargs) + ) + + async def async_add( + self, + nodes: List[BaseNode], + **add_kwargs: Any, + ) -> List[str]: + """ + Async add nodes to index. Args: nodes: List[BaseNode]: list of nodes with embeddings. """ - self._client.index_results(nodes) + await self._client.index_results(nodes) return [result.node_id for result in nodes] def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: @@ -466,10 +499,35 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: ref_doc_id (str): The doc_id of the document to delete. """ - self._client.delete_doc_id(ref_doc_id) + asyncio.get_event_loop().run_until_complete( + self.adelete(ref_doc_id, **delete_kwargs) + ) + + async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """ + Async delete nodes using with ref_doc_id. + + Args: + ref_doc_id (str): The doc_id of the document to delete. + + """ + await self._client.delete_doc_id(ref_doc_id) def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: - """Query index for top k most similar nodes. + """ + Query index for top k most similar nodes. + + Args: + query (VectorStoreQuery): Store query object. + + """ + return asyncio.get_event_loop().run_until_complete(self.aquery(query, **kwargs)) + + async def aquery( + self, query: VectorStoreQuery, **kwargs: Any + ) -> VectorStoreQueryResult: + """ + Async query index for top k most similar nodes. Args: query (VectorStoreQuery): Store query object. @@ -477,7 +535,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul """ query_embedding = cast(List[float], query.query_embedding) - return self._client.query( + return await self._client.aquery( query.mode, query.query_str, query_embedding, diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/pyproject.toml index 88d546fb78ebe..a443465d31270 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/pyproject.toml @@ -32,7 +32,10 @@ version = "0.1.4" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" -opensearch-py = "^2.4.2" + +[tool.poetry.dependencies.opensearch-py] +extras = ["async"] +version = "^2.4.2" [tool.poetry.group.dev.dependencies] ipython = "8.10.0" diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/tests/docker-compose.yml b/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/tests/docker-compose.yml new file mode 100644 index 0000000000000..753e6b3621132 --- /dev/null +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/tests/docker-compose.yml @@ -0,0 +1,11 @@ +version: "3" + +services: + opensearch: + image: opensearchproject/opensearch:latest + environment: + - discovery.type=single-node + - plugins.security.disabled=true + - OPENSEARCH_INITIAL_ADMIN_PASSWORD=Asd234%@#% + ports: + - "9200:9200" diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/tests/test_opensearch_client.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/tests/test_opensearch_client.py new file mode 100644 index 0000000000000..bcb0e1d9cb720 --- /dev/null +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-opensearch/tests/test_opensearch_client.py @@ -0,0 +1,157 @@ +import asyncio +import logging +import pytest +import uuid +from typing import List, Generator + +from llama_index.core.schema import TextNode +from llama_index.vector_stores.opensearch import ( + OpensearchVectorClient, + OpensearchVectorStore, +) +from llama_index.core.vector_stores.types import VectorStoreQuery + +## +# Start Opensearch locally +# cd tests +# docker-compose up +# +# Run tests +# pytest test_opensearch_client.py + +logging.basicConfig(level=logging.DEBUG) +evt_loop = asyncio.get_event_loop() + +try: + from opensearchpy import AsyncOpenSearch + + os_client = AsyncOpenSearch("localhost:9200") + evt_loop.run_until_complete(os_client.info()) + opensearch_not_available = False +except (ImportError, Exception): + opensearch_not_available = True +finally: + evt_loop.run_until_complete(os_client.close()) + + +def test_connection() -> None: + assert not opensearch_not_available + + +@pytest.fixture() +def index_name() -> str: + """Return the index name.""" + return f"test_{uuid.uuid4().hex}" + + +@pytest.fixture() +def os_store(index_name: str) -> Generator[OpensearchVectorStore, None, None]: + client = OpensearchVectorClient( + endpoint="localhost:9200", + index=index_name, + dim=3, + ) + + yield OpensearchVectorStore(client) + + # teardown step + # delete index + evt_loop.run_until_complete(client._os_client.indices.delete(index=index_name)) + # close client aiohttp session + evt_loop.run_until_complete(client._os_client.close()) + + +@pytest.fixture(scope="session") +def node_embeddings() -> List[TextNode]: + return [ + TextNode( + text="lorem ipsum", + id_="c330d77f-90bd-4c51-9ed2-57d8d693b3b0", + # relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-0")}, + metadata={ + "author": "Stephen King", + "theme": "Friendship", + }, + embedding=[1.0, 0.0, 0.0], + ), + TextNode( + text="lorem ipsum", + id_="c3d1e1dd-8fb4-4b8f-b7ea-7fa96038d39d", + # relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-1")}, + metadata={ + "director": "Francis Ford Coppola", + "theme": "Mafia", + }, + embedding=[0.0, 1.0, 0.0], + ), + TextNode( + text="lorem ipsum", + id_="c3ew11cd-8fb4-4b8f-b7ea-7fa96038d39d", + # relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-2")}, + metadata={ + "director": "Christopher Nolan", + }, + embedding=[0.0, 0.0, 1.0], + ), + TextNode( + text="I was taught that the way of progress was neither swift nor easy.", + id_="0b31ae71-b797-4e88-8495-031371a7752e", + # relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-3")}, + metadate={ + "author": "Marie Curie", + }, + embedding=[0.0, 0.0, 0.9], + ), + TextNode( + text=( + "The important thing is not to stop questioning." + + " Curiosity has its own reason for existing." + ), + id_="bd2e080b-159a-4030-acc3-d98afd2ba49b", + # relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-4")}, + metadate={ + "author": "Albert Einstein", + }, + embedding=[0.0, 0.0, 0.5], + ), + TextNode( + text=( + "I am no bird; and no net ensnares me;" + + " I am a free human being with an independent will." + ), + id_="f658de3b-8cef-4d1c-8bed-9a263c907251", + # relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-5")}, + metadate={ + "author": "Charlotte Bronte", + }, + embedding=[0.0, 0.0, 0.3], + ), + ] + + +def count_docs_in_index(os_store: OpensearchVectorStore) -> int: + """Refresh indices and return the count of documents in the index.""" + evt_loop.run_until_complete( + os_store.client._os_client.indices.refresh(index=os_store.client._index) + ) + count = evt_loop.run_until_complete( + os_store.client._os_client.count(index=os_store.client._index) + ) + return count["count"] + + +@pytest.mark.skipif(opensearch_not_available, reason="opensearch is not available") +def test_functionality( + os_store: OpensearchVectorStore, node_embeddings: List[TextNode] +) -> None: + # add + assert len(os_store.add(node_embeddings)) == len(node_embeddings) + # query + exp_node = node_embeddings[3] + query = VectorStoreQuery(query_embedding=exp_node.embedding, similarity_top_k=1) + query_result = os_store.query(query) + assert query_result.nodes + assert query_result.nodes[0].get_content() == exp_node.text + # delete + os_store.delete(exp_node.id_) + assert count_docs_in_index(os_store) == len(node_embeddings) - 1