Skip to content

Commit

Permalink
add async capability to opensearch with tests and associated docker-c…
Browse files Browse the repository at this point in the history
…ompose.yml and pyproject.toml updates
  • Loading branch information
ahs8w committed Feb 29, 2024
1 parent 348cad7 commit 73a1766
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""Elasticsearch/Opensearch vector store."""

import asyncio
import json
import uuid
from typing import Any, Dict, Iterable, List, Optional, Union, cast

import nest_asyncio

from llama_index.core.bridge.pydantic import PrivateAttr

from llama_index.core.schema import BaseNode, MetadataMode, TextNode
Expand All @@ -16,9 +21,9 @@
metadata_dict_to_node,
node_to_metadata_dict,
)
from opensearchpy import OpenSearch
from opensearchpy import AsyncOpenSearch
from opensearchpy.exceptions import NotFoundError
from opensearchpy.helpers import bulk
from opensearchpy.helpers import async_bulk

IMPORT_OPENSEARCH_PY_ERROR = (
"Could not import OpenSearch. Please install it with `pip install opensearch-py`."
Expand All @@ -29,36 +34,36 @@
MATCH_ALL_QUERY = {"match_all": {}} # type: Dict


def _import_opensearch() -> Any:
def _import_async_opensearch() -> Any:
"""Import OpenSearch if available, otherwise raise error."""
return OpenSearch
return AsyncOpenSearch


def _import_bulk() -> Any:
def _import_async_bulk() -> Any:
"""Import bulk if available, otherwise raise error."""
return bulk
return async_bulk


def _import_not_found_error() -> Any:
"""Import not found error if available, otherwise raise error."""
return NotFoundError


def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
"""Get OpenSearch client from the opensearch_url, otherwise raise error."""
def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
"""Get AsyncOpenSearch client from the opensearch_url, otherwise raise error."""
try:
opensearch = _import_opensearch()
opensearch = _import_async_opensearch()
client = opensearch(opensearch_url, **kwargs)

except ValueError as e:
raise ValueError(
f"OpenSearch client string provided is not in proper format. "
f"AsyncOpenSearch client string provided is not in proper format. "
f"Got error: {e} "
)
return client


def _bulk_ingest_embeddings(
async def _bulk_ingest_embeddings(
client: Any,
index_name: str,
embeddings: List[List[float]],
Expand All @@ -71,20 +76,20 @@ def _bulk_ingest_embeddings(
max_chunk_bytes: Optional[int] = 1 * 1024 * 1024,
is_aoss: bool = False,
) -> List[str]:
"""Bulk Ingest Embeddings into given index."""
"""Async Bulk Ingest Embeddings into given index."""
if not mapping:
mapping = {}

bulk = _import_bulk()
async_bulk = _import_async_bulk()
not_found_error = _import_not_found_error()
requests = []
return_ids = []
mapping = mapping

try:
client.indices.get(index=index_name)
await client.indices.get(index=index_name)
except not_found_error:
client.indices.create(index=index_name, body=mapping)
await client.indices.create(index=index_name, body=mapping)

for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
Expand All @@ -102,9 +107,9 @@ def _bulk_ingest_embeddings(
request["_id"] = _id
requests.append(request)
return_ids.append(_id)
bulk(client, requests, max_chunk_bytes=max_chunk_bytes)
await async_bulk(client, requests, max_chunk_bytes=max_chunk_bytes)
if not is_aoss:
client.indices.refresh(index=index_name)
await client.indices.refresh(index=index_name)
return return_ids


Expand Down Expand Up @@ -135,7 +140,8 @@ def _knn_search_query(
k: int,
filters: Optional[MetadataFilters] = None,
) -> Dict:
"""Do knn search.
"""
Do knn search.
If there are no filters do approx-knn search.
If there are (pre)-filters, do an exhaustive exact knn search using 'painless
Expand Down Expand Up @@ -243,7 +249,8 @@ def _is_aoss_enabled(http_auth: Any) -> bool:


class OpensearchVectorClient:
"""Object encapsulating an Opensearch index that has vector search enabled.
"""
Object encapsulating an Opensearch index that has vector search enabled.
If the index does not yet exist, it is created during init.
Therefore, the underlying index is assumed to either:
Expand Down Expand Up @@ -311,15 +318,22 @@ def __init__(
}
},
}
self._os_client = _get_opensearch_client(self._endpoint, **kwargs)
self._os_client = _get_async_opensearch_client(self._endpoint, **kwargs)
not_found_error = _import_not_found_error()
event_loop = asyncio.get_event_loop()
try:
self._os_client.indices.get(index=self._index)
event_loop.run_until_complete(
self._os_client.indices.get(index=self._index)
)
except not_found_error:
self._os_client.indices.create(index=self._index, body=idx_conf)
self._os_client.indices.refresh(index=self._index)
event_loop.run_until_complete(
self._os_client.indices.create(index=self._index, body=idx_conf)
)
event_loop.run_until_complete(
self._os_client.indices.refresh(index=self._index)
)

def index_results(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:
async def index_results(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:
"""Store results in the index."""
embeddings: List[List[float]] = []
texts: List[str] = []
Expand All @@ -331,7 +345,7 @@ def index_results(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:
texts.append(node.get_content(metadata_mode=MetadataMode.NONE))
metadatas.append(node_to_metadata_dict(node, remove_text=True))

return _bulk_ingest_embeddings(
return await _bulk_ingest_embeddings(
self._os_client,
self._index,
embeddings,
Expand All @@ -345,16 +359,16 @@ def index_results(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:
is_aoss=self.is_aoss,
)

def delete_doc_id(self, doc_id: str) -> None:
"""Delete a document.
async def delete_doc_id(self, doc_id: str) -> None:
"""
Delete a document.
Args:
doc_id (str): document id
"""
body = {"query": {"match": {"metadata.ref_doc_id": doc_id}}}
self._os_client.delete_by_query(index=self._index, body=body)
await self._os_client.delete(index=self._index, id=doc_id)

def query(
async def aquery(
self,
query_mode: VectorStoreQueryMode,
query_str: Optional[str],
Expand All @@ -380,7 +394,7 @@ def query(
)
params = None

res = self._os_client.search(
res = await self._os_client.search(
index=self._index, body=search_query, params=params
)
nodes = []
Expand Down Expand Up @@ -421,7 +435,8 @@ def query(


class OpensearchVectorStore(BasePydanticVectorStore):
"""Elasticsearch/Opensearch vector store.
"""
Elasticsearch/Opensearch vector store.
Args:
client (OpensearchVectorClient): Vector index client to use
Expand All @@ -437,6 +452,7 @@ def __init__(
) -> None:
"""Initialize params."""
super().__init__()
nest_asyncio.apply()
self._client = client

@property
Expand All @@ -449,13 +465,30 @@ def add(
nodes: List[BaseNode],
**add_kwargs: Any,
) -> List[str]:
"""Add nodes to index.
"""
Add nodes to index.
Args:
nodes: List[BaseNode]: list of nodes with embeddings.
"""
return asyncio.get_event_loop().run_until_complete(
self.async_add(nodes, **add_kwargs)
)

async def async_add(
self,
nodes: List[BaseNode],
**add_kwargs: Any,
) -> List[str]:
"""
Async add nodes to index.
Args:
nodes: List[BaseNode]: list of nodes with embeddings.
"""
self._client.index_results(nodes)
await self._client.index_results(nodes)
return [result.node_id for result in nodes]

def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
Expand All @@ -466,18 +499,43 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
ref_doc_id (str): The doc_id of the document to delete.
"""
self._client.delete_doc_id(ref_doc_id)
asyncio.get_event_loop().run_until_complete(
self.adelete(ref_doc_id, **delete_kwargs)
)

async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Async delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
await self._client.delete_doc_id(ref_doc_id)

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
"""
Query index for top k most similar nodes.
Args:
query (VectorStoreQuery): Store query object.
"""
return asyncio.get_event_loop().run_until_complete(self.aquery(query, **kwargs))

async def aquery(
self, query: VectorStoreQuery, **kwargs: Any
) -> VectorStoreQueryResult:
"""
Async query index for top k most similar nodes.
Args:
query (VectorStoreQuery): Store query object.
"""
query_embedding = cast(List[float], query.query_embedding)

return self._client.query(
return await self._client.aquery(
query.mode,
query.query_str,
query_embedding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ version = "0.1.4"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.1"
opensearch-py = "^2.4.2"

[tool.poetry.dependencies.opensearch-py]
extras = ["async"]
version = "^2.4.2"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
version: "3"

services:
opensearch:
image: opensearchproject/opensearch:latest
environment:
- discovery.type=single-node
- plugins.security.disabled=true
- OPENSEARCH_INITIAL_ADMIN_PASSWORD=Asd234%@#%
ports:
- "9200:9200"
Loading

0 comments on commit 73a1766

Please sign in to comment.