From cbd9ccf02745f75a7ff643a0272d058265407b15 Mon Sep 17 00:00:00 2001 From: Dominic Pelini <111786059+DomPeliniAerospike@users.noreply.github.com> Date: Fri, 26 Apr 2024 17:44:55 -0600 Subject: [PATCH 1/4] Added sync and reconfigured async to the aio module Change file structure for better organization Added helper functions to share code beetween base and aio modules Modified channel_provider to block on channel.close() Added async and aio test modules Changed names of VectorDbClient to client --- proto/codegen.sh | 6 +- src/aerospike_vector/__init__.py | 4 +- src/aerospike_vector/admin.py | 277 +++++++++++++ src/aerospike_vector/aio/__init__.py | 3 + .../{vectordb_admin.py => aio/admin.py} | 165 ++------ .../{vectordb_client.py => aio/client.py} | 200 ++-------- .../aio/internal/channel_provider.py | 98 +++++ src/aerospike_vector/client.py | 283 +++++++++++++ src/aerospike_vector/internal/__init__.py | 0 .../internal/channel_provider.py | 98 +++++ src/aerospike_vector/shared/__init__.py | 0 src/aerospike_vector/shared/admin_helpers.py | 158 ++++++++ .../shared/base_channel_provider.py | 79 ++++ .../shared/channel_provider_helpers.py | 50 +++ src/aerospike_vector/shared/client_helpers.py | 182 +++++++++ .../{ => shared}/conversions.py | 4 +- src/aerospike_vector/shared/helpers.py | 28 ++ .../{ => shared/proto_generated}/auth_pb2.py | 0 .../proto_generated}/auth_pb2_grpc.py | 0 .../{ => shared/proto_generated}/index_pb2.py | 0 .../proto_generated}/index_pb2_grpc.py | 0 .../proto_generated}/transact_pb2.py | 0 .../proto_generated}/transact_pb2_grpc.py | 0 .../{ => shared/proto_generated}/types_pb2.py | 0 .../proto_generated}/types_pb2_grpc.py | 0 .../proto_generated}/vector_db_pb2.py | 0 .../proto_generated}/vector_db_pb2_grpc.py | 0 src/aerospike_vector/types.py | 2 +- .../vectordb_channel_provider.py | 165 -------- tests/aio/__init__.py | 0 tests/{ => aio}/conftest.py | 18 +- tests/aio/requirements.txt | 4 + .../test_admin_client_index_create.py | 0 .../{ => aio}/test_admin_client_index_drop.py | 0 .../{ => aio}/test_admin_client_index_get.py | 0 .../test_admin_client_index_get_status.py | 0 .../{ => aio}/test_admin_client_index_list.py | 0 tests/{ => aio}/test_vector_client_exists.py | 0 tests/{ => aio}/test_vector_client_get.py | 0 tests/{ => aio}/test_vector_client_put.py | 0 tests/{ => aio}/test_vector_search.py | 1 + tests/pytest.ini | 1 - tests/sync/__init__.py | 0 tests/sync/conftest.py | 42 ++ tests/sync/requirements.txt | 3 + tests/sync/test_admin_client_index_create.py | 377 ++++++++++++++++++ tests/sync/test_admin_client_index_drop.py | 18 + tests/sync/test_admin_client_index_get.py | 29 ++ .../test_admin_client_index_get_status.py | 17 + tests/sync/test_admin_client_index_list.py | 17 + tests/sync/test_vector_client_exists.py | 47 +++ tests/sync/test_vector_client_get.py | 59 +++ tests/sync/test_vector_client_put.py | 41 ++ tests/sync/test_vector_search.py | 171 ++++++++ 54 files changed, 2157 insertions(+), 490 deletions(-) create mode 100644 src/aerospike_vector/admin.py create mode 100644 src/aerospike_vector/aio/__init__.py rename src/aerospike_vector/{vectordb_admin.py => aio/admin.py} (64%) rename src/aerospike_vector/{vectordb_client.py => aio/client.py} (60%) create mode 100644 src/aerospike_vector/aio/internal/channel_provider.py create mode 100644 src/aerospike_vector/client.py create mode 100644 src/aerospike_vector/internal/__init__.py create mode 100644 src/aerospike_vector/internal/channel_provider.py create mode 100644 src/aerospike_vector/shared/__init__.py create mode 100644 src/aerospike_vector/shared/admin_helpers.py create mode 100644 src/aerospike_vector/shared/base_channel_provider.py create mode 100644 src/aerospike_vector/shared/channel_provider_helpers.py create mode 100644 src/aerospike_vector/shared/client_helpers.py rename src/aerospike_vector/{ => shared}/conversions.py (98%) create mode 100644 src/aerospike_vector/shared/helpers.py rename src/aerospike_vector/{ => shared/proto_generated}/auth_pb2.py (100%) rename src/aerospike_vector/{ => shared/proto_generated}/auth_pb2_grpc.py (100%) rename src/aerospike_vector/{ => shared/proto_generated}/index_pb2.py (100%) rename src/aerospike_vector/{ => shared/proto_generated}/index_pb2_grpc.py (100%) rename src/aerospike_vector/{ => shared/proto_generated}/transact_pb2.py (100%) rename src/aerospike_vector/{ => shared/proto_generated}/transact_pb2_grpc.py (100%) rename src/aerospike_vector/{ => shared/proto_generated}/types_pb2.py (100%) rename src/aerospike_vector/{ => shared/proto_generated}/types_pb2_grpc.py (100%) rename src/aerospike_vector/{ => shared/proto_generated}/vector_db_pb2.py (100%) rename src/aerospike_vector/{ => shared/proto_generated}/vector_db_pb2_grpc.py (100%) delete mode 100644 src/aerospike_vector/vectordb_channel_provider.py create mode 100644 tests/aio/__init__.py rename tests/{ => aio}/conftest.py (72%) create mode 100644 tests/aio/requirements.txt rename tests/{ => aio}/test_admin_client_index_create.py (100%) rename tests/{ => aio}/test_admin_client_index_drop.py (100%) rename tests/{ => aio}/test_admin_client_index_get.py (100%) rename tests/{ => aio}/test_admin_client_index_get_status.py (100%) rename tests/{ => aio}/test_admin_client_index_list.py (100%) rename tests/{ => aio}/test_vector_client_exists.py (100%) rename tests/{ => aio}/test_vector_client_get.py (100%) rename tests/{ => aio}/test_vector_client_put.py (100%) rename tests/{ => aio}/test_vector_search.py (99%) delete mode 100644 tests/pytest.ini create mode 100644 tests/sync/__init__.py create mode 100644 tests/sync/conftest.py create mode 100644 tests/sync/requirements.txt create mode 100644 tests/sync/test_admin_client_index_create.py create mode 100644 tests/sync/test_admin_client_index_drop.py create mode 100644 tests/sync/test_admin_client_index_get.py create mode 100644 tests/sync/test_admin_client_index_get_status.py create mode 100644 tests/sync/test_admin_client_index_list.py create mode 100644 tests/sync/test_vector_client_exists.py create mode 100644 tests/sync/test_vector_client_get.py create mode 100644 tests/sync/test_vector_client_put.py create mode 100644 tests/sync/test_vector_search.py diff --git a/proto/codegen.sh b/proto/codegen.sh index ebf6a483..069515ac 100755 --- a/proto/codegen.sh +++ b/proto/codegen.sh @@ -6,10 +6,10 @@ cd "$(dirname "$0")" python3 -m pip install grpcio-tools python3 -m grpc_tools.protoc \ --proto_path=. \ - --python_out=../src/aerospike_vector/ \ - --grpc_python_out=../src/aerospike_vector/ \ + --python_out=../src/aerospike_vector/shared/proto_generated/ \ + --grpc_python_out=../src/aerospike_vector/shared/proto_generated/ \ *.proto # The generated imports are not relative and fail in generated packages. # Fix with relative imports. -find ../src/aerospike_vector/ -name "*.py" -exec sed -i -e 's/^import \(.*\)_pb2 /from . import \1_pb2 /g' {} \; +find ../src/aerospike_vector/shared/proto_generated/ -name "*.py" -exec sed -i -e 's/^import \(.*\)_pb2 /from . import \1_pb2 /g' {} \; diff --git a/src/aerospike_vector/__init__.py b/src/aerospike_vector/__init__.py index e0ed1a45..4b7456f4 100644 --- a/src/aerospike_vector/__init__.py +++ b/src/aerospike_vector/__init__.py @@ -1,5 +1,3 @@ -import logging +from .client import Client -name = "aerospike_vector" -logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/src/aerospike_vector/admin.py b/src/aerospike_vector/admin.py new file mode 100644 index 00000000..573cecaa --- /dev/null +++ b/src/aerospike_vector/admin.py @@ -0,0 +1,277 @@ +import logging +import sys +import time +from typing import Any, Optional, Union +import grpc + +from . import types +from .internal import channel_provider +from .shared import admin_helpers + +logger = logging.getLogger(__name__) + +class Client(object): + """ + Aerospike Vector Admin Client + + This client is designed to conduct Aerospike Vector administrative operation such as creating indexes, querying index information, and dropping indexes. + """ + + def __init__( + self, + *, + seeds: Union[types.HostPort, tuple[types.HostPort, ...]], + listener_name: Optional[str] = None, + is_loadbalancer: Optional[bool] = False, + ) -> None: + + seeds = admin_helpers.prepare_seeds(seeds) + + self._channelProvider = channel_provider.ChannelProvider( + seeds, listener_name, is_loadbalancer + ) + """ + Initialize the Aerospike Vector Admin Client. + + Args: + seeds (Union[types.HostPort, tuple[types.HostPort, ...]]): Used to create appropriate gRPC channels for interacting with Aerospike Vector. + listener_name (Optional[str], optional): Advertised listener for the client. Defaults to None. + is_loadbalancer (bool, optional): If true, the first seed address will be treated as a load balancer node. + + Raises: + Exception: Raised when no seed host is provided. + + """ + + def index_create( + self, + *, + namespace: str, + name: str, + vector_field: str, + dimensions: int, + vector_distance_metric: Optional[types.VectorDistanceMetric] = ( + types.VectorDistanceMetric.SQUARED_EUCLIDEAN + ), + sets: Optional[str] = None, + index_params: Optional[types.HnswParams] = None, + index_meta_data: Optional[dict[str, str]] = None, + ) -> None: + """ + Create an index. + + Args: + namespace (str): The namespace for the index. + name (str): The name of the index. + vector_field (str): The name of the field containing vector data. + dimensions (int): The number of dimensions in the vector data. + vector_distance_metric (Optional[types.VectorDistanceMetric], optional): + The distance metric used to compare when performing a vector search. + Defaults to VectorDistanceMetric.SQUARED_EUCLIDEAN. + sets (Optional[str], optional): The set used for the index. Defaults to None. + index_params (Optional[types.HnswParams], optional): + Parameters used for tuning vector search. Defaults to None. If index_params is None, then the default values specified for + types.HnswParams will be used. + index_meta_data (Optional[dict[str, str]], optional): Meta data associated with the index. Defaults to None. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method creates an index with the specified parameters and waits for the index creation to complete. + It waits for up to 100,000 seconds for the index creation to complete. + """ + (index_stub, index_create_request) = admin_helpers.prepare_index_create(self, namespace, name, vector_field, dimensions, vector_distance_metric, sets, index_params, index_meta_data, logger) + try: + index_stub.Create(index_create_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + try: + self._wait_for_index_creation( + namespace=namespace, name=name, timeout=100_000 + ) + except grpc.RpcError as e: + logger.error("Failed waiting for creation with error: %s", e) + raise e + + def index_drop(self, *, namespace: str, name: str) -> None: + """ + Drop an index. + + Args: + namespace (str): The namespace of the index. + name (str): The name of the index. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method drops an index with the specified parameters and waits for the index deletion to complete. + It waits for up to 100,000 seconds for the index deletion to complete. + """ + (index_stub, index_drop_request) = admin_helpers.prepare_index_drop(self, namespace, name, logger) + try: + index_stub.Drop(index_drop_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + try: + self._wait_for_index_deletion( + namespace=namespace, name=name, timeout=100_000 + ) + except grpc.RpcError as e: + logger.error("Failed waiting for deletion with error: %s", e) + raise e + + def index_list(self) -> list[dict]: + """ + List all indices. + + Returns: + list[dict]: A list of indices. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (index_stub, index_list_request) = admin_helpers.prepare_index_list(self, logger) + try: + response = index_stub.List(index_list_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + return admin_helpers.respond_index_list(response) + + def index_get( + self, *, namespace: str, name: str + ) -> dict[str, Union[int, str]]: + """ + Retrieve the information related with an index. + + Args: + namespace (str): The namespace of the index. + name (str): The name of the index. + + Returns: + dict[str, Union[int, str]: Information about an index. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (index_stub, index_get_request) = admin_helpers.prepare_index_get(self, namespace, name, logger) + try: + response = index_stub.Get(index_get_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + return admin_helpers.respond_index_get(response) + + + def index_get_status(self, *, namespace: str, name: str) -> int: + """ + Retrieve the number of records queued to be merged into an index. + + Args: + namespace (str): The namespace of the index. + name (str): The name of the index. + + Returns: + int: Records queued to be merged into an index. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector, + the records may not immediately begin to merge into the index. To wait for all records to be merged into an index, use vector_client.wait_for_index_completion. + + Warning: This API is subject to change. + """ + (index_stub, index_get_status_request) = admin_helpers.prepare_index_get_status(self, namespace, name, logger) + try: + response = index_stub.GetStatus(index_get_status_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + + return admin_helpers.respond_index_get_status(response) + + def _wait_for_index_creation( + self, *, namespace: str, name: str, timeout: int = sys.maxsize + ) -> None: + """ + Wait for the index to be created. + """ + (index_stub, wait_interval, start_time, _, _, index_creation_request) = admin_helpers.prepare_wait_for_index_waiting(self, namespace, name) + while True: + admin_helpers.check_timeout(start_time, timeout) + try: + index_stub.GetStatus(index_creation_request) + logger.debug("Index created succesfully") + # Index has been created + return + except grpc.RpcError as e: + if e.code() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.NOT_FOUND): + + # Wait for some more time. + time.sleep(wait_interval) + else: + logger.error("Failed with error: %s", e) + raise e + + def _wait_for_index_deletion( + self, *, namespace: str, name: str, timeout: int = sys.maxsize + ) -> None: + """ + Wait for the index to be deleted. + """ + + # Wait interval between polling + (index_stub, wait_interval, start_time, _, _, index_deletion_request) = admin_helpers.prepare_wait_for_index_waiting(self, namespace, name) + + while True: + admin_helpers.check_timeout(start_time, timeout) + + try: + index_stub.GetStatus(index_deletion_request) + # Wait for some more time. + time.sleep(wait_interval) + except grpc.RpcError as e: + if e.code() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.NOT_FOUND): + logger.debug("Index deleted succesfully") + # Index has been created + return + else: + raise e + + def close(self): + """ + Close the Aerospike Vector Admin Client. + + This method closes gRPC channels connected to Aerospike Vector. + + Note: + This method should be called when the VectorDbAdminClient is no longer needed to release resources. + """ + self._channelProvider.close() + + def __enter__(self): + """ + Enter an asynchronous context manager for the admin client. + + Returns: + VectorDbAdminlient: Aerospike Vector Admin Client instance. + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit an asynchronous context manager for the admin client. + """ + self.close() diff --git a/src/aerospike_vector/aio/__init__.py b/src/aerospike_vector/aio/__init__.py new file mode 100644 index 00000000..4b7456f4 --- /dev/null +++ b/src/aerospike_vector/aio/__init__.py @@ -0,0 +1,3 @@ +from .client import Client + + diff --git a/src/aerospike_vector/vectordb_admin.py b/src/aerospike_vector/aio/admin.py similarity index 64% rename from src/aerospike_vector/vectordb_admin.py rename to src/aerospike_vector/aio/admin.py index a24e0ec3..857a94aa 100644 --- a/src/aerospike_vector/vectordb_admin.py +++ b/src/aerospike_vector/aio/admin.py @@ -3,22 +3,15 @@ import sys import time from typing import Any, Optional, Union - -import google.protobuf.empty_pb2 -from google.protobuf.json_format import MessageToDict import grpc -from . import index_pb2 -from . import index_pb2_grpc -from . import types -from . import types_pb2 -from . import vectordb_channel_provider +from .. import types +from .internal import channel_provider +from ..shared import admin_helpers -empty = google.protobuf.empty_pb2.Empty() logger = logging.getLogger(__name__) - -class VectorDbAdminClient(object): +class Client(object): """ Aerospike Vector Admin Client @@ -33,13 +26,9 @@ def __init__( is_loadbalancer: Optional[bool] = False, ) -> None: - if not seeds: - raise Exception("at least one seed host needed") - - if isinstance(seeds, types.HostPort): - seeds = (seeds,) + seeds = admin_helpers.prepare_seeds(seeds) - self._channelProvider = vectordb_channel_provider.VectorDbChannelProvider( + self._channelProvider = channel_provider.ChannelProvider( seeds, listener_name, is_loadbalancer ) """ @@ -94,38 +83,9 @@ async def index_create( This method creates an index with the specified parameters and waits for the index creation to complete. It waits for up to 100,000 seconds for the index creation to complete. """ - index_stub = index_pb2_grpc.IndexServiceStub( - self._channelProvider.get_channel() - ) - if sets and not sets.strip(): - sets = None - - logger.debug( - "Creating index: namespace=%s, name=%s, vector_field=%s, dimensions=%d, vector_distance_metric=%s, " - "sets=%s, index_params=%s, index_meta_data=%s", - namespace, - name, - vector_field, - dimensions, - vector_distance_metric, - sets, - index_params, - index_meta_data, - ) - if index_params != None: - index_params = index_params._to_pb2() + (index_stub, index_create_request) = admin_helpers.prepare_index_create(self, namespace, name, vector_field, dimensions, vector_distance_metric, sets, index_params, index_meta_data, logger) try: - await index_stub.Create( - types_pb2.IndexDefinition( - id=types_pb2.IndexId(namespace=namespace, name=name), - vectorDistanceMetric=vector_distance_metric.value, - setFilter=sets, - hnswParams=index_params, - bin=vector_field, - dimensions=dimensions, - labels=index_meta_data, - ) - ) + await index_stub.Create(index_create_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e @@ -138,9 +98,6 @@ async def index_create( raise e async def index_drop(self, *, namespace: str, name: str) -> None: - index_stub = index_pb2_grpc.IndexServiceStub( - self._channelProvider.get_channel() - ) """ Drop an index. @@ -156,9 +113,9 @@ async def index_drop(self, *, namespace: str, name: str) -> None: This method drops an index with the specified parameters and waits for the index deletion to complete. It waits for up to 100,000 seconds for the index deletion to complete. """ - logger.debug("Dropping index: namespace=%s, name=%s", namespace, name) + (index_stub, index_drop_request) = admin_helpers.prepare_index_drop(self, namespace, name, logger) try: - await index_stub.Drop(types_pb2.IndexId(namespace=namespace, name=name)) + await index_stub.Drop(index_drop_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e @@ -181,37 +138,13 @@ async def index_list(self) -> list[dict]: grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - - index_stub = index_pb2_grpc.IndexServiceStub( - self._channelProvider.get_channel() - ) - - logger.debug("Getting index list") + (index_stub, index_list_request) = admin_helpers.prepare_index_list(self, logger) try: - response = await index_stub.List(empty) + response = await index_stub.List(index_list_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - response_list = [] - for index in response.indices: - response_dict = MessageToDict(index) - - # Modifying dict to adhere to PEP-8 naming - hnsw_params_dict = response_dict.pop("hnswParams", None) - - hnsw_params_dict["ef_construction"] = hnsw_params_dict.pop( - "efConstruction", None - ) - - batching_params_dict = hnsw_params_dict.pop("batchingParams", None) - batching_params_dict["max_records"] = batching_params_dict.pop( - "maxRecords", None - ) - hnsw_params_dict["batching_params"] = batching_params_dict - - response_dict["hnsw_params"] = hnsw_params_dict - response_list.append(response_dict) - return response_list + return admin_helpers.respond_index_list(response) async def index_get( self, *, namespace: str, name: str @@ -231,39 +164,14 @@ async def index_get( This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - - index_stub = index_pb2_grpc.IndexServiceStub( - self._channelProvider.get_channel() - ) - - logger.debug( - "Getting index information: namespace=%s, name=%s", namespace, name - ) + (index_stub, index_get_request) = admin_helpers.prepare_index_get(self, namespace, name, logger) try: - response = await index_stub.Get( - types_pb2.IndexId(namespace=namespace, name=name) - ) + response = await index_stub.Get(index_get_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e + return admin_helpers.respond_index_get(response) - response_dict = MessageToDict(response) - - # Modifying dict to adhere to PEP-8 naming - hnsw_params_dict = response_dict.pop("hnswParams", None) - - hnsw_params_dict["ef_construction"] = hnsw_params_dict.pop( - "efConstruction", None - ) - - batching_params_dict = hnsw_params_dict.pop("batchingParams", None) - batching_params_dict["max_records"] = batching_params_dict.pop( - "maxRecords", None - ) - hnsw_params_dict["batching_params"] = batching_params_dict - - response_dict["hnsw_params"] = hnsw_params_dict - return response_dict async def index_get_status(self, *, namespace: str, name: str) -> int: """ @@ -286,19 +194,14 @@ async def index_get_status(self, *, namespace: str, name: str) -> int: Warning: This API is subject to change. """ - index_stub = index_pb2_grpc.IndexServiceStub( - self._channelProvider.get_channel() - ) - logger.debug("Getting index status: namespace=%s, name=%s", namespace, name) + (index_stub, index_get_status_request) = admin_helpers.prepare_index_get_status(self, namespace, name, logger) try: - response = await index_stub.GetStatus( - types_pb2.IndexId(namespace=namespace, name=name) - ) + response = await index_stub.GetStatus(index_get_status_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return response.unmergedRecordCount + return admin_helpers.respond_index_get_status(response) async def _wait_for_index_creation( self, *, namespace: str, name: str, timeout: int = sys.maxsize @@ -306,21 +209,11 @@ async def _wait_for_index_creation( """ Wait for the index to be created. """ - - # Wait interval between polling - index_stub = index_pb2_grpc.IndexServiceStub( - self._channelProvider.get_channel() - ) - wait_interval = 0.100 - - start_time = time.monotonic() + (index_stub, wait_interval, start_time, _, _, index_creation_request) = admin_helpers.prepare_wait_for_index_waiting(self, namespace, name) while True: - if start_time + timeout < time.monotonic(): - raise "timed-out waiting for index creation" + admin_helpers.check_timeout(start_time, timeout) try: - await index_stub.GetStatus( - types_pb2.IndexId(namespace=namespace, name=name) - ) + await index_stub.GetStatus(index_creation_request) logger.debug("Index created succesfully") # Index has been created return @@ -341,19 +234,13 @@ async def _wait_for_index_deletion( """ # Wait interval between polling - index_stub = index_pb2_grpc.IndexServiceStub( - self._channelProvider.get_channel() - ) - wait_interval = 0.100 + (index_stub, wait_interval, start_time, _, _, index_deletion_request) = admin_helpers.prepare_wait_for_index_waiting(self, namespace, name) - start_time = time.monotonic() while True: - if start_time + timeout < time.monotonic(): - raise "timed-out waiting for index creation" + admin_helpers.check_timeout(start_time, timeout) + try: - await index_stub.GetStatus( - types_pb2.IndexId(namespace=namespace, name=name) - ) + await index_stub.GetStatus(index_deletion_request) # Wait for some more time. await asyncio.sleep(wait_interval) except grpc.RpcError as e: diff --git a/src/aerospike_vector/vectordb_client.py b/src/aerospike_vector/aio/client.py similarity index 60% rename from src/aerospike_vector/vectordb_client.py rename to src/aerospike_vector/aio/client.py index 5e4424d0..012f74b8 100644 --- a/src/aerospike_vector/vectordb_client.py +++ b/src/aerospike_vector/aio/client.py @@ -1,28 +1,19 @@ import asyncio import logging import sys -import time from typing import Any, Optional, Union -import google.protobuf.empty_pb2 import grpc -from . import conversions -from . import index_pb2 -from . import index_pb2_grpc -from . import transact_pb2 -from . import transact_pb2_grpc -from . import types -from . import types_pb2 -from . import vectordb_channel_provider +from .. import types +from .internal import channel_provider +from ..shared import client_helpers -empty = google.protobuf.empty_pb2.Empty() logger = logging.getLogger(__name__) - -class VectorDbClient(object): +class Client(object): """ - Aerospike Vector Vector Client + Aerospike Vector Admin Client This client specializes in performing database operations with vector data. Moreover, the client supports Hierarchical Navigable Small World (HNSW) vector searches, @@ -50,13 +41,8 @@ def __init__( Raises: Exception: Raised when no seed host is provided. """ - if not seeds: - raise Exception("at least one seed host needed") - - if isinstance(seeds, types.HostPort): - seeds = (seeds,) - - self._channelProvider = vectordb_channel_provider.VectorDbChannelProvider( + seeds = client_helpers.prepare_seeds(seeds) + self._channelProvider = channel_provider.ChannelProvider( seeds, listener_name, is_loadbalancer ) @@ -82,23 +68,10 @@ async def put( This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - transact_stub = transact_pb2_grpc.TransactStub( - self._channelProvider.get_channel() - ) - key = _get_key(namespace, set_name, key) - bin_list = [ - types_pb2.Bin(name=k, value=conversions.toVectorDbValue(v)) - for (k, v) in record_data.items() - ] - logger.debug( - "Putting record: namespace=%s, key=%s, record_data:%s, set_name:%s", - namespace, - key, - record_data, - set_name, - ) + (transact_stub, put_request) = client_helpers.prepare_put(self, namespace, key, record_data, set_name, logger) + try: - await transact_stub.Put(transact_pb2.PutRequest(key=key, bins=bin_list)) + await transact_stub.Put(put_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e @@ -128,30 +101,14 @@ async def get( grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - transact_stub = transact_pb2_grpc.TransactStub( - self._channelProvider.get_channel() - ) - key = _get_key(namespace, set_name, key) - bin_selector = _get_bin_selector(bin_names=bin_names) - logger.debug( - "Getting record: namespace=%s, key=%s, bin_names:%s, set_name:%s", - namespace, - key, - bin_names, - set_name, - ) + (transact_stub, key, get_request) = client_helpers.prepare_get(self, namespace, key, bin_names, set_name, logger) try: - result = await transact_stub.Get( - transact_pb2.GetRequest(key=key, binSelector=bin_selector) - ) + response = await transact_stub.Get(get_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return types.RecordWithKey( - key=conversions.fromVectorDbKey(key), - bins=conversions.fromVectorDbRecord(result), - ) + return client_helpers.respond_get(response, key) async def exists( self, *, namespace: str, key: Any, set_name: Optional[str] = None @@ -171,23 +128,14 @@ async def exists( grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - transact_stub = transact_pb2_grpc.TransactStub( - self._channelProvider.get_channel() - ) - key = _get_key(namespace, set_name, key) - logger.debug( - "Getting record existence: namespace=%s, key=%s, set_name:%s", - namespace, - key, - set_name, - ) + (transact_stub, key) = client_helpers.prepare_exists(self, namespace, key, set_name, logger) try: - result = await transact_stub.Exists(key) + response = await transact_stub.Exists(key) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return result.value + return client_helpers.respond_exists(response) async def is_indexed( self, @@ -216,29 +164,13 @@ async def is_indexed( grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - if not index_namespace: - index_namespace = namespace - - index_id = types_pb2.IndexId(namespace=index_namespace, name=index_name) - key = _get_key(namespace, set_name, key) - request = transact_pb2.IsIndexedRequest(key=key, indexId=index_id) - transact_stub = transact_pb2_grpc.TransactStub( - self._channelProvider.get_channel() - ) - logger.debug( - "Checking if index exists: namespace=%s, key=%s, index_name=%s, index_namespace=%s, set_name=%s", - namespace, - key, - index_name, - index_namespace, - set_name, - ) + (transact_stub, is_indexed_request) = client_helpers.prepare_is_indexed(self, namespace, key, index_name, index_namespace, set_name, logger) try: - result = await transact_stub.IsIndexed(request) + response = await transact_stub.IsIndexed(is_indexed_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return result.value + return client_helpers.respond_is_indexed(response) async def vector_search( self, @@ -270,36 +202,16 @@ async def vector_search( grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - transact_stub = transact_pb2_grpc.TransactStub( - self._channelProvider.get_channel() - ) - logger.debug( - "Performing vector search: namespace=%s, index_name=%s, query=%s, limit=%s, search_params=%s, bin_names=%s", - namespace, - index_name, - query, - limit, - search_params, - bin_names, - ) - if search_params != None: - search_params = search_params._to_pb2() + (transact_stub, vector_search_request) = client_helpers.prepare_vector_search(self, namespace, index_name, query, limit, search_params, bin_names, logger) + try: - results_stream = transact_stub.VectorSearch( - transact_pb2.VectorSearchRequest( - index=types_pb2.IndexId(namespace=namespace, name=index_name), - queryVector=(conversions.toVectorDbValue(query).vectorValue), - limit=limit, - hnswSearchParams=search_params, - binSelector=_get_bin_selector(bin_names=bin_names), - ) - ) + results_stream = transact_stub.VectorSearch(vector_search_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e async_results = [] async for result in results_stream: - async_results.append(conversions.fromVectorDbNeighbor(result)) + async_results.append(client_helpers.respond_neighbor(result)) return async_results @@ -325,25 +237,10 @@ async def wait_for_index_completion( the timeout is reached or the index has no pending index update operations. """ # Wait interval between polling - index_stub = index_pb2_grpc.IndexServiceStub( - self._channelProvider.get_channel() - ) - - wait_interval = 1 - - unmerged_record_count = sys.maxsize - start_time = time.monotonic() - unmerged_record_initialized = False + (index_stub, wait_interval, start_time, unmerged_record_initialized, double_check, index_completion_request) = client_helpers.prepare_wait_for_index_waiting(self, namespace, name) while True: - if start_time + timeout < time.monotonic(): - raise "timed-out waiting for index completion" - # Wait for in-memory batches to be flushed to storage. - if start_time + 10 < time.monotonic(): - unmerged_record_initialized = True try: - index_status = await index_stub.GetStatus( - types_pb2.IndexId(namespace=namespace, name=name) - ) + index_status = await index_stub.GetStatus(index_completion_request) except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNAVAILABLE: @@ -351,19 +248,13 @@ async def wait_for_index_completion( else: logger.error("Failed with error: %s", e) raise e - - if index_status.unmergedRecordCount > 0: - unmerged_record_initialized = True - - if ( - unmerged_record_count == 0 - and index_status.unmergedRecordCount == 0 - and unmerged_record_initialized == True - ): - return - # Update. - unmerged_record_count = index_status.unmergedRecordCount - await asyncio.sleep(wait_interval) + if client_helpers.check_completion_condition(start_time, timeout, index_status, unmerged_record_initialized): + if double_check: + return + else: + double_check = True + else: + await asyncio.sleep(wait_interval) async def close(self): """ @@ -389,29 +280,4 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): """ Exit an asynchronous context manager for the vector client. """ - await self.close() - - -def _get_bin_selector(*, bin_names: Optional[list] = None): - - if not bin_names: - bin_selector = transact_pb2.BinSelector( - type=transact_pb2.BinSelectorType.ALL, binNames=bin_names - ) - else: - bin_selector = transact_pb2.BinSelector( - type=transact_pb2.BinSelectorType.SPECIFIED, binNames=bin_names - ) - return bin_selector - - -def _get_key(namespace: str, set: str, key: Union[int, str, bytes, bytearray]): - if isinstance(key, str): - key = types_pb2.Key(namespace=namespace, set=set, stringValue=key) - elif isinstance(key, int): - key = types_pb2.Key(namespace=namespace, set=set, longValue=key) - elif isinstance(key, (bytes, bytearray)): - key = types_pb2.Key(namespace=namespace, set=set, bytesValue=key) - else: - raise Exception("Invalid key type" + type(key)) - return key + await self.close() \ No newline at end of file diff --git a/src/aerospike_vector/aio/internal/channel_provider.py b/src/aerospike_vector/aio/internal/channel_provider.py new file mode 100644 index 00000000..c98fd48b --- /dev/null +++ b/src/aerospike_vector/aio/internal/channel_provider.py @@ -0,0 +1,98 @@ +import re +import asyncio +import logging +from typing import Optional, Union + +import google.protobuf.empty_pb2 +import grpc +import random + +from ... import types +from ...shared.proto_generated import vector_db_pb2 +from ...shared.proto_generated import vector_db_pb2_grpc +from ...shared import base_channel_provider +from ...shared import channel_provider_helpers + +empty = google.protobuf.empty_pb2.Empty() + +logger = logging.getLogger(__name__) + + +class ChannelProvider(base_channel_provider.BaseChannelProvider): + """Proximus Channel Provider""" + def __init__( + self, seeds: tuple[types.HostPort, ...], listener_name: Optional[str] = None, is_loadbalancer: Optional[bool] = False + ) -> None: + super().__init__(seeds, listener_name, is_loadbalancer) + asyncio.create_task(self._tend()) + + async def close(self): + self._closed = True + for channel in self._seedChannels: + await channel.close() + + for k, channelEndpoints in self._nodeChannels.items(): + if channelEndpoints.channel: + await channelEndpoints.channel.close() + + async def _tend(self): + (temp_endpoints, update_endpoints, channels, end_tend) = channel_provider_helpers.init_tend(self) + if end_tend: + return + + for seedChannel in channels: + + stub = vector_db_pb2_grpc.ClusterInfoStub(seedChannel) + + try: + new_cluster_id = await stub.GetClusterId(empty).id + if channel_provider_helpers.check_cluster_id(self, new_cluster_id): + update_endpoints = True + else: + continue + + except Exception as e: + logger.debug("While tending, failed to get cluster id with error:" + str(e)) + + + try: + response = await stub.GetClusterEndpoints( + vector_db_pb2.ClusterNodeEndpointsRequest( + listenerName=self.listener_name + ) + ) + except Exception as e: + logger.debug("While tending, failed to get cluster endpoints with error:" + str(e)) + + temp_endpoints = channel_provider_helpers.update_temp_endpoints(response, temp_endpoints) + + if update_endpoints: + for node, newEndpoints in temp_endpoints.items(): + (channel_endpoints, add_new_channel) = channel_provider_helpers.check_for_new_endpoints(self, node, newEndpoints) + if add_new_channel: + try: + # TODO: Wait for all calls to drain + await channel_endpoints.channel.close() + except Exception as e: + logger.debug("While tending, failed to close GRPC channel:" + str(e)) + + self.add_new_channel_to_node_channels(node, newEndpoints) + + for node, channel_endpoints in self._nodeChannels.items(): + if not temp_endpoints.get(node): + # TODO: Wait for all calls to drain + try: + await channel_endpoints.channel.close() + del self._nodeChannels[node] + + except Exception as e: + logger.debug("While tending, failed to close GRPC channel:" + str(e)) + + # TODO: check tend interval. + await asyncio.sleep(1) + asyncio.create_task(self._tend()) + + def _create_channel(self, host: str, port: int, isTls: bool) -> grpc.aio.Channel: + # TODO: Take care of TLS + host = re.sub(r"%.*", "", host) + return grpc.aio.insecure_channel(f"{host}:{port}") \ No newline at end of file diff --git a/src/aerospike_vector/client.py b/src/aerospike_vector/client.py new file mode 100644 index 00000000..9ddc07c9 --- /dev/null +++ b/src/aerospike_vector/client.py @@ -0,0 +1,283 @@ +import logging +import sys +import time +from typing import Any, Optional, Union + +import grpc + +from . import types +from .internal import channel_provider +from .shared import client_helpers + +logger = logging.getLogger(__name__) + +class Client(object): + """ + Aerospike Vector Admin Client + + This client specializes in performing database operations with vector data. + Moreover, the client supports Hierarchical Navigable Small World (HNSW) vector searches, + allowing users to find vectors similar to a given query vector within an index. + """ + + def __init__( + self, + *, + seeds: Union[types.HostPort, tuple[types.HostPort, ...]], + listener_name: str = None, + is_loadbalancer: Optional[bool] = False, + ) -> None: + """ + Initialize the Aerospike Vector Vector Client. + + Args: + seeds (Union[types.HostPort, tuple[types.HostPort, ...]]): + Used to create appropriate gRPC channels for interacting with Aerospike Vector. + listener_name (Optional[str], optional): + Advertised listener for the client. Defaults to None. + is_loadbalancer (bool, optional): + If true, the first seed address will be treated as a load balancer node. + + Raises: + Exception: Raised when no seed host is provided. + """ + seeds = client_helpers.prepare_seeds(seeds) + self._channelProvider = channel_provider.ChannelProvider( + seeds, listener_name, is_loadbalancer + ) + + def put( + self, + *, + namespace: str, + key: Union[int, str, bytes, bytearray], + record_data: dict[str, Any], + set_name: Optional[str] = None, + ) -> None: + """ + Write a record to Aerospike Vector. + + Args: + namespace (str): The namespace for the record. + key (Union[int, str, bytes, bytearray]): The key for the record. + record_data (dict[str, Any]): The data to be stored in the record. + set_name (Optional[str], optional): The name of the set to which the record belongs. Defaults to None. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (transact_stub, put_request) = client_helpers.prepare_put(self, namespace, key, record_data, set_name, logger) + + try: + transact_stub.Put(put_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + + def get( + self, + *, + namespace: str, + key: Union[int, str, bytes, bytearray], + bin_names: Optional[list[str]] = None, + set_name: Optional[str] = None, + ) -> types.RecordWithKey: + """ + Read a record from Aerospike Vector. + + Args: + namespace (str): The namespace for the record. + key (Union[int, str, bytes, bytearray]): The key for the record. + bin_names (Optional[list[str]], optional): A list of bin names to retrieve from the record. + If None, all bins are retrieved. Defaults to None. + set_name (Optional[str], optional): The name of the set from which to read the record. Defaults to None. + + Returns: + types.RecordWithKey: A record with its associated key. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (transact_stub, key, get_request) = client_helpers.prepare_get(self, namespace, key, bin_names, set_name, logger) + try: + response = transact_stub.Get(get_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + + return client_helpers.respond_get(response, key) + + def exists( + self, *, namespace: str, key: Any, set_name: Optional[str] = None + ) -> bool: + """ + Check if a record exists in Aerospike Vector. + + Args: + namespace (str): The namespace for the record. + key (Any): The key for the record. + set_name (Optional[str], optional): The name of the set to which the record belongs. Defaults to None. + + Returns: + bool: True if the record exists, False otherwise. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (transact_stub, key) = client_helpers.prepare_exists(self, namespace, key, set_name, logger) + try: + response = transact_stub.Exists(key) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + + return client_helpers.respond_exists(response) + + def is_indexed( + self, + *, + namespace: str, + key: Union[int, str, bytes, bytearray], + index_name: str, + index_namespace: Optional[str] = None, + set_name: Optional[str] = None, + ) -> bool: + """ + Check if a record is indexed in the Vector DB. + + Args: + namespace (str): The namespace for the record. + key (Union[int, str, bytes, bytearray]): The key for the record. + index_name (str): The name of the index. + index_namespace (Optional[str], optional): The namespace of the index. + If None, defaults to the namespace of the record. Defaults to None. + set_name (Optional[str], optional): The name of the set to which the record belongs. Defaults to None. + + Returns: + bool: True if the record is indexed, False otherwise. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (transact_stub, is_indexed_request) = client_helpers.prepare_is_indexed(self, namespace, key, index_name, index_namespace, set_name, logger) + try: + response = transact_stub.IsIndexed(is_indexed_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + return client_helpers.respond_is_indexed(response) + + def vector_search( + self, + *, + namespace: str, + index_name: str, + query: list[Union[bool, float]], + limit: int, + search_params: Optional[types.HnswSearchParams] = None, + bin_names: Optional[list[str]] = None, + ) -> list[types.Neighbor]: + """ + Perform a Hierarchical Navigable Small World (HNSW) vector search in Aerospike Vector. + + Args: + namespace (str): The namespace for the records. + index_name (str): The name of the index. + query (list[Union[bool, float]]): The query vector for the search. + limit (int): The maximum number of neighbors to return. K value. + search_params (Optional[types_pb2.HnswSearchParams], optional): Parameters for the HNSW algorithm. + If None, the default parameters for the index are used. Defaults to None. + bin_names (Optional[list[str]], optional): A list of bin names to retrieve from the results. + If None, all bins are retrieved. Defaults to None. + + Returns: + list[types.Neighbor]: A list of neighbors records found by the search. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (transact_stub, vector_search_request) = client_helpers.prepare_vector_search(self, namespace, index_name, query, limit, search_params, bin_names, logger) + + try: + results_stream = transact_stub.VectorSearch(vector_search_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + results = [] + for result in results_stream: + results.append(client_helpers.respond_neighbor(result)) + + return results + + def wait_for_index_completion( + self, *, namespace: str, name: str, timeout: Optional[int] = sys.maxsize + ) -> None: + """ + Wait for the index to have no pending index update operations. + + Args: + namespace (str): The namespace of the index. + name (str): The name of the index. + timeout (int, optional): The maximum time (in seconds) to wait for the index to complete. + Defaults to sys.maxsize. + + Raises: + Exception: Raised when the timeout occurs while waiting for index completion. + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + The function polls the index status with a wait interval of 10 seconds until either + the timeout is reached or the index has no pending index update operations. + """ + # Wait interval between polling + (index_stub, wait_interval, start_time, unmerged_record_initialized, double_check, index_completion_request) = client_helpers.prepare_wait_for_index_waiting(self, namespace, name) + while True: + try: + index_status = index_stub.GetStatus(index_completion_request) + + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + continue + else: + logger.error("Failed with error: %s", e) + raise e + if client_helpers.check_completion_condition(start_time, timeout, index_status, unmerged_record_initialized): + if double_check: + return + else: + double_check = True + else: + time.sleep(wait_interval) + + def close(self): + """ + Close the Aerospike Vector Vector Client. + + This method closes gRPC channels connected to Aerospike Vector. + + Note: + This method should be called when the VectorDbAdminClient is no longer needed to release resources. + """ + self._channelProvider.close() + + def __enter__(self): + """ + Enter an asynchronous context manager for the vector client. + + Returns: + VectorDbClient: Aerospike Vector Vector Client instance. + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit an asynchronous context manager for the vector client. + """ + self.close() \ No newline at end of file diff --git a/src/aerospike_vector/internal/__init__.py b/src/aerospike_vector/internal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aerospike_vector/internal/channel_provider.py b/src/aerospike_vector/internal/channel_provider.py new file mode 100644 index 00000000..b70192b8 --- /dev/null +++ b/src/aerospike_vector/internal/channel_provider.py @@ -0,0 +1,98 @@ +import re +import logging +import threading +from typing import Optional, Union + +import google.protobuf.empty_pb2 +import grpc + +from .. import types +from ..shared.proto_generated import vector_db_pb2 +from ..shared.proto_generated import vector_db_pb2_grpc +from ..shared import base_channel_provider +from ..shared import channel_provider_helpers + +empty = google.protobuf.empty_pb2.Empty() + +logger = logging.getLogger(__name__) + + +class ChannelProvider(base_channel_provider.BaseChannelProvider): + """Proximus Channel Provider""" + def __init__( + self, seeds: tuple[types.HostPort, ...], listener_name: Optional[str] = None, is_loadbalancer: Optional[bool] = False + ) -> None: + super().__init__(seeds, listener_name, is_loadbalancer) + self._tend() + + def close(self): + self._closed = True + for channel in self._seedChannels: + channel.close() + + for k, channelEndpoints in self._nodeChannels.items(): + if channelEndpoints.channel: + channelEndpoints.channel.close() + + def _tend(self): + (temp_endpoints, update_endpoints, channels, end_tend) = channel_provider_helpers.init_tend(self) + + if end_tend: + return + + for seedChannel in channels: + + stub = vector_db_pb2_grpc.ClusterInfoStub(seedChannel) + + try: + new_cluster_id = stub.GetClusterId(empty).id + if channel_provider_helpers.check_cluster_id(self, new_cluster_id): + update_endpoints = True + else: + continue + + except Exception as e: + logger.debug("While tending, failed to get cluster id with error:" + str(e)) + + + try: + response = stub.GetClusterEndpoints( + vector_db_pb2.ClusterNodeEndpointsRequest( + listenerName=self.listener_name + ) + ) + except Exception as e: + logger.debug("While tending, failed to get cluster endpoints with error:" + str(e)) + + temp_endpoints = channel_provider_helpers.update_temp_endpoints(response, temp_endpoints) + + if update_endpoints: + for node, newEndpoints in temp_endpoints.items(): + (channel_endpoints, add_new_channel) = channel_provider_helpers.check_for_new_endpoints(self, node, newEndpoints) + + if add_new_channel: + try: + # TODO: Wait for all calls to drain + channel_endpoints.channel.close() + except Exception as e: + logger.debug("While tending, failed to close GRPC channel:" + str(e)) + + self.add_new_channel_to_node_channels(node, newEndpoints) + + for node, channel_endpoints in self._nodeChannels.items(): + if not temp_endpoints.get(node): + # TODO: Wait for all calls to drain + try: + channel_endpoints.channel.close() + del self._nodeChannels[node] + + except Exception as e: + logger.debug("While tending, failed to close GRPC channel:" + str(e)) + + # TODO: check tend interval. + threading.Timer(1, self._tend).start() + + def _create_channel(self, host: str, port: int, isTls: bool) -> grpc.Channel: + # TODO: Take care of TLS + host = re.sub(r"%.*", "", host) + return grpc.insecure_channel(f"{host}:{port}") \ No newline at end of file diff --git a/src/aerospike_vector/shared/__init__.py b/src/aerospike_vector/shared/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aerospike_vector/shared/admin_helpers.py b/src/aerospike_vector/shared/admin_helpers.py new file mode 100644 index 00000000..017a1c37 --- /dev/null +++ b/src/aerospike_vector/shared/admin_helpers.py @@ -0,0 +1,158 @@ +import asyncio +import logging +from typing import Any, Optional, Union +import time + +import google.protobuf.empty_pb2 +from google.protobuf.json_format import MessageToDict +import grpc + +from . import helpers +from .proto_generated import index_pb2_grpc +from .proto_generated import types_pb2 +from .. import types + +logger = logging.getLogger(__name__) + +empty = google.protobuf.empty_pb2.Empty() + +def prepare_seeds(seeds) -> None: + return helpers.prepare_seeds(seeds) + +def prepare_index_create(self, namespace, name, vector_field, dimensions, vector_distance_metric, sets, index_params, index_meta_data, logger) -> None: + + logger.debug( + "Creating index: namespace=%s, name=%s, vector_field=%s, dimensions=%d, vector_distance_metric=%s, " + "sets=%s, index_params=%s, index_meta_data=%s", + namespace, + name, + vector_field, + dimensions, + vector_distance_metric, + sets, + index_params, + index_meta_data, + ) + + if sets and not sets.strip(): + sets = None + if index_params != None: + index_params = index_params._to_pb2() + id = get_index_id(namespace, name) + vector_distance_metric = vector_distance_metric.value + + index_stub = get_index_stub(self) + + index_create_request = types_pb2.IndexDefinition( + id=id, + vectorDistanceMetric=vector_distance_metric, + setFilter=sets, + hnswParams=index_params, + bin=vector_field, + dimensions=dimensions, + labels=index_meta_data, + ) + return (index_stub, index_create_request) + +def prepare_index_drop(self, namespace, name, logger) -> None: + + logger.debug("Dropping index: namespace=%s, name=%s", namespace, name) + + index_stub = get_index_stub(self) + index_drop_request = get_index_id(namespace, name) + + return (index_stub, index_drop_request) + +def prepare_index_list(self, logger) -> None: + + logger.debug("Getting index list") + + index_stub = get_index_stub(self) + index_list_request = empty + + return (index_stub, index_list_request) + +def prepare_index_get(self, namespace, name, logger) -> None: + + logger.debug( + "Getting index information: namespace=%s, name=%s", namespace, name + ) + + index_stub = get_index_stub(self) + index_get_request = get_index_id(namespace, name) + + return (index_stub, index_get_request) + +def prepare_index_get_status(self, namespace, name, logger) -> None: + + logger.debug("Getting index status: namespace=%s, name=%s", namespace, name) + + index_stub = get_index_stub(self) + index_get_status_request = get_index_id(namespace, name) + + + return (index_stub, index_get_status_request) + +def respond_index_list(response) -> None: + response_list = [] + for index in response.indices: + response_dict = MessageToDict(index) + + # Modifying dict to adhere to PEP-8 naming + hnsw_params_dict = response_dict.pop("hnswParams", None) + + hnsw_params_dict["ef_construction"] = hnsw_params_dict.pop( + "efConstruction", None + ) + + batching_params_dict = hnsw_params_dict.pop("batchingParams", None) + batching_params_dict["max_records"] = batching_params_dict.pop( + "maxRecords", None + ) + hnsw_params_dict["batching_params"] = batching_params_dict + + response_dict["hnsw_params"] = hnsw_params_dict + response_list.append(response_dict) + return response_list + +def respond_index_get(response) -> None: + response_dict = MessageToDict(response) + + # Modifying dict to adhere to PEP-8 naming + hnsw_params_dict = response_dict.pop("hnswParams", None) + + hnsw_params_dict["ef_construction"] = hnsw_params_dict.pop( + "efConstruction", None + ) + + batching_params_dict = hnsw_params_dict.pop("batchingParams", None) + batching_params_dict["max_records"] = batching_params_dict.pop( + "maxRecords", None + ) + hnsw_params_dict["batching_params"] = batching_params_dict + + response_dict["hnsw_params"] = hnsw_params_dict + return response_dict + +def respond_index_get_status(response) -> None: + return response.unmergedRecordCount + +def get_index_stub(self): + return index_pb2_grpc.IndexServiceStub( + self._channelProvider.get_channel() + ) + +def get_index_id(namespace, name): + return types_pb2.IndexId(namespace=namespace, name=name) + +def prepare_wait_for_index_waiting(self, namespace, name): + return helpers.prepare_wait_for_index_waiting(self, namespace, name) + + +def check_timeout(start_time, timeout): + if start_time + timeout < time.monotonic(): + raise "timed-out waiting for index creation" + + + + diff --git a/src/aerospike_vector/shared/base_channel_provider.py b/src/aerospike_vector/shared/base_channel_provider.py new file mode 100644 index 00000000..ad69750d --- /dev/null +++ b/src/aerospike_vector/shared/base_channel_provider.py @@ -0,0 +1,79 @@ +import logging +import random + +from typing import Optional, Union + +import grpc + +from .. import types +from .proto_generated import vector_db_pb2 + +logger = logging.getLogger(__name__) + +class ChannelAndEndpoints(object): + def __init__( + self, channel: Union[grpc.Channel, grpc.aio.Channel], endpoints: vector_db_pb2.ServerEndpointList + ) -> None: + self.channel = channel + self.endpoints = endpoints + + +class BaseChannelProvider(object): + """Proximus Channel Provider""" + def __init__( + self, seeds: tuple[types.HostPort, ...], listener_name: Optional[str] = None, is_loadbalancer: Optional[bool] = False + ) -> None: + if not seeds: + raise Exception("at least one seed host needed") + self._nodeChannels: dict[int, ChannelAndEndpoints] = {} + self._seedChannels: Union[dict[grpc.Channel], dict[grpc.Channel.aio]] = {} + self._closed = False + self._clusterId = 0 + self.seeds = seeds + self.listener_name = listener_name + self._is_loadbalancer = is_loadbalancer + self._seedChannels = [ + self._create_channel_from_host_port(seed) for seed in self.seeds + ] + + def get_channel(self) -> Union[grpc.aio.Channel, grpc.Channel]: + if not self._is_loadbalancer: + discovered_channels: list[ChannelAndEndpoints] = list( + self._nodeChannels.values()) + if len(discovered_channels) <= 0: + return self._seedChannels[0] + + + # Return a random channel. + channel = random.choice(discovered_channels).channel + if channel: + return channel + + return self._seedChannels[0] + def _create_channel_from_host_port(self, host: types.HostPort) -> Union[grpc.aio.Channel, grpc.Channel]: + return self._create_channel(host.host, host.port, host.isTls) + + def _create_channel_from_server_endpoint_list( + self, endpoints: vector_db_pb2.ServerEndpointList + ) -> Union[grpc.aio.Channel, grpc.Channel]: + # TODO: Create channel with all endpoints + for endpoint in endpoints.endpoints: + if ":" in endpoint.address: + # TODO: Ignoring IPv6 for now. Needs fix + continue + try: + return self._create_channel( + endpoint.address, endpoint.port, endpoint.isTls + ) + except Exception as e: + logger.debug("failure creating channel: " + str(e)) + + def add_new_channel_to_node_channels(self, node, newEndpoints): + + # We have discovered a new node + new_channel = self._create_channel_from_server_endpoint_list( + newEndpoints + ) + self._nodeChannels[node] = ChannelAndEndpoints( + new_channel, newEndpoints + ) diff --git a/src/aerospike_vector/shared/channel_provider_helpers.py b/src/aerospike_vector/shared/channel_provider_helpers.py new file mode 100644 index 00000000..d832ad65 --- /dev/null +++ b/src/aerospike_vector/shared/channel_provider_helpers.py @@ -0,0 +1,50 @@ +def init_tend(self) -> None: + end_tend = False + if self._is_loadbalancer: + # Skip tend if we are behind a load-balancer + end_tend = True + + if self._closed: + end_tend = True + + # TODO: Worry about thread safety + temp_endpoints: dict[int, vector_db_pb2.ServerEndpointList] = {} + + update_endpoints = False + channels = self._seedChannels + [ + x.channel for x in self._nodeChannels.values() + ] + return (temp_endpoints, update_endpoints, channels, end_tend) + + +def check_cluster_id(self, new_cluster_id) -> None: + if new_cluster_id == self._clusterId: + return False + + self._clusterId = new_cluster_id + + return True + +def update_temp_endpoints(response, temp_endpoints): + endpoints = response.endpoints + if len(endpoints) > len(temp_endpoints): + return endpoints + else: + return temp_endpoints + + +def check_for_new_endpoints(self, node, newEndpoints): + + channel_endpoints = self._nodeChannels.get(node) + add_new_channel = True + + if channel_endpoints: + # We have this node. Check if the endpoints changed. + if channel_endpoints.endpoints == newEndpoints: + # Nothing to be done for this node + add_new_channel = False + else: + add_new_channel = True + + return (channel_endpoints, add_new_channel) + diff --git a/src/aerospike_vector/shared/client_helpers.py b/src/aerospike_vector/shared/client_helpers.py new file mode 100644 index 00000000..e6386455 --- /dev/null +++ b/src/aerospike_vector/shared/client_helpers.py @@ -0,0 +1,182 @@ +from typing import Any, Optional, Union +import time +from . import conversions + +from .proto_generated import transact_pb2 +from .proto_generated import transact_pb2_grpc +from .. import types +from .proto_generated import types_pb2 +from . import helpers + + +def prepare_seeds(seeds) -> None: + return helpers.prepare_seeds(seeds) + +def prepare_put(self, namespace, key, record_data, set_name, logger) -> None: + + logger.debug( + "Putting record: namespace=%s, key=%s, record_data:%s, set_name:%s", + namespace, + key, + record_data, + set_name, + ) + + key = _get_key(namespace, set_name, key) + bin_list = [ + types_pb2.Bin(name=k, value=conversions.toVectorDbValue(v)) + for (k, v) in record_data.items() + ] + + transact_stub = get_transact_stub(self) + put_request = transact_pb2.PutRequest(key=key, bins=bin_list) + + return (transact_stub, put_request) + +def prepare_get(self, namespace, key, bin_names, set_name, logger) -> None: + + logger.debug( + "Getting record: namespace=%s, key=%s, bin_names:%s, set_name:%s", + namespace, + key, + bin_names, + set_name, + ) + + + key = _get_key(namespace, set_name, key) + bin_selector = _get_bin_selector(bin_names=bin_names) + + transact_stub = get_transact_stub(self) + get_request = transact_pb2.GetRequest(key=key, binSelector=bin_selector) + + return (transact_stub, key, get_request) + +def prepare_exists(self, namespace, key, set_name, logger) -> None: + + logger.debug( + "Getting record existence: namespace=%s, key=%s, set_name:%s", + namespace, + key, + set_name, + ) + + key = _get_key(namespace, set_name, key) + + transact_stub = get_transact_stub(self) + + return (transact_stub, key) + +def prepare_is_indexed(self, namespace, key, index_name, index_namespace, set_name, logger) -> None: + + logger.debug( + "Checking if index exists: namespace=%s, key=%s, index_name=%s, index_namespace=%s, set_name=%s", + namespace, + key, + index_name, + index_namespace, + set_name, + ) + + if not index_namespace: + index_namespace = namespace + index_id = types_pb2.IndexId(namespace=index_namespace, name=index_name) + key = _get_key(namespace, set_name, key) + + transact_stub = get_transact_stub(self) + is_indexed_request = transact_pb2.IsIndexedRequest(key=key, indexId=index_id) + + return (transact_stub, is_indexed_request) + +def prepare_vector_search(self, namespace, index_name, query, limit, search_params, bin_names, logger) -> None: + + logger.debug( + "Performing vector search: namespace=%s, index_name=%s, query=%s, limit=%s, search_params=%s, bin_names=%s", + namespace, + index_name, + query, + limit, + search_params, + bin_names, + ) + + if search_params != None: + search_params = search_params._to_pb2() + bin_selector = _get_bin_selector(bin_names=bin_names) + index = types_pb2.IndexId(namespace=namespace, name=index_name) + query_vector = conversions.toVectorDbValue(query).vectorValue + + + transact_stub = get_transact_stub(self) + + vector_search_request = transact_pb2.VectorSearchRequest( + index=index, + queryVector=query_vector, + limit=limit, + hnswSearchParams=search_params, + binSelector=bin_selector, + ) + + return (transact_stub, vector_search_request) + +def get_transact_stub(self): + return transact_pb2_grpc.TransactStub( + self._channelProvider.get_channel() + ) + +def respond_get(response, key) -> None: + return types.RecordWithKey( + key=conversions.fromVectorDbKey(key), + bins=conversions.fromVectorDbRecord(response), + ) + +def respond_exists(response) -> None: + return response.value + +def respond_is_indexed(response) -> None: + return response.value + +def respond_neighbor(response) -> None: + return conversions.fromVectorDbNeighbor(response) + +def _get_bin_selector(*, bin_names: Optional[list] = None): + + if not bin_names: + bin_selector = transact_pb2.BinSelector( + type=transact_pb2.BinSelectorType.ALL, binNames=bin_names + ) + else: + bin_selector = transact_pb2.BinSelector( + type=transact_pb2.BinSelectorType.SPECIFIED, binNames=bin_names + ) + return bin_selector + +def _get_key(namespace: str, set: str, key: Union[int, str, bytes, bytearray]): + if isinstance(key, str): + key = types_pb2.Key(namespace=namespace, set=set, stringValue=key) + elif isinstance(key, int): + key = types_pb2.Key(namespace=namespace, set=set, longValue=key) + elif isinstance(key, (bytes, bytearray)): + key = types_pb2.Key(namespace=namespace, set=set, bytesValue=key) + else: + raise Exception("Invalid key type" + type(key)) + return key + +def prepare_wait_for_index_waiting(self, namespace, name): + return helpers.prepare_wait_for_index_waiting(self, namespace, name) + +def check_completion_condition(start_time, timeout, index_status, unmerged_record_initialized): + + if start_time + 10 < time.monotonic(): + unmerged_record_initialized = True + + if index_status.unmergedRecordCount > 0: + unmerged_record_initialized = True + + if ( + index_status.unmergedRecordCount == 0 + and unmerged_record_initialized == True + ): + return True + else: + return False diff --git a/src/aerospike_vector/conversions.py b/src/aerospike_vector/shared/conversions.py similarity index 98% rename from src/aerospike_vector/conversions.py rename to src/aerospike_vector/shared/conversions.py index 6a689416..b3614924 100644 --- a/src/aerospike_vector/conversions.py +++ b/src/aerospike_vector/shared/conversions.py @@ -1,7 +1,7 @@ from typing import Any -from . import types -from . import types_pb2 +from .. import types +from .proto_generated import types_pb2 def toVectorDbValue(value: Any) -> types_pb2.Value: diff --git a/src/aerospike_vector/shared/helpers.py b/src/aerospike_vector/shared/helpers.py new file mode 100644 index 00000000..e18c9ddb --- /dev/null +++ b/src/aerospike_vector/shared/helpers.py @@ -0,0 +1,28 @@ +import time +from .. import types +from .proto_generated import types_pb2 +from .proto_generated import index_pb2_grpc + +def prepare_seeds(seeds) -> None: + + if not seeds: + raise Exception("at least one seed host needed") + + if isinstance(seeds, types.HostPort): + seeds = (seeds,) + + return seeds + + +def prepare_wait_for_index_waiting(self, namespace, name): + + wait_interval = 0.100 + unmerged_record_initialized = False + start_time = time.monotonic() + double_check = False + + index_stub = index_pb2_grpc.IndexServiceStub( + self._channelProvider.get_channel() + ) + index_wait_request = types_pb2.IndexId(namespace=namespace, name=name) + return (index_stub, wait_interval, start_time, unmerged_record_initialized, False, index_wait_request) diff --git a/src/aerospike_vector/auth_pb2.py b/src/aerospike_vector/shared/proto_generated/auth_pb2.py similarity index 100% rename from src/aerospike_vector/auth_pb2.py rename to src/aerospike_vector/shared/proto_generated/auth_pb2.py diff --git a/src/aerospike_vector/auth_pb2_grpc.py b/src/aerospike_vector/shared/proto_generated/auth_pb2_grpc.py similarity index 100% rename from src/aerospike_vector/auth_pb2_grpc.py rename to src/aerospike_vector/shared/proto_generated/auth_pb2_grpc.py diff --git a/src/aerospike_vector/index_pb2.py b/src/aerospike_vector/shared/proto_generated/index_pb2.py similarity index 100% rename from src/aerospike_vector/index_pb2.py rename to src/aerospike_vector/shared/proto_generated/index_pb2.py diff --git a/src/aerospike_vector/index_pb2_grpc.py b/src/aerospike_vector/shared/proto_generated/index_pb2_grpc.py similarity index 100% rename from src/aerospike_vector/index_pb2_grpc.py rename to src/aerospike_vector/shared/proto_generated/index_pb2_grpc.py diff --git a/src/aerospike_vector/transact_pb2.py b/src/aerospike_vector/shared/proto_generated/transact_pb2.py similarity index 100% rename from src/aerospike_vector/transact_pb2.py rename to src/aerospike_vector/shared/proto_generated/transact_pb2.py diff --git a/src/aerospike_vector/transact_pb2_grpc.py b/src/aerospike_vector/shared/proto_generated/transact_pb2_grpc.py similarity index 100% rename from src/aerospike_vector/transact_pb2_grpc.py rename to src/aerospike_vector/shared/proto_generated/transact_pb2_grpc.py diff --git a/src/aerospike_vector/types_pb2.py b/src/aerospike_vector/shared/proto_generated/types_pb2.py similarity index 100% rename from src/aerospike_vector/types_pb2.py rename to src/aerospike_vector/shared/proto_generated/types_pb2.py diff --git a/src/aerospike_vector/types_pb2_grpc.py b/src/aerospike_vector/shared/proto_generated/types_pb2_grpc.py similarity index 100% rename from src/aerospike_vector/types_pb2_grpc.py rename to src/aerospike_vector/shared/proto_generated/types_pb2_grpc.py diff --git a/src/aerospike_vector/vector_db_pb2.py b/src/aerospike_vector/shared/proto_generated/vector_db_pb2.py similarity index 100% rename from src/aerospike_vector/vector_db_pb2.py rename to src/aerospike_vector/shared/proto_generated/vector_db_pb2.py diff --git a/src/aerospike_vector/vector_db_pb2_grpc.py b/src/aerospike_vector/shared/proto_generated/vector_db_pb2_grpc.py similarity index 100% rename from src/aerospike_vector/vector_db_pb2_grpc.py rename to src/aerospike_vector/shared/proto_generated/vector_db_pb2_grpc.py diff --git a/src/aerospike_vector/types.py b/src/aerospike_vector/types.py index 6a2a5eea..63d724ea 100644 --- a/src/aerospike_vector/types.py +++ b/src/aerospike_vector/types.py @@ -1,7 +1,7 @@ import enum from typing import Any, Optional -from . import types_pb2 +from .shared.proto_generated import types_pb2 class HostPort(object): diff --git a/src/aerospike_vector/vectordb_channel_provider.py b/src/aerospike_vector/vectordb_channel_provider.py deleted file mode 100644 index 09694218..00000000 --- a/src/aerospike_vector/vectordb_channel_provider.py +++ /dev/null @@ -1,165 +0,0 @@ -import re -import threading -import warnings -import logging -from typing import Optional - -import google.protobuf.empty_pb2 -import grpc -import random - -from . import types -from . import vector_db_pb2 -from . import vector_db_pb2_grpc - -empty = google.protobuf.empty_pb2.Empty() - -logger = logging.getLogger(__name__) - -class ChannelAndEndpoints(object): - def __init__( - self, channel: grpc.aio.Channel, endpoints: vector_db_pb2.ServerEndpointList - ) -> None: - self.channel = channel - self.endpoints = endpoints - - -class VectorDbChannelProvider(object): - """Proximus Channel Provider""" - def __init__( - self, seeds: tuple[types.HostPort, ...], listener_name: Optional[str] = None, is_loadbalancer: Optional[bool] = False - ) -> None: - if not seeds: - raise Exception("at least one seed host needed") - self._nodeChannels: dict[int, ChannelAndEndpoints] = {} - self._seedChannels: dict[grpc.aio.Channel] = {} - self._closed = False - self._clusterId = 0 - self.seeds = seeds - self.listener_name = listener_name - self._is_loadbalancer = is_loadbalancer - self._seedChannels = [ - self._create_channel_from_host_port(seed) for seed in self.seeds - ] - self._tend() - - async def close(self): - self._closed = True - for channel in self._seedChannels: - await channel.close() - - for k, channelEndpoints in self._nodeChannels.items(): - if channelEndpoints.channel: - await channelEndpoints.channel.close() - - def get_channel(self) -> grpc.Channel: - if not self._is_loadbalancer: - discovered_channels: list[ChannelAndEndpoints] = list( - self._nodeChannels.values()) - if len(discovered_channels) <= 0: - return self._seedChannels[0] - - - # Return a random channel. - channel = random.choice(discovered_channels).channel - if channel: - return channel - - return self._seedChannels[0] - - def _tend(self): - if self._is_loadbalancer: - # Skip tend if we are behind a load-balancer - return - - # TODO: Worry about thread safety - temp_endpoints: dict[int, vector_db_pb2.ServerEndpointList] = {} - - if self._closed: - return - - try: - update_endpoints = False - channels = self._seedChannels + [ - x.channel for x in self._nodeChannels.values() - ] - for seedChannel in channels: - try: - stub = vector_db_pb2_grpc.ClusterInfoStub(seedChannel) - newClusterId = stub.GetClusterId(empty).id - - if newClusterId == self._clusterId: - continue - - update_endpoints = True - self._clusterId = newClusterId - endpoints = stub.GetClusterEndpoints( - vector_db_pb2.ClusterNodeEndpointsRequest( - listenerName=self.listener_name - ) - ).endpoints - - if len(endpoints) > len(temp_endpoints): - temp_endpoints = endpoints - - except Exception as e: - logger.debug("failure tending cluster endpoints: " + str(e)) - - if update_endpoints: - for node, newEndpoints in temp_endpoints.items(): - channel_endpoints = self._nodeChannels.get(node) - add_new_channel = True - if channel_endpoints: - # We have this node. Check if the endpoints changed. - if channel_endpoints.endpoints == newEndpoints: - # Nothing to be done for this node - add_new_channel = False - else: - # TODO: Wait for all calls to drain - channel_endpoints.channel.close() - add_new_channel = True - - if add_new_channel: - # We have discovered a new node - new_channel = self._create_channel_from_server_endpoint_list( - newEndpoints - ) - self._nodeChannels[node] = ChannelAndEndpoints( - new_channel, newEndpoints - ) - - for node, channel_endpoints in self._nodeChannels.items(): - if not temp_endpoints.get(node): - # TODO: Wait for all calls to drain - channel_endpoints.channel.close() - del self._nodeChannels[node] - - except Exception as e: - logger.debug("failure tending: " + str(e)) - - if not self._closed: - # TODO: check tend interval. - threading.Timer(1, self._tend).start() - - def _create_channel_from_host_port(self, host: types.HostPort) -> grpc.aio.Channel: - return self._create_channel(host.host, host.port, host.isTls) - - def _create_channel_from_server_endpoint_list( - self, endpoints: vector_db_pb2.ServerEndpointList - ) -> grpc.aio.Channel: - # TODO: Create channel with all endpoints - for endpoint in endpoints.endpoints: - if ":" in endpoint.address: - # TODO: Ignoring IPv6 for now. Needs fix - continue - try: - return self._create_channel( - endpoint.address, endpoint.port, endpoint.isTls - ) - except Exception as e: - logger.debug("failure creating channel: " + str(e)) - - def _create_channel(self, host: str, port: int, isTls: bool) -> grpc.aio.Channel: - # TODO: Take care of TLS - host = re.sub(r"%.*", "", host) - return grpc.aio.insecure_channel(f"{host}:{port}") \ No newline at end of file diff --git a/tests/aio/__init__.py b/tests/aio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/aio/conftest.py similarity index 72% rename from tests/conftest.py rename to tests/aio/conftest.py index d15f017b..6f046eb6 100644 --- a/tests/conftest.py +++ b/tests/aio/conftest.py @@ -1,13 +1,14 @@ import pytest import asyncio -from aerospike_vector import vectordb_client, vectordb_admin, types - +from aerospike_vector.aio import Client +from aerospike_vector.aio.admin import Client as AdminClient +from aerospike_vector import types host = 'localhost' port = 5000 @pytest.fixture(scope="session", autouse=True) async def drop_all_indexes(): - async with vectordb_admin.VectorDbAdminClient( + async with AdminClient( seeds=types.HostPort(host=host, port=port) ) as client: index_list = await client.index_list() @@ -16,19 +17,18 @@ async def drop_all_indexes(): tasks.append(client.index_drop(namespace="test", name=item['id']['name'])) await asyncio.gather(*tasks) - -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") async def session_admin_client(): - client = vectordb_admin.VectorDbAdminClient( + client = AdminClient( seeds=types.HostPort(host=host, port=port) ) yield client await client.close() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") async def session_vector_client(): - client = vectordb_client.VectorDbClient( + client = Client( seeds=types.HostPort(host=host, port=port) ) yield client @@ -36,7 +36,7 @@ async def session_vector_client(): @pytest.fixture async def function_admin_client(): - client = vectordb_admin.VectorDbAdminClient( + client = AdminClient( seeds=types.HostPort(host=host, port=port) ) yield client diff --git a/tests/aio/requirements.txt b/tests/aio/requirements.txt new file mode 100644 index 00000000..483b5812 --- /dev/null +++ b/tests/aio/requirements.txt @@ -0,0 +1,4 @@ +numpy==1.26.4 +pytest==7.4.0 +pytest-aio==1.5.0 +.. \ No newline at end of file diff --git a/tests/test_admin_client_index_create.py b/tests/aio/test_admin_client_index_create.py similarity index 100% rename from tests/test_admin_client_index_create.py rename to tests/aio/test_admin_client_index_create.py diff --git a/tests/test_admin_client_index_drop.py b/tests/aio/test_admin_client_index_drop.py similarity index 100% rename from tests/test_admin_client_index_drop.py rename to tests/aio/test_admin_client_index_drop.py diff --git a/tests/test_admin_client_index_get.py b/tests/aio/test_admin_client_index_get.py similarity index 100% rename from tests/test_admin_client_index_get.py rename to tests/aio/test_admin_client_index_get.py diff --git a/tests/test_admin_client_index_get_status.py b/tests/aio/test_admin_client_index_get_status.py similarity index 100% rename from tests/test_admin_client_index_get_status.py rename to tests/aio/test_admin_client_index_get_status.py diff --git a/tests/test_admin_client_index_list.py b/tests/aio/test_admin_client_index_list.py similarity index 100% rename from tests/test_admin_client_index_list.py rename to tests/aio/test_admin_client_index_list.py diff --git a/tests/test_vector_client_exists.py b/tests/aio/test_vector_client_exists.py similarity index 100% rename from tests/test_vector_client_exists.py rename to tests/aio/test_vector_client_exists.py diff --git a/tests/test_vector_client_get.py b/tests/aio/test_vector_client_get.py similarity index 100% rename from tests/test_vector_client_get.py rename to tests/aio/test_vector_client_get.py diff --git a/tests/test_vector_client_put.py b/tests/aio/test_vector_client_put.py similarity index 100% rename from tests/test_vector_client_put.py rename to tests/aio/test_vector_client_put.py diff --git a/tests/test_vector_search.py b/tests/aio/test_vector_search.py similarity index 99% rename from tests/test_vector_search.py rename to tests/aio/test_vector_search.py index f3ff8c04..681dacee 100644 --- a/tests/test_vector_search.py +++ b/tests/aio/test_vector_search.py @@ -10,6 +10,7 @@ query_vector_number = 100 +# Print the current working directory def parse_sift_to_numpy_array(length, dim, byte_buffer, dtype): numpy = np.empty((length,), dtype=object) diff --git a/tests/pytest.ini b/tests/pytest.ini deleted file mode 100644 index eea2c180..00000000 --- a/tests/pytest.ini +++ /dev/null @@ -1 +0,0 @@ -[pytest] diff --git a/tests/sync/__init__.py b/tests/sync/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sync/conftest.py b/tests/sync/conftest.py new file mode 100644 index 00000000..ce92a375 --- /dev/null +++ b/tests/sync/conftest.py @@ -0,0 +1,42 @@ +import pytest +import asyncio +from aerospike_vector import Client +from aerospike_vector.admin import Client as AdminClient +from aerospike_vector import types + +host = 'localhost' +port = 5000 +@pytest.fixture(scope="session", autouse=True) +def drop_all_indexes(): + with AdminClient( + seeds=types.HostPort(host=host, port=port) + ) as client: + index_list = client.index_list() + tasks = [] + for item in index_list: + client.index_drop(namespace="test", name=item['id']['name']) + +@pytest.fixture(scope="module") +def session_admin_client(): + client = AdminClient( + seeds=types.HostPort(host=host, port=port) + ) + yield client + client.close() + + +@pytest.fixture(scope="module") +def session_vector_client(): + client = Client( + seeds=types.HostPort(host=host, port=port) + ) + yield client + client.close() + +@pytest.fixture +def function_admin_client(): + client = AdminClient( + seeds=types.HostPort(host=host, port=port) + ) + yield client + client.close() \ No newline at end of file diff --git a/tests/sync/requirements.txt b/tests/sync/requirements.txt new file mode 100644 index 00000000..3d934824 --- /dev/null +++ b/tests/sync/requirements.txt @@ -0,0 +1,3 @@ +numpy==1.26.4 +pytest==7.4.0 +.. \ No newline at end of file diff --git a/tests/sync/test_admin_client_index_create.py b/tests/sync/test_admin_client_index_create.py new file mode 100644 index 00000000..98b34b7a --- /dev/null +++ b/tests/sync/test_admin_client_index_create.py @@ -0,0 +1,377 @@ +import pytest +from aerospike_vector import types + + +class index_create_test_case: + def __init__( + self, + *, + namespace, + name, + vector_field, + dimensions, + vector_distance_metric, + sets, + index_params, + index_meta_data, + ): + self.namespace = namespace + self.name = name + self.vector_field = vector_field + self.dimensions = dimensions + if vector_distance_metric == None: + self.vector_distance_metric = types.VectorDistanceMetric.SQUARED_EUCLIDEAN + else: + self.vector_distance_metric = vector_distance_metric + self.sets = sets + self.index_params = index_params + self.index_meta_data = index_meta_data + + +@pytest.mark.parametrize( + "test_case", + [ + index_create_test_case( + namespace="test", + name="index_1", + vector_field="example_1", + dimensions=1024, + vector_distance_metric=None, + sets=None, + index_params=None, + index_meta_data=None, + ) + ], +) +def test_index_create(session_admin_client, test_case): + session_admin_client.index_create( + namespace=test_case.namespace, + name=test_case.name, + vector_field=test_case.vector_field, + dimensions=test_case.dimensions, + vector_distance_metric=test_case.vector_distance_metric, + sets=test_case.sets, + index_params=test_case.index_params, + index_meta_data=test_case.index_meta_data, + ) + results = session_admin_client.index_list() + found = False + for result in results: + if result['id']['name'] == test_case.name: + found = True + assert result['id']['namespace'] == test_case.namespace + assert result['dimensions'] == test_case.dimensions + assert result['bin'] == test_case.vector_field + assert result['hnsw_params']['m'] == 16 + assert result['hnsw_params']['ef_construction'] == 100 + assert result['hnsw_params']['ef'] == 100 + assert result['hnsw_params']['batching_params']['max_records'] == 100000 + assert result['hnsw_params']['batching_params']['interval'] == 30000 + assert result['hnsw_params']['batching_params']['disabled'] == False + assert result['aerospikeStorage']['namespace'] == test_case.namespace + assert result['aerospikeStorage']['set'] == test_case.name + assert found == True + +@pytest.mark.parametrize( + "test_case", + [ + index_create_test_case( + namespace="test", + name="index_2", + vector_field="example_2", + dimensions=495, + vector_distance_metric=None, + sets=None, + index_params=None, + index_meta_data=None, + ), + index_create_test_case( + namespace="test", + name="index_3", + vector_field="example_3", + dimensions=2048, + vector_distance_metric=None, + sets=None, + index_params=None, + index_meta_data=None, + ), + ], +) +def test_index_create_with_dimnesions(session_admin_client, test_case): + session_admin_client.index_create( + namespace=test_case.namespace, + name=test_case.name, + vector_field=test_case.vector_field, + dimensions=test_case.dimensions, + vector_distance_metric=test_case.vector_distance_metric, + sets=test_case.sets, + index_params=test_case.index_params, + index_meta_data=test_case.index_meta_data, + ) + results = session_admin_client.index_list() + found = False + for result in results: + if result['id']['name'] == test_case.name: + found = True + assert result['id']['namespace'] == test_case.namespace + assert result['dimensions'] == test_case.dimensions + assert result['bin'] == test_case.vector_field + assert result['hnsw_params']['m'] == 16 + assert result['hnsw_params']['ef_construction'] == 100 + assert result['hnsw_params']['ef'] == 100 + assert result['hnsw_params']['batching_params']['max_records'] == 100000 + assert result['hnsw_params']['batching_params']['interval'] == 30000 + assert result['hnsw_params']['batching_params']['disabled'] == False + assert result['aerospikeStorage']['namespace'] == test_case.namespace + assert result['aerospikeStorage']['set'] == test_case.name + assert found == True + +@pytest.mark.parametrize( + "test_case", + [ + index_create_test_case( + namespace="test", + name="index_4", + vector_field="example_4", + dimensions=1024, + vector_distance_metric=types.VectorDistanceMetric.COSINE, + sets=None, + index_params=None, + index_meta_data=None, + ), + index_create_test_case( + namespace="test", + name="index_5", + vector_field="example_5", + dimensions=1024, + vector_distance_metric=types.VectorDistanceMetric.DOT_PRODUCT, + sets=None, + index_params=None, + index_meta_data=None, + ), + index_create_test_case( + namespace="test", + name="index_6", + vector_field="example_6", + dimensions=1024, + vector_distance_metric=types.VectorDistanceMetric.MANHATTAN, + sets=None, + index_params=None, + index_meta_data=None, + ), + index_create_test_case( + namespace="test", + name="index_7", + vector_field="example_7", + dimensions=1024, + vector_distance_metric=types.VectorDistanceMetric.HAMMING, + sets=None, + index_params=None, + index_meta_data=None, + ), + ], +) +def test_index_create_with_vector_distance_metric( + session_admin_client, test_case +): + session_admin_client.index_create( + namespace=test_case.namespace, + name=test_case.name, + vector_field=test_case.vector_field, + dimensions=test_case.dimensions, + vector_distance_metric=test_case.vector_distance_metric, + sets=test_case.sets, + index_params=test_case.index_params, + index_meta_data=test_case.index_meta_data, + ) + results = session_admin_client.index_list() + found = False + for result in results: + if result['id']['name'] == test_case.name: + found = True + assert result['id']['namespace'] == test_case.namespace + assert result['dimensions'] == test_case.dimensions + assert result['bin'] == test_case.vector_field + assert result['hnsw_params']['m'] == 16 + assert result['hnsw_params']['ef_construction'] == 100 + assert result['hnsw_params']['ef'] == 100 + assert result['hnsw_params']['batching_params']['max_records'] == 100000 + assert result['hnsw_params']['batching_params']['interval'] == 30000 + assert result['hnsw_params']['batching_params']['disabled'] == False + assert result['aerospikeStorage']['namespace'] == test_case.namespace + assert result['aerospikeStorage']['set'] == test_case.name + assert found == True + +@pytest.mark.parametrize( + "test_case", + [ + index_create_test_case( + namespace="test", + name="index_8", + vector_field="example_8", + dimensions=1024, + vector_distance_metric=None, + sets="Demo", + index_params=None, + index_meta_data=None, + ), + index_create_test_case( + namespace="test", + name="index_9", + vector_field="example_9", + dimensions=1024, + vector_distance_metric=None, + sets="Cheese", + index_params=None, + index_meta_data=None, + ), + ], +) +def test_index_create_with_sets(session_admin_client, test_case): + session_admin_client.index_create( + namespace=test_case.namespace, + name=test_case.name, + vector_field=test_case.vector_field, + dimensions=test_case.dimensions, + vector_distance_metric=test_case.vector_distance_metric, + sets=test_case.sets, + index_params=test_case.index_params, + index_meta_data=test_case.index_meta_data, + ) + results = session_admin_client.index_list() + found = False + for result in results: + if result['id']['name'] == test_case.name: + found = True + assert result['id']['namespace'] == test_case.namespace + assert result['dimensions'] == test_case.dimensions + assert result['bin'] == test_case.vector_field + assert result['hnsw_params']['m'] == 16 + assert result['hnsw_params']['ef_construction'] == 100 + assert result['hnsw_params']['ef'] == 100 + assert result['hnsw_params']['batching_params']['max_records'] == 100000 + assert result['hnsw_params']['batching_params']['interval'] == 30000 + assert result['hnsw_params']['batching_params']['disabled'] == False + assert result['aerospikeStorage']['namespace'] == test_case.namespace + assert result['aerospikeStorage']['set'] == test_case.name + assert found == True + +@pytest.mark.parametrize( + "test_case", + [ + index_create_test_case( + namespace="test", + name="index_10", + vector_field="example_10", + dimensions=1024, + vector_distance_metric=None, + sets=None, + index_params=types.HnswParams( + m=32, + ef_construction=200, + ef=400, + ), + index_meta_data=None, + ), + index_create_test_case( + namespace="test", + name="index_11", + vector_field="example_11", + dimensions=1024, + vector_distance_metric=None, + sets=None, + index_params=types.HnswParams( + m=8, + ef_construction=50, + ef=25, + ), + index_meta_data=None, + ), + index_create_test_case( + namespace="test", + name="index_12", + vector_field="example_12", + dimensions=1024, + vector_distance_metric=None, + sets=None, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + max_records=500, interval=500, disabled=True + ) + ), + index_meta_data=None, + ), + ], +) +def test_index_create_with_index_params(session_admin_client, test_case): + session_admin_client.index_create( + namespace=test_case.namespace, + name=test_case.name, + vector_field=test_case.vector_field, + dimensions=test_case.dimensions, + vector_distance_metric=test_case.vector_distance_metric, + sets=test_case.sets, + index_params=test_case.index_params, + index_meta_data=test_case.index_meta_data, + ) + results = session_admin_client.index_list() + found = False + for result in results: + if result['id']['name'] == test_case.name: + found = True + assert result['id']['namespace'] == test_case.namespace + assert result['dimensions'] == test_case.dimensions + assert result['bin'] == test_case.vector_field + assert result['hnsw_params']['m'] == test_case.index_params.m + assert result['hnsw_params']['ef_construction'] == test_case.index_params.ef_construction + assert result['hnsw_params']['ef'] == test_case.index_params.ef + assert result['hnsw_params']['batching_params']['max_records'] == test_case.index_params.batching_params.max_records + assert result['hnsw_params']['batching_params']['interval'] == test_case.index_params.batching_params.interval + assert result['hnsw_params']['batching_params']['disabled'] == test_case.index_params.batching_params.disabled + assert result['aerospikeStorage']['namespace'] == test_case.namespace + assert result['aerospikeStorage']['set'] == test_case.name + assert found == True + +@pytest.mark.parametrize( + "test_case", + [ + index_create_test_case( + namespace="test", + name="index_13", + vector_field="example_13", + dimensions=1024, + vector_distance_metric=None, + sets=None, + index_params=None, + index_meta_data={"size": "large", "price": "$4.99", "currencyType": "CAN"}, + ), + ], +) +def test_index_create_index_meta_data(session_admin_client, test_case): + session_admin_client.index_create( + namespace=test_case.namespace, + name=test_case.name, + vector_field=test_case.vector_field, + dimensions=test_case.dimensions, + vector_distance_metric=test_case.vector_distance_metric, + sets=test_case.sets, + index_params=test_case.index_params, + index_meta_data=test_case.index_meta_data, + ) + results = session_admin_client.index_list() + found = False + for result in results: + if result['id']['name'] == test_case.name: + found = True + assert result['id']['namespace'] == test_case.namespace + assert result['dimensions'] == test_case.dimensions + assert result['bin'] == test_case.vector_field + assert result['hnsw_params']['m'] == 16 + assert result['hnsw_params']['ef_construction'] == 100 + assert result['hnsw_params']['ef'] == 100 + assert result['hnsw_params']['batching_params']['max_records'] == 100000 + assert result['hnsw_params']['batching_params']['interval'] == 30000 + assert result['hnsw_params']['batching_params']['disabled'] == False + assert result['aerospikeStorage']['namespace'] == test_case.namespace + assert result['aerospikeStorage']['set'] == test_case.name + assert found == True \ No newline at end of file diff --git a/tests/sync/test_admin_client_index_drop.py b/tests/sync/test_admin_client_index_drop.py new file mode 100644 index 00000000..92777662 --- /dev/null +++ b/tests/sync/test_admin_client_index_drop.py @@ -0,0 +1,18 @@ +import pytest + +@pytest.fixture +def add_index(function_admin_client): + function_admin_client.index_create( + namespace="test", + name="index_drop_1", + vector_field="art", + dimensions=1024, + ) + +def test_index_drop(add_index, session_admin_client): + session_admin_client.index_drop(namespace="test", name="index_drop_1") + + result = session_admin_client.index_list() + result = result + for index in result: + assert index["id"]["name"] != "index_drop_1" diff --git a/tests/sync/test_admin_client_index_get.py b/tests/sync/test_admin_client_index_get.py new file mode 100644 index 00000000..2dc9a267 --- /dev/null +++ b/tests/sync/test_admin_client_index_get.py @@ -0,0 +1,29 @@ +import pytest + +@pytest.fixture +def add_index(function_admin_client): + function_admin_client.index_create( + namespace="test", + name="index_get_1", + vector_field="science", + dimensions=1024, + ) + + +def test_index_get(add_index, session_admin_client): + result = session_admin_client.index_get( + namespace="test", name="index_get_1" + ) + + assert result["id"]["name"] == "index_get_1" + assert result["id"]["namespace"] == "test" + assert result["dimensions"] == 1024 + assert result["bin"] == "science" + assert result["hnsw_params"]["m"] == 16 + assert result["hnsw_params"]["ef_construction"] == 100 + assert result["hnsw_params"]["ef"] == 100 + assert result["hnsw_params"]["batching_params"]["max_records"] == 100000 + assert result["hnsw_params"]["batching_params"]["interval"] == 30000 + assert not result["hnsw_params"]["batching_params"]["disabled"] + assert result["aerospikeStorage"]["namespace"] == "test" + assert result["aerospikeStorage"]["set"] == "index_get_1" \ No newline at end of file diff --git a/tests/sync/test_admin_client_index_get_status.py b/tests/sync/test_admin_client_index_get_status.py new file mode 100644 index 00000000..fee6bbe6 --- /dev/null +++ b/tests/sync/test_admin_client_index_get_status.py @@ -0,0 +1,17 @@ +import pytest + +@pytest.fixture +def add_index(function_admin_client): + function_admin_client.index_create( + namespace="test", + name="index_get_status_1", + vector_field="science", + dimensions=1024, + ) + + +def test_index_get_status(add_index, session_admin_client): + result = session_admin_client.index_get_status( + namespace="test", name="index_get_status_1" + ) + assert result == 0 diff --git a/tests/sync/test_admin_client_index_list.py b/tests/sync/test_admin_client_index_list.py new file mode 100644 index 00000000..7d4c0b92 --- /dev/null +++ b/tests/sync/test_admin_client_index_list.py @@ -0,0 +1,17 @@ +def test_index_list(session_admin_client): + result = session_admin_client.index_list() + assert len(result) > 0 + for index in result: + assert isinstance(index['id']['name'], str) + assert isinstance(index['id']['namespace'], str) + assert isinstance(index['dimensions'], int) + assert isinstance(index['bin'], str) + assert isinstance(index['hnsw_params']['m'], int) + assert isinstance(index['hnsw_params']['ef_construction'], int) + assert isinstance(index['hnsw_params']['ef'], int) + assert isinstance(index['hnsw_params']['batching_params']['max_records'], int) + assert isinstance(index['hnsw_params']['batching_params']['interval'], int) + assert isinstance(index['hnsw_params']['batching_params']['disabled'], bool) + assert isinstance(index['aerospikeStorage']['namespace'], str) + assert isinstance(index['aerospikeStorage']['set'], str) + diff --git a/tests/sync/test_vector_client_exists.py b/tests/sync/test_vector_client_exists.py new file mode 100644 index 00000000..db7c917c --- /dev/null +++ b/tests/sync/test_vector_client_exists.py @@ -0,0 +1,47 @@ +import pytest + +class exists_test_case: + def __init__( + self, + *, + namespace, + key, + record_data, + set_name, + + ): + self.namespace = namespace + self.key = key + self.set_name = set_name + self.record_data = record_data + +@pytest.mark.parametrize( + "test_case", + [ + exists_test_case( + namespace="test", + key="get/1", + set_name=None, + record_data={"skills": [i for i in range(1024)]}, + ), + exists_test_case( + namespace="test", + key="get/1", + set_name=None, + record_data={"english": [float(i) for i in range(1024)]}, + ) + ], +) +def test_vector_exists(session_vector_client, test_case): + session_vector_client.put( + namespace=test_case.namespace, + key=test_case.key, + record_data=test_case.record_data, + set_name=test_case.set_name + + ) + result = session_vector_client.exists( + namespace=test_case.namespace, + key=test_case.key, + ) + assert result is True diff --git a/tests/sync/test_vector_client_get.py b/tests/sync/test_vector_client_get.py new file mode 100644 index 00000000..a7fca2ca --- /dev/null +++ b/tests/sync/test_vector_client_get.py @@ -0,0 +1,59 @@ +import pytest + +class get_test_case: + def __init__( + self, + *, + namespace, + key, + bin_names, + set_name, + record_data, + expected_bins + ): + self.namespace = namespace + self.key = key + self.bin_names = bin_names + self.set_name = set_name + self.record_data = record_data + self.expected_bins = expected_bins + +@pytest.mark.parametrize( + "test_case", + [ + get_test_case( + namespace="test", + key="get/1", + bin_names=['skills'], + set_name=None, + record_data={"skills": [i for i in range(1024)]}, + expected_bins={"skills": [i for i in range(1024)]} + ), + get_test_case( + namespace="test", + key="get/1", + bin_names=['english'], + set_name=None, + record_data={"english": [float(i) for i in range(1024)]}, + expected_bins={"english": [float(i) for i in range(1024)]} + ) + ], +) +def test_vector_get(session_vector_client, test_case): + session_vector_client.put( + namespace=test_case.namespace, + key=test_case.key, + record_data=test_case.record_data, + set_name=test_case.set_name + + ) + result = session_vector_client.get( + namespace=test_case.namespace, key=test_case.key, bin_names=test_case.bin_names + ) + assert result.key.namespace == test_case.namespace + if(test_case.set_name == None): + test_case.set_name = "" + assert result.key.set == test_case.set_name + assert result.key.key == test_case.key + assert isinstance(result.key.digest, bytes) + assert result.bins == test_case.expected_bins diff --git a/tests/sync/test_vector_client_put.py b/tests/sync/test_vector_client_put.py new file mode 100644 index 00000000..104ac9f6 --- /dev/null +++ b/tests/sync/test_vector_client_put.py @@ -0,0 +1,41 @@ +import pytest +class put_test_case: + def __init__( + self, + *, + namespace, + key, + record_data, + set_name + ): + self.namespace = namespace + self.key = key + self.record_data = record_data + self.set_name = set_name + + +@pytest.mark.parametrize( + "test_case", + [ + put_test_case( + namespace="test", + key="put/1", + record_data={"math": [i for i in range(1024)]}, + set_name=None + ), + put_test_case( + namespace="test", + key="put/2", + record_data={"english": [float(i) for i in range(1024)]}, + set_name=None + ) + ], +) +def test_vector_put(session_vector_client, test_case): + session_vector_client.put( + namespace=test_case.namespace, + key=test_case.key, + record_data=test_case.record_data, + set_name=test_case.set_name + + ) \ No newline at end of file diff --git a/tests/sync/test_vector_search.py b/tests/sync/test_vector_search.py new file mode 100644 index 00000000..57811b89 --- /dev/null +++ b/tests/sync/test_vector_search.py @@ -0,0 +1,171 @@ +import numpy as np +import pytest +import random +from aerospike_vector import types + +dimensions = 128 +truth_vector_dimensions = 100 +base_vector_number = 10_000 +query_vector_number = 100 + + +# Print the current working directory +def parse_sift_to_numpy_array(length, dim, byte_buffer, dtype): + numpy = np.empty((length,), dtype=object) + + record_length = (dim * 4) + 4 + + for i in range(length): + current_offset = i * record_length + begin = current_offset + vector_begin = current_offset + 4 + end = current_offset + record_length + if np.frombuffer(byte_buffer[begin:vector_begin], dtype=np.int32)[0] != dim: + raise Exception("Failed to parse byte buffer correctly") + numpy[i] = np.frombuffer(byte_buffer[vector_begin:end], dtype=dtype) + return numpy + + +@pytest.fixture +def base_numpy(): + base_filename = "siftsmall/siftsmall_base.fvecs" + with open(base_filename, "rb") as file: + base_bytes = bytearray(file.read()) + + base_numpy = parse_sift_to_numpy_array( + base_vector_number, dimensions, base_bytes, np.float32 + ) + + return base_numpy + + +@pytest.fixture +def truth_numpy(): + truth_filename = "siftsmall/siftsmall_groundtruth.ivecs" + with open(truth_filename, "rb") as file: + truth_bytes = bytearray(file.read()) + + truth_numpy = parse_sift_to_numpy_array( + query_vector_number, truth_vector_dimensions, truth_bytes, np.int32 + ) + + return truth_numpy + + +@pytest.fixture +def query_numpy(): + query_filename = "siftsmall/siftsmall_query.fvecs" + with open(query_filename, "rb") as file: + query_bytes = bytearray(file.read()) + + truth_numpy = parse_sift_to_numpy_array( + query_vector_number, dimensions, query_bytes, np.float32 + ) + + return truth_numpy + + +def put_vector(client, vector, j): + client.put( + namespace="test", key=str(j), record_data={"unit_test": vector}, set_name="demo" + ) + + +def get_vector(client, j): + result = client.get(namespace="test", key=str(j), set_name="demo") + + +def vector_search(client, vector): + result = client.vector_search( + namespace="test", + index_name="demo", + query=vector, + limit=100, + bin_names=["unit_test"], + ) + return result + + +def vector_search_ef_80(client, vector): + result = client.vector_search( + namespace="test", + index_name="demo", + query=vector, + limit=100, + bin_names=["unit_test"], + search_params=types.HnswSearchParams(ef=80) + ) + return result + +def test_vector_search( + base_numpy, + truth_numpy, + query_numpy, + session_vector_client, + session_admin_client, +): + + session_admin_client.index_create( + namespace="test", + name="demo", + vector_field="unit_test", + dimensions=128, + sets="demo", + ) + + # Put base vectors for search + for j, vector in enumerate(base_numpy): + put_vector(session_vector_client, vector.tolist(), j) + + session_vector_client.wait_for_index_completion(namespace='test', name='demo') + + + # Vector search all query vectors + results = [] + count = 0 + for i in query_numpy: + if count % 2: + results.append(vector_search(session_vector_client, i.tolist())) + else: + results.append(vector_search_ef_80(session_vector_client, i.tolist())) + count += 1 + + # Get recall numbers for each query + recall_for_each_query = [] + for i, outside in enumerate(truth_numpy): + true_positive = 0 + false_negative = 0 + # Parse all bins for each neighbor into an array + binList = [] + + for j, result in enumerate(results[i]): + binList.append(result.bins["unit_test"]) + for j, index in enumerate(outside): + vector = base_numpy[index].tolist() + if vector in binList: + true_positive = true_positive + 1 + else: + false_negative = false_negative + 1 + + recall = true_positive / (true_positive + false_negative) + recall_for_each_query.append(recall) + + # Calculate the sum of all values + recall_sum = sum(recall_for_each_query) + + # Calculate the average + average = recall_sum / len(recall_for_each_query) + + assert average > 0.95 + for recall in recall_for_each_query: + assert recall > 0.9 + + +def test_vector_is_indexed(session_vector_client, session_admin_client): + result = session_vector_client.is_indexed( + namespace="test", + key=str(random.randrange(10_000)), + set_name="demo", + index_name="demo", + ) + assert result is True From bd0074ac9f9d3f2c36003643bfd87f1df123a2ed Mon Sep 17 00:00:00 2001 From: Dominic Pelini <111786059+DomPeliniAerospike@users.noreply.github.com> Date: Mon, 29 Apr 2024 10:17:23 -0600 Subject: [PATCH 2/4] Update src/aerospike_vector/client.py Co-authored-by: Jesse S --- src/aerospike_vector/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aerospike_vector/client.py b/src/aerospike_vector/client.py index 9ddc07c9..06a8a1af 100644 --- a/src/aerospike_vector/client.py +++ b/src/aerospike_vector/client.py @@ -24,7 +24,7 @@ def __init__( self, *, seeds: Union[types.HostPort, tuple[types.HostPort, ...]], - listener_name: str = None, + listener_name: Optional[str] = None, is_loadbalancer: Optional[bool] = False, ) -> None: """ From 4210fb26b257fe7923c092449d1c8355d49a1be2 Mon Sep 17 00:00:00 2001 From: Dominic Pelini <111786059+DomPeliniAerospike@users.noreply.github.com> Date: Tue, 30 Apr 2024 10:52:51 -0600 Subject: [PATCH 3/4] Renamed package to aerospike vector search Addressed Jesse's review comments. Fixed index_waiting logic to check index status twice before returning. Changed index wait to 5 seconds for consistency --- README.md | 12 +- docs/conf.py | 4 +- docs/index.rst | 4 +- docs/requirements.txt | 2 +- docs/types.rst | 2 +- docs/vectordb_admin.rst | 2 +- docs/vectordb_client.rst | 2 +- proto/codegen.sh | 6 +- pyproject.toml | 6 +- src/aerospike_vector/shared/admin_helpers.py | 158 ---------- .../shared/channel_provider_helpers.py | 50 --- src/aerospike_vector/shared/client_helpers.py | 182 ----------- .../__init__.py | 0 .../admin.py | 46 +-- .../aio/__init__.py | 0 .../aio/admin.py | 46 +-- .../aio/client.py | 53 ++-- .../aio/internal/channel_provider.py | 17 +- src/aerospike_vector_search/client.py | 284 ++++++++++++++++++ .../internal/__init__.py | 0 .../internal/channel_provider.py | 17 +- .../shared/__init__.py | 0 .../shared/admin_helpers.py | 161 ++++++++++ .../shared/base_channel_provider.py | 61 +++- .../shared/client_helpers.py | 183 +++++++++++ .../shared/conversions.py | 0 .../shared/helpers.py | 2 +- .../shared/proto_generated/auth_pb2.py | 0 .../shared/proto_generated/auth_pb2_grpc.py | 0 .../shared/proto_generated/index_pb2.py | 0 .../shared/proto_generated/index_pb2_grpc.py | 0 .../shared/proto_generated/transact_pb2.py | 0 .../proto_generated/transact_pb2_grpc.py | 0 .../shared/proto_generated/types_pb2.py | 0 .../shared/proto_generated/types_pb2_grpc.py | 0 .../shared/proto_generated/vector_db_pb2.py | 0 .../proto_generated/vector_db_pb2_grpc.py | 0 .../types.py | 6 +- tests/aio/conftest.py | 7 +- tests/aio/test_admin_client_index_create.py | 2 +- tests/aio/test_vector_search.py | 5 +- tests/sync/conftest.py | 6 +- tests/sync/test_admin_client_index_create.py | 2 +- tests/sync/test_vector_search.py | 3 +- 44 files changed, 810 insertions(+), 521 deletions(-) delete mode 100644 src/aerospike_vector/shared/admin_helpers.py delete mode 100644 src/aerospike_vector/shared/channel_provider_helpers.py delete mode 100644 src/aerospike_vector/shared/client_helpers.py rename src/{aerospike_vector => aerospike_vector_search}/__init__.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/admin.py (85%) rename src/{aerospike_vector => aerospike_vector_search}/aio/__init__.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/aio/admin.py (85%) rename src/{aerospike_vector => aerospike_vector_search}/aio/client.py (84%) rename src/{aerospike_vector => aerospike_vector_search}/aio/internal/channel_provider.py (81%) create mode 100644 src/aerospike_vector_search/client.py rename src/{aerospike_vector => aerospike_vector_search}/internal/__init__.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/internal/channel_provider.py (81%) rename src/{aerospike_vector => aerospike_vector_search}/shared/__init__.py (100%) create mode 100644 src/aerospike_vector_search/shared/admin_helpers.py rename src/{aerospike_vector => aerospike_vector_search}/shared/base_channel_provider.py (60%) create mode 100644 src/aerospike_vector_search/shared/client_helpers.py rename src/{aerospike_vector => aerospike_vector_search}/shared/conversions.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/shared/helpers.py (96%) rename src/{aerospike_vector => aerospike_vector_search}/shared/proto_generated/auth_pb2.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/shared/proto_generated/auth_pb2_grpc.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/shared/proto_generated/index_pb2.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/shared/proto_generated/index_pb2_grpc.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/shared/proto_generated/transact_pb2.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/shared/proto_generated/transact_pb2_grpc.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/shared/proto_generated/types_pb2.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/shared/proto_generated/types_pb2_grpc.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/shared/proto_generated/vector_db_pb2.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/shared/proto_generated/vector_db_pb2_grpc.py (100%) rename src/{aerospike_vector => aerospike_vector_search}/types.py (97%) diff --git a/README.md b/README.md index 5a6f5f0f..ddaefeb5 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ -# Aerospike Vector Client Python -Python client for Aerospike VectorDB +# Aerospike Vector Search Client Python +Python client for Aerospike Vector Search Database ## Prerequisites - Python 3.8 or higher - pip version 9.0.1 or higher - - Aerospike VectorDB and Aerospike clusters running. + - Aerospike Vector Search DB and Aerospike clusters running. ## Using the client from your application using pip @@ -15,13 +15,13 @@ To resolve the client packages using pip, add the following to $HOME/.pip/pip.co extra-index-url=https://:@aerospike.jfrog.io/artifactory/api/pypi/ecosystem-python-dev-local/simple ``` -### Install the aerospike_vector using pip +### Install the aerospike_vector_search using pip ```shell -python3 -m pip install aerospike-vector +python3 -m pip install aerospike-vector-search ``` Or -You can add the package name `aerospike-vector` to your application's `requirements.txt` and install all dependencies using +You can add the package name `aerospike-vector-search` to your application's `requirements.txt` and install all dependencies using ```shell python3 -m pip install -r requirements.txt ``` diff --git a/docs/conf.py b/docs/conf.py index e9b95de5..ffe0cc29 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -6,10 +6,10 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'aerospike-vector' +project = 'aerospike-vector-search' copyright = '2024, Dominic Pelini' author = 'Dominic Pelini' -release = '0.4.0' +release = '0.5.0' # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/index.rst b/docs/index.rst index 4b611144..cfad5f25 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,9 +1,9 @@ -.. aerospike-vector documentation master file, created by +.. aerospike-vector-search documentation master file, created by sphinx-quickstart on Thu Apr 11 07:35:51 2024. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to Aerospike Vector Client for Python. +Welcome to Aerospike Vector Search Client for Python. This package splits the client functionality into two separate clients. diff --git a/docs/requirements.txt b/docs/requirements.txt index e5ffa603..499120b7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,2 @@ sphinx_rtd_theme -aerospike-vector +aerospike-vector-search diff --git a/docs/types.rst b/docs/types.rst index 01d3ab4e..60d9d6dd 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -3,7 +3,7 @@ types This module provides administrative functions for the VectorDB system. -.. automodule:: aerospike_vector.types +.. automodule:: aerospike_vector_search.types :members: :undoc-members: :show-inheritance: diff --git a/docs/vectordb_admin.rst b/docs/vectordb_admin.rst index 51c15fad..5dc2beef 100644 --- a/docs/vectordb_admin.rst +++ b/docs/vectordb_admin.rst @@ -4,7 +4,7 @@ vectordb_admin Module This module contains the admin client, which is designed to conduct Proximus administrative operation such as creating indexes, querying index information, and dropping indexes. -.. automodule:: aerospike_vector.vectordb_admin +.. automodule:: aerospike_vector_search.vectordb_admin :members: :undoc-members: :show-inheritance: diff --git a/docs/vectordb_client.rst b/docs/vectordb_client.rst index 51bdbd9d..d662146b 100644 --- a/docs/vectordb_client.rst +++ b/docs/vectordb_client.rst @@ -4,7 +4,7 @@ vectordb_client Module This module contains the vector client (VectorDbClient), which specializes in performing database operations with vector data. -.. automodule:: aerospike_vector.vectordb_client +.. automodule:: aerospike_vector_search.vectordb_client :members: :undoc-members: :show-inheritance: diff --git a/proto/codegen.sh b/proto/codegen.sh index 069515ac..643db4b5 100755 --- a/proto/codegen.sh +++ b/proto/codegen.sh @@ -6,10 +6,10 @@ cd "$(dirname "$0")" python3 -m pip install grpcio-tools python3 -m grpc_tools.protoc \ --proto_path=. \ - --python_out=../src/aerospike_vector/shared/proto_generated/ \ - --grpc_python_out=../src/aerospike_vector/shared/proto_generated/ \ + --python_out=../src/aerospike_vector_search/shared/proto_generated/ \ + --grpc_python_out=../src/aerospike_vector_search/shared/proto_generated/ \ *.proto # The generated imports are not relative and fail in generated packages. # Fix with relative imports. -find ../src/aerospike_vector/shared/proto_generated/ -name "*.py" -exec sed -i -e 's/^import \(.*\)_pb2 /from . import \1_pb2 /g' {} \; +find ../src/aerospike_vector_search/shared/proto_generated/ -name "*.py" -exec sed -i -e 's/^import \(.*\)_pb2 /from . import \1_pb2 /g' {} \; diff --git a/pyproject.toml b/pyproject.toml index c6eb8446..38c8217f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools","wheel"] build-backend = "setuptools.build_meta" [project] -name = "aerospike-vector" +name = "aerospike-vector-search" description = "Aerospike Proximus Client Library for Python" authors = [ { name = "Aerospike, Inc.", email = "info@aerospike.com" } @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Topic :: Database" ] -version = "0.4.0" +version = "0.5.0dev1" requires-python = ">3.8" dependencies = [ "grpcio", @@ -35,5 +35,5 @@ dependencies = [ [tool.setuptools] zip-safe = false include-package-data = true -packages = ["aerospike_vector"] +packages = ["aerospike_vector_search", "aerospike_vector_search.aio", "aerospike_vector_search.shared", "aerospike_vector_search.shared.proto_generated", "aerospike_vector_search.internal", "aerospike_vector_search.aio.internal"] package-dir={"" = "src"} diff --git a/src/aerospike_vector/shared/admin_helpers.py b/src/aerospike_vector/shared/admin_helpers.py deleted file mode 100644 index 017a1c37..00000000 --- a/src/aerospike_vector/shared/admin_helpers.py +++ /dev/null @@ -1,158 +0,0 @@ -import asyncio -import logging -from typing import Any, Optional, Union -import time - -import google.protobuf.empty_pb2 -from google.protobuf.json_format import MessageToDict -import grpc - -from . import helpers -from .proto_generated import index_pb2_grpc -from .proto_generated import types_pb2 -from .. import types - -logger = logging.getLogger(__name__) - -empty = google.protobuf.empty_pb2.Empty() - -def prepare_seeds(seeds) -> None: - return helpers.prepare_seeds(seeds) - -def prepare_index_create(self, namespace, name, vector_field, dimensions, vector_distance_metric, sets, index_params, index_meta_data, logger) -> None: - - logger.debug( - "Creating index: namespace=%s, name=%s, vector_field=%s, dimensions=%d, vector_distance_metric=%s, " - "sets=%s, index_params=%s, index_meta_data=%s", - namespace, - name, - vector_field, - dimensions, - vector_distance_metric, - sets, - index_params, - index_meta_data, - ) - - if sets and not sets.strip(): - sets = None - if index_params != None: - index_params = index_params._to_pb2() - id = get_index_id(namespace, name) - vector_distance_metric = vector_distance_metric.value - - index_stub = get_index_stub(self) - - index_create_request = types_pb2.IndexDefinition( - id=id, - vectorDistanceMetric=vector_distance_metric, - setFilter=sets, - hnswParams=index_params, - bin=vector_field, - dimensions=dimensions, - labels=index_meta_data, - ) - return (index_stub, index_create_request) - -def prepare_index_drop(self, namespace, name, logger) -> None: - - logger.debug("Dropping index: namespace=%s, name=%s", namespace, name) - - index_stub = get_index_stub(self) - index_drop_request = get_index_id(namespace, name) - - return (index_stub, index_drop_request) - -def prepare_index_list(self, logger) -> None: - - logger.debug("Getting index list") - - index_stub = get_index_stub(self) - index_list_request = empty - - return (index_stub, index_list_request) - -def prepare_index_get(self, namespace, name, logger) -> None: - - logger.debug( - "Getting index information: namespace=%s, name=%s", namespace, name - ) - - index_stub = get_index_stub(self) - index_get_request = get_index_id(namespace, name) - - return (index_stub, index_get_request) - -def prepare_index_get_status(self, namespace, name, logger) -> None: - - logger.debug("Getting index status: namespace=%s, name=%s", namespace, name) - - index_stub = get_index_stub(self) - index_get_status_request = get_index_id(namespace, name) - - - return (index_stub, index_get_status_request) - -def respond_index_list(response) -> None: - response_list = [] - for index in response.indices: - response_dict = MessageToDict(index) - - # Modifying dict to adhere to PEP-8 naming - hnsw_params_dict = response_dict.pop("hnswParams", None) - - hnsw_params_dict["ef_construction"] = hnsw_params_dict.pop( - "efConstruction", None - ) - - batching_params_dict = hnsw_params_dict.pop("batchingParams", None) - batching_params_dict["max_records"] = batching_params_dict.pop( - "maxRecords", None - ) - hnsw_params_dict["batching_params"] = batching_params_dict - - response_dict["hnsw_params"] = hnsw_params_dict - response_list.append(response_dict) - return response_list - -def respond_index_get(response) -> None: - response_dict = MessageToDict(response) - - # Modifying dict to adhere to PEP-8 naming - hnsw_params_dict = response_dict.pop("hnswParams", None) - - hnsw_params_dict["ef_construction"] = hnsw_params_dict.pop( - "efConstruction", None - ) - - batching_params_dict = hnsw_params_dict.pop("batchingParams", None) - batching_params_dict["max_records"] = batching_params_dict.pop( - "maxRecords", None - ) - hnsw_params_dict["batching_params"] = batching_params_dict - - response_dict["hnsw_params"] = hnsw_params_dict - return response_dict - -def respond_index_get_status(response) -> None: - return response.unmergedRecordCount - -def get_index_stub(self): - return index_pb2_grpc.IndexServiceStub( - self._channelProvider.get_channel() - ) - -def get_index_id(namespace, name): - return types_pb2.IndexId(namespace=namespace, name=name) - -def prepare_wait_for_index_waiting(self, namespace, name): - return helpers.prepare_wait_for_index_waiting(self, namespace, name) - - -def check_timeout(start_time, timeout): - if start_time + timeout < time.monotonic(): - raise "timed-out waiting for index creation" - - - - diff --git a/src/aerospike_vector/shared/channel_provider_helpers.py b/src/aerospike_vector/shared/channel_provider_helpers.py deleted file mode 100644 index d832ad65..00000000 --- a/src/aerospike_vector/shared/channel_provider_helpers.py +++ /dev/null @@ -1,50 +0,0 @@ -def init_tend(self) -> None: - end_tend = False - if self._is_loadbalancer: - # Skip tend if we are behind a load-balancer - end_tend = True - - if self._closed: - end_tend = True - - # TODO: Worry about thread safety - temp_endpoints: dict[int, vector_db_pb2.ServerEndpointList] = {} - - update_endpoints = False - channels = self._seedChannels + [ - x.channel for x in self._nodeChannels.values() - ] - return (temp_endpoints, update_endpoints, channels, end_tend) - - -def check_cluster_id(self, new_cluster_id) -> None: - if new_cluster_id == self._clusterId: - return False - - self._clusterId = new_cluster_id - - return True - -def update_temp_endpoints(response, temp_endpoints): - endpoints = response.endpoints - if len(endpoints) > len(temp_endpoints): - return endpoints - else: - return temp_endpoints - - -def check_for_new_endpoints(self, node, newEndpoints): - - channel_endpoints = self._nodeChannels.get(node) - add_new_channel = True - - if channel_endpoints: - # We have this node. Check if the endpoints changed. - if channel_endpoints.endpoints == newEndpoints: - # Nothing to be done for this node - add_new_channel = False - else: - add_new_channel = True - - return (channel_endpoints, add_new_channel) - diff --git a/src/aerospike_vector/shared/client_helpers.py b/src/aerospike_vector/shared/client_helpers.py deleted file mode 100644 index e6386455..00000000 --- a/src/aerospike_vector/shared/client_helpers.py +++ /dev/null @@ -1,182 +0,0 @@ -from typing import Any, Optional, Union -import time -from . import conversions - -from .proto_generated import transact_pb2 -from .proto_generated import transact_pb2_grpc -from .. import types -from .proto_generated import types_pb2 -from . import helpers - - -def prepare_seeds(seeds) -> None: - return helpers.prepare_seeds(seeds) - -def prepare_put(self, namespace, key, record_data, set_name, logger) -> None: - - logger.debug( - "Putting record: namespace=%s, key=%s, record_data:%s, set_name:%s", - namespace, - key, - record_data, - set_name, - ) - - key = _get_key(namespace, set_name, key) - bin_list = [ - types_pb2.Bin(name=k, value=conversions.toVectorDbValue(v)) - for (k, v) in record_data.items() - ] - - transact_stub = get_transact_stub(self) - put_request = transact_pb2.PutRequest(key=key, bins=bin_list) - - return (transact_stub, put_request) - -def prepare_get(self, namespace, key, bin_names, set_name, logger) -> None: - - logger.debug( - "Getting record: namespace=%s, key=%s, bin_names:%s, set_name:%s", - namespace, - key, - bin_names, - set_name, - ) - - - key = _get_key(namespace, set_name, key) - bin_selector = _get_bin_selector(bin_names=bin_names) - - transact_stub = get_transact_stub(self) - get_request = transact_pb2.GetRequest(key=key, binSelector=bin_selector) - - return (transact_stub, key, get_request) - -def prepare_exists(self, namespace, key, set_name, logger) -> None: - - logger.debug( - "Getting record existence: namespace=%s, key=%s, set_name:%s", - namespace, - key, - set_name, - ) - - key = _get_key(namespace, set_name, key) - - transact_stub = get_transact_stub(self) - - return (transact_stub, key) - -def prepare_is_indexed(self, namespace, key, index_name, index_namespace, set_name, logger) -> None: - - logger.debug( - "Checking if index exists: namespace=%s, key=%s, index_name=%s, index_namespace=%s, set_name=%s", - namespace, - key, - index_name, - index_namespace, - set_name, - ) - - if not index_namespace: - index_namespace = namespace - index_id = types_pb2.IndexId(namespace=index_namespace, name=index_name) - key = _get_key(namespace, set_name, key) - - transact_stub = get_transact_stub(self) - is_indexed_request = transact_pb2.IsIndexedRequest(key=key, indexId=index_id) - - return (transact_stub, is_indexed_request) - -def prepare_vector_search(self, namespace, index_name, query, limit, search_params, bin_names, logger) -> None: - - logger.debug( - "Performing vector search: namespace=%s, index_name=%s, query=%s, limit=%s, search_params=%s, bin_names=%s", - namespace, - index_name, - query, - limit, - search_params, - bin_names, - ) - - if search_params != None: - search_params = search_params._to_pb2() - bin_selector = _get_bin_selector(bin_names=bin_names) - index = types_pb2.IndexId(namespace=namespace, name=index_name) - query_vector = conversions.toVectorDbValue(query).vectorValue - - - transact_stub = get_transact_stub(self) - - vector_search_request = transact_pb2.VectorSearchRequest( - index=index, - queryVector=query_vector, - limit=limit, - hnswSearchParams=search_params, - binSelector=bin_selector, - ) - - return (transact_stub, vector_search_request) - -def get_transact_stub(self): - return transact_pb2_grpc.TransactStub( - self._channelProvider.get_channel() - ) - -def respond_get(response, key) -> None: - return types.RecordWithKey( - key=conversions.fromVectorDbKey(key), - bins=conversions.fromVectorDbRecord(response), - ) - -def respond_exists(response) -> None: - return response.value - -def respond_is_indexed(response) -> None: - return response.value - -def respond_neighbor(response) -> None: - return conversions.fromVectorDbNeighbor(response) - -def _get_bin_selector(*, bin_names: Optional[list] = None): - - if not bin_names: - bin_selector = transact_pb2.BinSelector( - type=transact_pb2.BinSelectorType.ALL, binNames=bin_names - ) - else: - bin_selector = transact_pb2.BinSelector( - type=transact_pb2.BinSelectorType.SPECIFIED, binNames=bin_names - ) - return bin_selector - -def _get_key(namespace: str, set: str, key: Union[int, str, bytes, bytearray]): - if isinstance(key, str): - key = types_pb2.Key(namespace=namespace, set=set, stringValue=key) - elif isinstance(key, int): - key = types_pb2.Key(namespace=namespace, set=set, longValue=key) - elif isinstance(key, (bytes, bytearray)): - key = types_pb2.Key(namespace=namespace, set=set, bytesValue=key) - else: - raise Exception("Invalid key type" + type(key)) - return key - -def prepare_wait_for_index_waiting(self, namespace, name): - return helpers.prepare_wait_for_index_waiting(self, namespace, name) - -def check_completion_condition(start_time, timeout, index_status, unmerged_record_initialized): - - if start_time + 10 < time.monotonic(): - unmerged_record_initialized = True - - if index_status.unmergedRecordCount > 0: - unmerged_record_initialized = True - - if ( - index_status.unmergedRecordCount == 0 - and unmerged_record_initialized == True - ): - return True - else: - return False diff --git a/src/aerospike_vector/__init__.py b/src/aerospike_vector_search/__init__.py similarity index 100% rename from src/aerospike_vector/__init__.py rename to src/aerospike_vector_search/__init__.py diff --git a/src/aerospike_vector/admin.py b/src/aerospike_vector_search/admin.py similarity index 85% rename from src/aerospike_vector/admin.py rename to src/aerospike_vector_search/admin.py index 573cecaa..deb10499 100644 --- a/src/aerospike_vector/admin.py +++ b/src/aerospike_vector_search/admin.py @@ -6,15 +6,15 @@ from . import types from .internal import channel_provider -from .shared import admin_helpers +from .shared.admin_helpers import BaseClient logger = logging.getLogger(__name__) -class Client(object): +class Client(BaseClient): """ - Aerospike Vector Admin Client + Aerospike Vector Search Admin Client - This client is designed to conduct Aerospike Vector administrative operation such as creating indexes, querying index information, and dropping indexes. + This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes. """ def __init__( @@ -25,16 +25,16 @@ def __init__( is_loadbalancer: Optional[bool] = False, ) -> None: - seeds = admin_helpers.prepare_seeds(seeds) + seeds = self.prepare_seeds(seeds) self._channelProvider = channel_provider.ChannelProvider( seeds, listener_name, is_loadbalancer ) """ - Initialize the Aerospike Vector Admin Client. + Initialize the Aerospike Vector Search Admin Client. Args: - seeds (Union[types.HostPort, tuple[types.HostPort, ...]]): Used to create appropriate gRPC channels for interacting with Aerospike Vector. + seeds (Union[types.HostPort, tuple[types.HostPort, ...]]): Used to create appropriate gRPC channels for interacting with Aerospike Vector Search. listener_name (Optional[str], optional): Advertised listener for the client. Defaults to None. is_loadbalancer (bool, optional): If true, the first seed address will be treated as a load balancer node. @@ -82,7 +82,7 @@ def index_create( This method creates an index with the specified parameters and waits for the index creation to complete. It waits for up to 100,000 seconds for the index creation to complete. """ - (index_stub, index_create_request) = admin_helpers.prepare_index_create(self, namespace, name, vector_field, dimensions, vector_distance_metric, sets, index_params, index_meta_data, logger) + (index_stub, index_create_request) = self.prepare_index_create(namespace, name, vector_field, dimensions, vector_distance_metric, sets, index_params, index_meta_data, logger) try: index_stub.Create(index_create_request) except grpc.RpcError as e: @@ -112,7 +112,7 @@ def index_drop(self, *, namespace: str, name: str) -> None: This method drops an index with the specified parameters and waits for the index deletion to complete. It waits for up to 100,000 seconds for the index deletion to complete. """ - (index_stub, index_drop_request) = admin_helpers.prepare_index_drop(self, namespace, name, logger) + (index_stub, index_drop_request) = self.prepare_index_drop(namespace, name, logger) try: index_stub.Drop(index_drop_request) except grpc.RpcError as e: @@ -137,13 +137,13 @@ def index_list(self) -> list[dict]: grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - (index_stub, index_list_request) = admin_helpers.prepare_index_list(self, logger) + (index_stub, index_list_request) = self.prepare_index_list(logger) try: response = index_stub.List(index_list_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return admin_helpers.respond_index_list(response) + return self.respond_index_list(response) def index_get( self, *, namespace: str, name: str @@ -163,13 +163,13 @@ def index_get( This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - (index_stub, index_get_request) = admin_helpers.prepare_index_get(self, namespace, name, logger) + (index_stub, index_get_request) = self.prepare_index_get(namespace, name, logger) try: response = index_stub.Get(index_get_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return admin_helpers.respond_index_get(response) + return self.respond_index_get(response) def index_get_status(self, *, namespace: str, name: str) -> int: @@ -188,19 +188,19 @@ def index_get_status(self, *, namespace: str, name: str) -> int: This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. Note: - This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector, + This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, the records may not immediately begin to merge into the index. To wait for all records to be merged into an index, use vector_client.wait_for_index_completion. Warning: This API is subject to change. """ - (index_stub, index_get_status_request) = admin_helpers.prepare_index_get_status(self, namespace, name, logger) + (index_stub, index_get_status_request) = self.prepare_index_get_status(namespace, name, logger) try: response = index_stub.GetStatus(index_get_status_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return admin_helpers.respond_index_get_status(response) + return self.respond_index_get_status(response) def _wait_for_index_creation( self, *, namespace: str, name: str, timeout: int = sys.maxsize @@ -208,9 +208,9 @@ def _wait_for_index_creation( """ Wait for the index to be created. """ - (index_stub, wait_interval, start_time, _, _, index_creation_request) = admin_helpers.prepare_wait_for_index_waiting(self, namespace, name) + (index_stub, wait_interval, start_time, _, _, index_creation_request) = self.prepare_wait_for_index_waiting(namespace, name) while True: - admin_helpers.check_timeout(start_time, timeout) + self.check_timeout(start_time, timeout) try: index_stub.GetStatus(index_creation_request) logger.debug("Index created succesfully") @@ -233,10 +233,10 @@ def _wait_for_index_deletion( """ # Wait interval between polling - (index_stub, wait_interval, start_time, _, _, index_deletion_request) = admin_helpers.prepare_wait_for_index_waiting(self, namespace, name) + (index_stub, wait_interval, start_time, _, _, index_deletion_request) = self.prepare_wait_for_index_waiting(namespace, name) while True: - admin_helpers.check_timeout(start_time, timeout) + self.check_timeout(start_time, timeout) try: index_stub.GetStatus(index_deletion_request) @@ -252,9 +252,9 @@ def _wait_for_index_deletion( def close(self): """ - Close the Aerospike Vector Admin Client. + Close the Aerospike Vector Search Admin Client. - This method closes gRPC channels connected to Aerospike Vector. + This method closes gRPC channels connected to Aerospike Vector Search. Note: This method should be called when the VectorDbAdminClient is no longer needed to release resources. @@ -266,7 +266,7 @@ def __enter__(self): Enter an asynchronous context manager for the admin client. Returns: - VectorDbAdminlient: Aerospike Vector Admin Client instance. + VectorDbAdminlient: Aerospike Vector Search Admin Client instance. """ return self diff --git a/src/aerospike_vector/aio/__init__.py b/src/aerospike_vector_search/aio/__init__.py similarity index 100% rename from src/aerospike_vector/aio/__init__.py rename to src/aerospike_vector_search/aio/__init__.py diff --git a/src/aerospike_vector/aio/admin.py b/src/aerospike_vector_search/aio/admin.py similarity index 85% rename from src/aerospike_vector/aio/admin.py rename to src/aerospike_vector_search/aio/admin.py index 857a94aa..0d99fc6d 100644 --- a/src/aerospike_vector/aio/admin.py +++ b/src/aerospike_vector_search/aio/admin.py @@ -7,15 +7,15 @@ from .. import types from .internal import channel_provider -from ..shared import admin_helpers +from ..shared.admin_helpers import BaseClient logger = logging.getLogger(__name__) -class Client(object): +class Client(BaseClient): """ - Aerospike Vector Admin Client + Aerospike Vector Search Asyncio Admin Client - This client is designed to conduct Aerospike Vector administrative operation such as creating indexes, querying index information, and dropping indexes. + This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes. """ def __init__( @@ -26,16 +26,16 @@ def __init__( is_loadbalancer: Optional[bool] = False, ) -> None: - seeds = admin_helpers.prepare_seeds(seeds) + seeds = self.prepare_seeds(seeds) self._channelProvider = channel_provider.ChannelProvider( seeds, listener_name, is_loadbalancer ) """ - Initialize the Aerospike Vector Admin Client. + Initialize the Aerospike Vector Search Admin Client. Args: - seeds (Union[types.HostPort, tuple[types.HostPort, ...]]): Used to create appropriate gRPC channels for interacting with Aerospike Vector. + seeds (Union[types.HostPort, tuple[types.HostPort, ...]]): Used to create appropriate gRPC channels for interacting with Aerospike Vector Search. listener_name (Optional[str], optional): Advertised listener for the client. Defaults to None. is_loadbalancer (bool, optional): If true, the first seed address will be treated as a load balancer node. @@ -83,7 +83,7 @@ async def index_create( This method creates an index with the specified parameters and waits for the index creation to complete. It waits for up to 100,000 seconds for the index creation to complete. """ - (index_stub, index_create_request) = admin_helpers.prepare_index_create(self, namespace, name, vector_field, dimensions, vector_distance_metric, sets, index_params, index_meta_data, logger) + (index_stub, index_create_request) = self.prepare_index_create(namespace, name, vector_field, dimensions, vector_distance_metric, sets, index_params, index_meta_data, logger) try: await index_stub.Create(index_create_request) except grpc.RpcError as e: @@ -113,7 +113,7 @@ async def index_drop(self, *, namespace: str, name: str) -> None: This method drops an index with the specified parameters and waits for the index deletion to complete. It waits for up to 100,000 seconds for the index deletion to complete. """ - (index_stub, index_drop_request) = admin_helpers.prepare_index_drop(self, namespace, name, logger) + (index_stub, index_drop_request) = self.prepare_index_drop(namespace, name, logger) try: await index_stub.Drop(index_drop_request) except grpc.RpcError as e: @@ -138,13 +138,13 @@ async def index_list(self) -> list[dict]: grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - (index_stub, index_list_request) = admin_helpers.prepare_index_list(self, logger) + (index_stub, index_list_request) = self.prepare_index_list(logger) try: response = await index_stub.List(index_list_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return admin_helpers.respond_index_list(response) + return self.respond_index_list(response) async def index_get( self, *, namespace: str, name: str @@ -164,13 +164,13 @@ async def index_get( This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - (index_stub, index_get_request) = admin_helpers.prepare_index_get(self, namespace, name, logger) + (index_stub, index_get_request) = self.prepare_index_get(namespace, name, logger) try: response = await index_stub.Get(index_get_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return admin_helpers.respond_index_get(response) + return self.respond_index_get(response) async def index_get_status(self, *, namespace: str, name: str) -> int: @@ -189,19 +189,19 @@ async def index_get_status(self, *, namespace: str, name: str) -> int: This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. Note: - This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector, + This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, the records may not immediately begin to merge into the index. To wait for all records to be merged into an index, use vector_client.wait_for_index_completion. Warning: This API is subject to change. """ - (index_stub, index_get_status_request) = admin_helpers.prepare_index_get_status(self, namespace, name, logger) + (index_stub, index_get_status_request) = self.prepare_index_get_status(namespace, name, logger) try: response = await index_stub.GetStatus(index_get_status_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return admin_helpers.respond_index_get_status(response) + return self.respond_index_get_status(response) async def _wait_for_index_creation( self, *, namespace: str, name: str, timeout: int = sys.maxsize @@ -209,9 +209,9 @@ async def _wait_for_index_creation( """ Wait for the index to be created. """ - (index_stub, wait_interval, start_time, _, _, index_creation_request) = admin_helpers.prepare_wait_for_index_waiting(self, namespace, name) + (index_stub, wait_interval, start_time, _, _, index_creation_request) = self.prepare_wait_for_index_waiting(namespace, name) while True: - admin_helpers.check_timeout(start_time, timeout) + self.check_timeout(start_time, timeout) try: await index_stub.GetStatus(index_creation_request) logger.debug("Index created succesfully") @@ -234,10 +234,10 @@ async def _wait_for_index_deletion( """ # Wait interval between polling - (index_stub, wait_interval, start_time, _, _, index_deletion_request) = admin_helpers.prepare_wait_for_index_waiting(self, namespace, name) + (index_stub, wait_interval, start_time, _, _, index_deletion_request) = self.prepare_wait_for_index_waiting(namespace, name) while True: - admin_helpers.check_timeout(start_time, timeout) + self.check_timeout(start_time, timeout) try: await index_stub.GetStatus(index_deletion_request) @@ -253,9 +253,9 @@ async def _wait_for_index_deletion( async def close(self): """ - Close the Aerospike Vector Admin Client. + Close the Aerospike Vector Search Admin Client. - This method closes gRPC channels connected to Aerospike Vector. + This method closes gRPC channels connected to Aerospike Vector Search. Note: This method should be called when the VectorDbAdminClient is no longer needed to release resources. @@ -267,7 +267,7 @@ async def __aenter__(self): Enter an asynchronous context manager for the admin client. Returns: - VectorDbAdminlient: Aerospike Vector Admin Client instance. + VectorDbAdminlient: Aerospike Vector Search Admin Client instance. """ return self diff --git a/src/aerospike_vector/aio/client.py b/src/aerospike_vector_search/aio/client.py similarity index 84% rename from src/aerospike_vector/aio/client.py rename to src/aerospike_vector_search/aio/client.py index 012f74b8..6921cfc7 100644 --- a/src/aerospike_vector/aio/client.py +++ b/src/aerospike_vector_search/aio/client.py @@ -7,13 +7,13 @@ from .. import types from .internal import channel_provider -from ..shared import client_helpers +from ..shared.client_helpers import BaseClient logger = logging.getLogger(__name__) -class Client(object): +class Client(BaseClient): """ - Aerospike Vector Admin Client + Aerospike Vector Search Asyncio Admin Client This client specializes in performing database operations with vector data. Moreover, the client supports Hierarchical Navigable Small World (HNSW) vector searches, @@ -24,15 +24,15 @@ def __init__( self, *, seeds: Union[types.HostPort, tuple[types.HostPort, ...]], - listener_name: str = None, + listener_name: Optional[str] = None, is_loadbalancer: Optional[bool] = False, ) -> None: """ - Initialize the Aerospike Vector Vector Client. + Initialize the Aerospike Vector Search Vector Client. Args: seeds (Union[types.HostPort, tuple[types.HostPort, ...]]): - Used to create appropriate gRPC channels for interacting with Aerospike Vector. + Used to create appropriate gRPC channels for interacting with Aerospike Vector Search. listener_name (Optional[str], optional): Advertised listener for the client. Defaults to None. is_loadbalancer (bool, optional): @@ -41,7 +41,7 @@ def __init__( Raises: Exception: Raised when no seed host is provided. """ - seeds = client_helpers.prepare_seeds(seeds) + seeds = self.prepare_seeds(seeds) self._channelProvider = channel_provider.ChannelProvider( seeds, listener_name, is_loadbalancer ) @@ -55,7 +55,7 @@ async def put( set_name: Optional[str] = None, ) -> None: """ - Write a record to Aerospike Vector. + Write a record to Aerospike Vector Search. Args: namespace (str): The namespace for the record. @@ -68,7 +68,7 @@ async def put( This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - (transact_stub, put_request) = client_helpers.prepare_put(self, namespace, key, record_data, set_name, logger) + (transact_stub, put_request) = self.prepare_put(namespace, key, record_data, set_name, logger) try: await transact_stub.Put(put_request) @@ -85,7 +85,7 @@ async def get( set_name: Optional[str] = None, ) -> types.RecordWithKey: """ - Read a record from Aerospike Vector. + Read a record from Aerospike Vector Search. Args: namespace (str): The namespace for the record. @@ -101,20 +101,20 @@ async def get( grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - (transact_stub, key, get_request) = client_helpers.prepare_get(self, namespace, key, bin_names, set_name, logger) + (transact_stub, key, get_request) = self.prepare_get(namespace, key, bin_names, set_name, logger) try: response = await transact_stub.Get(get_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return client_helpers.respond_get(response, key) + return self.respond_get(response, key) async def exists( self, *, namespace: str, key: Any, set_name: Optional[str] = None ) -> bool: """ - Check if a record exists in Aerospike Vector. + Check if a record exists in Aerospike Vector Search. Args: namespace (str): The namespace for the record. @@ -128,14 +128,14 @@ async def exists( grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - (transact_stub, key) = client_helpers.prepare_exists(self, namespace, key, set_name, logger) + (transact_stub, key) = self.prepare_exists(namespace, key, set_name, logger) try: response = await transact_stub.Exists(key) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return client_helpers.respond_exists(response) + return self.respond_exists(response) async def is_indexed( self, @@ -164,13 +164,13 @@ async def is_indexed( grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - (transact_stub, is_indexed_request) = client_helpers.prepare_is_indexed(self, namespace, key, index_name, index_namespace, set_name, logger) + (transact_stub, is_indexed_request) = self.prepare_is_indexed(namespace, key, index_name, index_namespace, set_name, logger) try: response = await transact_stub.IsIndexed(is_indexed_request) except grpc.RpcError as e: logger.error("Failed with error: %s", e) raise e - return client_helpers.respond_is_indexed(response) + return self.respond_is_indexed(response) async def vector_search( self, @@ -183,7 +183,7 @@ async def vector_search( bin_names: Optional[list[str]] = None, ) -> list[types.Neighbor]: """ - Perform a Hierarchical Navigable Small World (HNSW) vector search in Aerospike Vector. + Perform a Hierarchical Navigable Small World (HNSW) vector search in Aerospike Vector Search. Args: namespace (str): The namespace for the records. @@ -202,7 +202,7 @@ async def vector_search( grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ - (transact_stub, vector_search_request) = client_helpers.prepare_vector_search(self, namespace, index_name, query, limit, search_params, bin_names, logger) + (transact_stub, vector_search_request) = self.prepare_vector_search(namespace, index_name, query, limit, search_params, bin_names, logger) try: results_stream = transact_stub.VectorSearch(vector_search_request) @@ -211,7 +211,7 @@ async def vector_search( raise e async_results = [] async for result in results_stream: - async_results.append(client_helpers.respond_neighbor(result)) + async_results.append(self.respond_neighbor(result)) return async_results @@ -237,7 +237,7 @@ async def wait_for_index_completion( the timeout is reached or the index has no pending index update operations. """ # Wait interval between polling - (index_stub, wait_interval, start_time, unmerged_record_initialized, double_check, index_completion_request) = client_helpers.prepare_wait_for_index_waiting(self, namespace, name) + (index_stub, wait_interval, start_time, unmerged_record_initialized, double_check, index_completion_request) = self.prepare_wait_for_index_waiting(namespace, name) while True: try: index_status = await index_stub.GetStatus(index_completion_request) @@ -248,19 +248,20 @@ async def wait_for_index_completion( else: logger.error("Failed with error: %s", e) raise e - if client_helpers.check_completion_condition(start_time, timeout, index_status, unmerged_record_initialized): + if self.check_completion_condition(start_time, timeout, index_status, unmerged_record_initialized): if double_check: return else: double_check = True else: - await asyncio.sleep(wait_interval) + double_check = False + await asyncio.sleep(wait_interval) async def close(self): """ - Close the Aerospike Vector Vector Client. + Close the Aerospike Vector Search Vector Client. - This method closes gRPC channels connected to Aerospike Vector. + This method closes gRPC channels connected to Aerospike Vector Search. Note: This method should be called when the VectorDbAdminClient is no longer needed to release resources. @@ -272,7 +273,7 @@ async def __aenter__(self): Enter an asynchronous context manager for the vector client. Returns: - VectorDbClient: Aerospike Vector Vector Client instance. + VectorDbClient: Aerospike Vector Search Vector Client instance. """ return self diff --git a/src/aerospike_vector/aio/internal/channel_provider.py b/src/aerospike_vector_search/aio/internal/channel_provider.py similarity index 81% rename from src/aerospike_vector/aio/internal/channel_provider.py rename to src/aerospike_vector_search/aio/internal/channel_provider.py index c98fd48b..dd14506b 100644 --- a/src/aerospike_vector/aio/internal/channel_provider.py +++ b/src/aerospike_vector_search/aio/internal/channel_provider.py @@ -11,7 +11,6 @@ from ...shared.proto_generated import vector_db_pb2 from ...shared.proto_generated import vector_db_pb2_grpc from ...shared import base_channel_provider -from ...shared import channel_provider_helpers empty = google.protobuf.empty_pb2.Empty() @@ -31,12 +30,12 @@ async def close(self): for channel in self._seedChannels: await channel.close() - for k, channelEndpoints in self._nodeChannels.items(): + for k, channelEndpoints in self._node_channels.items(): if channelEndpoints.channel: await channelEndpoints.channel.close() async def _tend(self): - (temp_endpoints, update_endpoints, channels, end_tend) = channel_provider_helpers.init_tend(self) + (temp_endpoints, update_endpoints, channels, end_tend) = self.init_tend() if end_tend: return @@ -46,7 +45,7 @@ async def _tend(self): try: new_cluster_id = await stub.GetClusterId(empty).id - if channel_provider_helpers.check_cluster_id(self, new_cluster_id): + if self.check_cluster_id(new_cluster_id): update_endpoints = True else: continue @@ -64,11 +63,11 @@ async def _tend(self): except Exception as e: logger.debug("While tending, failed to get cluster endpoints with error:" + str(e)) - temp_endpoints = channel_provider_helpers.update_temp_endpoints(response, temp_endpoints) + temp_endpoints = self.update_temp_endpoints(response, temp_endpoints) if update_endpoints: for node, newEndpoints in temp_endpoints.items(): - (channel_endpoints, add_new_channel) = channel_provider_helpers.check_for_new_endpoints(self, node, newEndpoints) + (channel_endpoints, add_new_channel) = self.check_for_new_endpoints(node, newEndpoints) if add_new_channel: try: # TODO: Wait for all calls to drain @@ -78,12 +77,12 @@ async def _tend(self): self.add_new_channel_to_node_channels(node, newEndpoints) - for node, channel_endpoints in self._nodeChannels.items(): + for node, channel_endpoints in self._node_channels.items(): if not temp_endpoints.get(node): # TODO: Wait for all calls to drain try: await channel_endpoints.channel.close() - del self._nodeChannels[node] + del self._node_channels[node] except Exception as e: logger.debug("While tending, failed to close GRPC channel:" + str(e)) @@ -92,7 +91,7 @@ async def _tend(self): await asyncio.sleep(1) asyncio.create_task(self._tend()) - def _create_channel(self, host: str, port: int, isTls: bool) -> grpc.aio.Channel: + def _create_channel(self, host: str, port: int, is_tls: bool) -> grpc.aio.Channel: # TODO: Take care of TLS host = re.sub(r"%.*", "", host) return grpc.aio.insecure_channel(f"{host}:{port}") \ No newline at end of file diff --git a/src/aerospike_vector_search/client.py b/src/aerospike_vector_search/client.py new file mode 100644 index 00000000..36056e8f --- /dev/null +++ b/src/aerospike_vector_search/client.py @@ -0,0 +1,284 @@ +import logging +import sys +import time +from typing import Any, Optional, Union + +import grpc + +from . import types +from .internal import channel_provider +from .shared.client_helpers import BaseClient + +logger = logging.getLogger(__name__) + +class Client(BaseClient): + """ + Aerospike Vector Search Admin Client + + This client specializes in performing database operations with vector data. + Moreover, the client supports Hierarchical Navigable Small World (HNSW) vector searches, + allowing users to find vectors similar to a given query vector within an index. + """ + + def __init__( + self, + *, + seeds: Union[types.HostPort, tuple[types.HostPort, ...]], + listener_name: Optional[str] = None, + is_loadbalancer: Optional[bool] = False, + ) -> None: + """ + Initialize the Aerospike Vector Search Vector Client. + + Args: + seeds (Union[types.HostPort, tuple[types.HostPort, ...]]): + Used to create appropriate gRPC channels for interacting with Aerospike Vector Search. + listener_name (Optional[str], optional): + Advertised listener for the client. Defaults to None. + is_loadbalancer (bool, optional): + If true, the first seed address will be treated as a load balancer node. + + Raises: + Exception: Raised when no seed host is provided. + """ + seeds = self.prepare_seeds(seeds) + self._channelProvider = channel_provider.ChannelProvider( + seeds, listener_name, is_loadbalancer + ) + + def put( + self, + *, + namespace: str, + key: Union[int, str, bytes, bytearray], + record_data: dict[str, Any], + set_name: Optional[str] = None, + ) -> None: + """ + Write a record to Aerospike Vector Search. + + Args: + namespace (str): The namespace for the record. + key (Union[int, str, bytes, bytearray]): The key for the record. + record_data (dict[str, Any]): The data to be stored in the record. + set_name (Optional[str], optional): The name of the set to which the record belongs. Defaults to None. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (transact_stub, put_request) = self.prepare_put(namespace, key, record_data, set_name, logger) + + try: + transact_stub.Put(put_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + + def get( + self, + *, + namespace: str, + key: Union[int, str, bytes, bytearray], + bin_names: Optional[list[str]] = None, + set_name: Optional[str] = None, + ) -> types.RecordWithKey: + """ + Read a record from Aerospike Vector Search. + + Args: + namespace (str): The namespace for the record. + key (Union[int, str, bytes, bytearray]): The key for the record. + bin_names (Optional[list[str]], optional): A list of bin names to retrieve from the record. + If None, all bins are retrieved. Defaults to None. + set_name (Optional[str], optional): The name of the set from which to read the record. Defaults to None. + + Returns: + types.RecordWithKey: A record with its associated key. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (transact_stub, key, get_request) = self.prepare_get(namespace, key, bin_names, set_name, logger) + try: + response = transact_stub.Get(get_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + + return self.respond_get(response, key) + + def exists( + self, *, namespace: str, key: Any, set_name: Optional[str] = None + ) -> bool: + """ + Check if a record exists in Aerospike Vector Search. + + Args: + namespace (str): The namespace for the record. + key (Any): The key for the record. + set_name (Optional[str], optional): The name of the set to which the record belongs. Defaults to None. + + Returns: + bool: True if the record exists, False otherwise. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (transact_stub, key) = self.prepare_exists(namespace, key, set_name, logger) + try: + response = transact_stub.Exists(key) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + + return self.respond_exists(response) + + def is_indexed( + self, + *, + namespace: str, + key: Union[int, str, bytes, bytearray], + index_name: str, + index_namespace: Optional[str] = None, + set_name: Optional[str] = None, + ) -> bool: + """ + Check if a record is indexed in the Vector DB. + + Args: + namespace (str): The namespace for the record. + key (Union[int, str, bytes, bytearray]): The key for the record. + index_name (str): The name of the index. + index_namespace (Optional[str], optional): The namespace of the index. + If None, defaults to the namespace of the record. Defaults to None. + set_name (Optional[str], optional): The name of the set to which the record belongs. Defaults to None. + + Returns: + bool: True if the record is indexed, False otherwise. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (transact_stub, is_indexed_request) = self.prepare_is_indexed(namespace, key, index_name, index_namespace, set_name, logger) + try: + response = transact_stub.IsIndexed(is_indexed_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + return self.respond_is_indexed(response) + + def vector_search( + self, + *, + namespace: str, + index_name: str, + query: list[Union[bool, float]], + limit: int, + search_params: Optional[types.HnswSearchParams] = None, + bin_names: Optional[list[str]] = None, + ) -> list[types.Neighbor]: + """ + Perform a Hierarchical Navigable Small World (HNSW) vector search in Aerospike Vector Search. + + Args: + namespace (str): The namespace for the records. + index_name (str): The name of the index. + query (list[Union[bool, float]]): The query vector for the search. + limit (int): The maximum number of neighbors to return. K value. + search_params (Optional[types_pb2.HnswSearchParams], optional): Parameters for the HNSW algorithm. + If None, the default parameters for the index are used. Defaults to None. + bin_names (Optional[list[str]], optional): A list of bin names to retrieve from the results. + If None, all bins are retrieved. Defaults to None. + + Returns: + list[types.Neighbor]: A list of neighbors records found by the search. + + Raises: + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (transact_stub, vector_search_request) = self.prepare_vector_search(namespace, index_name, query, limit, search_params, bin_names, logger) + + try: + results_stream = transact_stub.VectorSearch(vector_search_request) + except grpc.RpcError as e: + logger.error("Failed with error: %s", e) + raise e + results = [] + for result in results_stream: + results.append(self.respond_neighbor(result)) + + return results + + def wait_for_index_completion( + self, *, namespace: str, name: str, timeout: Optional[int] = sys.maxsize + ) -> None: + """ + Wait for the index to have no pending index update operations. + + Args: + namespace (str): The namespace of the index. + name (str): The name of the index. + timeout (int, optional): The maximum time (in seconds) to wait for the index to complete. + Defaults to sys.maxsize. + + Raises: + Exception: Raised when the timeout occurs while waiting for index completion. + grpc.RpcError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + The function polls the index status with a wait interval of 10 seconds until either + the timeout is reached or the index has no pending index update operations. + """ + # Wait interval between polling + (index_stub, wait_interval, start_time, unmerged_record_initialized, double_check, index_completion_request) = self.prepare_wait_for_index_waiting(namespace, name) + while True: + try: + index_status = index_stub.GetStatus(index_completion_request) + + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + continue + else: + logger.error("Failed with error: %s", e) + raise e + if self.check_completion_condition(start_time, timeout, index_status, unmerged_record_initialized): + if double_check: + return + else: + double_check = True + else: + double_check = False + time.sleep(wait_interval) + + def close(self): + """ + Close the Aerospike Vector Search Vector Client. + + This method closes gRPC channels connected to Aerospike Vector Search. + + Note: + This method should be called when the VectorDbAdminClient is no longer needed to release resources. + """ + self._channelProvider.close() + + def __enter__(self): + """ + Enter an asynchronous context manager for the vector client. + + Returns: + VectorDbClient: Aerospike Vector Search Vector Client instance. + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit an asynchronous context manager for the vector client. + """ + self.close() \ No newline at end of file diff --git a/src/aerospike_vector/internal/__init__.py b/src/aerospike_vector_search/internal/__init__.py similarity index 100% rename from src/aerospike_vector/internal/__init__.py rename to src/aerospike_vector_search/internal/__init__.py diff --git a/src/aerospike_vector/internal/channel_provider.py b/src/aerospike_vector_search/internal/channel_provider.py similarity index 81% rename from src/aerospike_vector/internal/channel_provider.py rename to src/aerospike_vector_search/internal/channel_provider.py index b70192b8..cc6b4b02 100644 --- a/src/aerospike_vector/internal/channel_provider.py +++ b/src/aerospike_vector_search/internal/channel_provider.py @@ -10,7 +10,6 @@ from ..shared.proto_generated import vector_db_pb2 from ..shared.proto_generated import vector_db_pb2_grpc from ..shared import base_channel_provider -from ..shared import channel_provider_helpers empty = google.protobuf.empty_pb2.Empty() @@ -30,12 +29,12 @@ def close(self): for channel in self._seedChannels: channel.close() - for k, channelEndpoints in self._nodeChannels.items(): + for k, channelEndpoints in self._node_channels.items(): if channelEndpoints.channel: channelEndpoints.channel.close() def _tend(self): - (temp_endpoints, update_endpoints, channels, end_tend) = channel_provider_helpers.init_tend(self) + (temp_endpoints, update_endpoints, channels, end_tend) = self.init_tend() if end_tend: return @@ -46,7 +45,7 @@ def _tend(self): try: new_cluster_id = stub.GetClusterId(empty).id - if channel_provider_helpers.check_cluster_id(self, new_cluster_id): + if self.check_cluster_id(new_cluster_id): update_endpoints = True else: continue @@ -61,14 +60,14 @@ def _tend(self): listenerName=self.listener_name ) ) + temp_endpoints = self.update_temp_endpoints(response, temp_endpoints) except Exception as e: logger.debug("While tending, failed to get cluster endpoints with error:" + str(e)) - temp_endpoints = channel_provider_helpers.update_temp_endpoints(response, temp_endpoints) if update_endpoints: for node, newEndpoints in temp_endpoints.items(): - (channel_endpoints, add_new_channel) = channel_provider_helpers.check_for_new_endpoints(self, node, newEndpoints) + (channel_endpoints, add_new_channel) = self.check_for_new_endpoints(node, newEndpoints) if add_new_channel: try: @@ -79,12 +78,12 @@ def _tend(self): self.add_new_channel_to_node_channels(node, newEndpoints) - for node, channel_endpoints in self._nodeChannels.items(): + for node, channel_endpoints in self._node_channels.items(): if not temp_endpoints.get(node): # TODO: Wait for all calls to drain try: channel_endpoints.channel.close() - del self._nodeChannels[node] + del self._node_channels[node] except Exception as e: logger.debug("While tending, failed to close GRPC channel:" + str(e)) @@ -92,7 +91,7 @@ def _tend(self): # TODO: check tend interval. threading.Timer(1, self._tend).start() - def _create_channel(self, host: str, port: int, isTls: bool) -> grpc.Channel: + def _create_channel(self, host: str, port: int, is_tls: bool) -> grpc.Channel: # TODO: Take care of TLS host = re.sub(r"%.*", "", host) return grpc.insecure_channel(f"{host}:{port}") \ No newline at end of file diff --git a/src/aerospike_vector/shared/__init__.py b/src/aerospike_vector_search/shared/__init__.py similarity index 100% rename from src/aerospike_vector/shared/__init__.py rename to src/aerospike_vector_search/shared/__init__.py diff --git a/src/aerospike_vector_search/shared/admin_helpers.py b/src/aerospike_vector_search/shared/admin_helpers.py new file mode 100644 index 00000000..35213df0 --- /dev/null +++ b/src/aerospike_vector_search/shared/admin_helpers.py @@ -0,0 +1,161 @@ +import asyncio +import logging +from typing import Any, Optional, Union +import time + +import google.protobuf.empty_pb2 +from google.protobuf.json_format import MessageToDict +import grpc + +from . import helpers +from .proto_generated import index_pb2_grpc +from .proto_generated import types_pb2 +from .. import types + +logger = logging.getLogger(__name__) + +empty = google.protobuf.empty_pb2.Empty() + + +class BaseClient(object): + + def prepare_seeds(self, seeds) -> None: + return helpers.prepare_seeds(seeds) + + def prepare_index_create(self, namespace, name, vector_field, dimensions, vector_distance_metric, sets, index_params, index_meta_data, logger) -> None: + + logger.debug( + "Creating index: namespace=%s, name=%s, vector_field=%s, dimensions=%d, vector_distance_metric=%s, " + "sets=%s, index_params=%s, index_meta_data=%s", + namespace, + name, + vector_field, + dimensions, + vector_distance_metric, + sets, + index_params, + index_meta_data, + ) + + if sets and not sets.strip(): + sets = None + if index_params != None: + index_params = index_params._to_pb2() + id = self.get_index_id(namespace, name) + vector_distance_metric = vector_distance_metric.value + + index_stub = self.get_index_stub() + + index_create_request = types_pb2.IndexDefinition( + id=id, + vectorDistanceMetric=vector_distance_metric, + setFilter=sets, + hnswParams=index_params, + bin=vector_field, + dimensions=dimensions, + labels=index_meta_data, + ) + return (index_stub, index_create_request) + + def prepare_index_drop(self, namespace, name, logger) -> None: + + logger.debug("Dropping index: namespace=%s, name=%s", namespace, name) + + index_stub = self.get_index_stub() + index_drop_request = self.get_index_id(namespace, name) + + return (index_stub, index_drop_request) + + def prepare_index_list(self, logger) -> None: + + logger.debug("Getting index list") + + index_stub = self.get_index_stub() + index_list_request = empty + + return (index_stub, index_list_request) + + def prepare_index_get(self, namespace, name, logger) -> None: + + logger.debug( + "Getting index information: namespace=%s, name=%s", namespace, name + ) + + index_stub = self.get_index_stub() + index_get_request = self.get_index_id(namespace, name) + + return (index_stub, index_get_request) + + def prepare_index_get_status(self, namespace, name, logger) -> None: + + logger.debug("Getting index status: namespace=%s, name=%s", namespace, name) + + index_stub = self.get_index_stub() + index_get_status_request = self.get_index_id(namespace, name) + + + return (index_stub, index_get_status_request) + + def respond_index_list(self, response) -> None: + response_list = [] + for index in response.indices: + response_dict = MessageToDict(index) + + # Modifying dict to adhere to PEP-8 naming + hnsw_params_dict = response_dict.pop("hnswParams", None) + + hnsw_params_dict["ef_construction"] = hnsw_params_dict.pop( + "efConstruction", None + ) + + batching_params_dict = hnsw_params_dict.pop("batchingParams", None) + batching_params_dict["max_records"] = batching_params_dict.pop( + "maxRecords", None + ) + hnsw_params_dict["batching_params"] = batching_params_dict + + response_dict["hnsw_params"] = hnsw_params_dict + response_list.append(response_dict) + return response_list + + def respond_index_get(self, response) -> None: + response_dict = MessageToDict(response) + + # Modifying dict to adhere to PEP-8 naming + hnsw_params_dict = response_dict.pop("hnswParams", None) + + hnsw_params_dict["ef_construction"] = hnsw_params_dict.pop( + "efConstruction", None + ) + + batching_params_dict = hnsw_params_dict.pop("batchingParams", None) + batching_params_dict["max_records"] = batching_params_dict.pop( + "maxRecords", None + ) + hnsw_params_dict["batching_params"] = batching_params_dict + + response_dict["hnsw_params"] = hnsw_params_dict + return response_dict + + def respond_index_get_status(self, response) -> None: + return response.unmergedRecordCount + + def get_index_stub(self): + return index_pb2_grpc.IndexServiceStub( + self._channelProvider.get_channel() + ) + + def get_index_id(self, namespace, name): + return types_pb2.IndexId(namespace=namespace, name=name) + + def prepare_wait_for_index_waiting(self, namespace, name): + return helpers.prepare_wait_for_index_waiting(self, namespace, name) + + + def check_timeout(self, start_time, timeout): + if start_time + timeout < time.monotonic(): + raise "timed-out waiting for index creation" + + + + diff --git a/src/aerospike_vector/shared/base_channel_provider.py b/src/aerospike_vector_search/shared/base_channel_provider.py similarity index 60% rename from src/aerospike_vector/shared/base_channel_provider.py rename to src/aerospike_vector_search/shared/base_channel_provider.py index ad69750d..1d9db8dc 100644 --- a/src/aerospike_vector/shared/base_channel_provider.py +++ b/src/aerospike_vector_search/shared/base_channel_provider.py @@ -25,10 +25,10 @@ def __init__( ) -> None: if not seeds: raise Exception("at least one seed host needed") - self._nodeChannels: dict[int, ChannelAndEndpoints] = {} + self._node_channels: dict[int, ChannelAndEndpoints] = {} self._seedChannels: Union[dict[grpc.Channel], dict[grpc.Channel.aio]] = {} self._closed = False - self._clusterId = 0 + self._cluster_id = 0 self.seeds = seeds self.listener_name = listener_name self._is_loadbalancer = is_loadbalancer @@ -39,7 +39,7 @@ def __init__( def get_channel(self) -> Union[grpc.aio.Channel, grpc.Channel]: if not self._is_loadbalancer: discovered_channels: list[ChannelAndEndpoints] = list( - self._nodeChannels.values()) + self._node_channels.values()) if len(discovered_channels) <= 0: return self._seedChannels[0] @@ -51,7 +51,7 @@ def get_channel(self) -> Union[grpc.aio.Channel, grpc.Channel]: return self._seedChannels[0] def _create_channel_from_host_port(self, host: types.HostPort) -> Union[grpc.aio.Channel, grpc.Channel]: - return self._create_channel(host.host, host.port, host.isTls) + return self._create_channel(host.host, host.port, host.is_tls) def _create_channel_from_server_endpoint_list( self, endpoints: vector_db_pb2.ServerEndpointList @@ -74,6 +74,57 @@ def add_new_channel_to_node_channels(self, node, newEndpoints): new_channel = self._create_channel_from_server_endpoint_list( newEndpoints ) - self._nodeChannels[node] = ChannelAndEndpoints( + self._node_channels[node] = ChannelAndEndpoints( new_channel, newEndpoints ) + + def init_tend(self) -> None: + end_tend = False + if self._is_loadbalancer: + # Skip tend if we are behind a load-balancer + end_tend = True + + if self._closed: + end_tend = True + + # TODO: Worry about thread safety + temp_endpoints: dict[int, vector_db_pb2.ServerEndpointList] = {} + + update_endpoints = False + channels = self._seedChannels + [ + x.channel for x in self._node_channels.values() + ] + return (temp_endpoints, update_endpoints, channels, end_tend) + + + def check_cluster_id(self, new_cluster_id) -> None: + if new_cluster_id == self._cluster_id: + return False + + self._cluster_id = new_cluster_id + + return True + + def update_temp_endpoints(self, response, temp_endpoints): + endpoints = response.endpoints + if len(endpoints) > len(temp_endpoints): + return endpoints + else: + return temp_endpoints + + + def check_for_new_endpoints(self, node, newEndpoints): + + channel_endpoints = self._node_channels.get(node) + add_new_channel = True + + if channel_endpoints: + # We have this node. Check if the endpoints changed. + if channel_endpoints.endpoints == newEndpoints: + # Nothing to be done for this node + add_new_channel = False + else: + add_new_channel = True + + return (channel_endpoints, add_new_channel) + diff --git a/src/aerospike_vector_search/shared/client_helpers.py b/src/aerospike_vector_search/shared/client_helpers.py new file mode 100644 index 00000000..a8df779f --- /dev/null +++ b/src/aerospike_vector_search/shared/client_helpers.py @@ -0,0 +1,183 @@ +from typing import Any, Optional, Union +import time +from . import conversions + +from .proto_generated import transact_pb2 +from .proto_generated import transact_pb2_grpc +from .. import types +from .proto_generated import types_pb2 +from . import helpers + +class BaseClient(object): + + def prepare_seeds(self, seeds) -> None: + return helpers.prepare_seeds(seeds) + + def prepare_put(self, namespace, key, record_data, set_name, logger) -> None: + + logger.debug( + "Putting record: namespace=%s, key=%s, record_data:%s, set_name:%s", + namespace, + key, + record_data, + set_name, + ) + + key = self._get_key(namespace, set_name, key) + bin_list = [ + types_pb2.Bin(name=k, value=conversions.toVectorDbValue(v)) + for (k, v) in record_data.items() + ] + + transact_stub = self.get_transact_stub() + put_request = transact_pb2.PutRequest(key=key, bins=bin_list) + + return (transact_stub, put_request) + + def prepare_get(self, namespace, key, bin_names, set_name, logger) -> None: + + logger.debug( + "Getting record: namespace=%s, key=%s, bin_names:%s, set_name:%s", + namespace, + key, + bin_names, + set_name, + ) + + + key = self._get_key(namespace, set_name, key) + bin_selector = self._get_bin_selector(bin_names=bin_names) + + transact_stub = self.get_transact_stub() + get_request = transact_pb2.GetRequest(key=key, binSelector=bin_selector) + + return (transact_stub, key, get_request) + + def prepare_exists(self, namespace, key, set_name, logger) -> None: + + logger.debug( + "Getting record existence: namespace=%s, key=%s, set_name:%s", + namespace, + key, + set_name, + ) + + key = self._get_key(namespace, set_name, key) + + transact_stub = self.get_transact_stub() + + return (transact_stub, key) + + def prepare_is_indexed(self, namespace, key, index_name, index_namespace, set_name, logger) -> None: + + logger.debug( + "Checking if index exists: namespace=%s, key=%s, index_name=%s, index_namespace=%s, set_name=%s", + namespace, + key, + index_name, + index_namespace, + set_name, + ) + + if not index_namespace: + index_namespace = namespace + index_id = types_pb2.IndexId(namespace=index_namespace, name=index_name) + key = self._get_key(namespace, set_name, key) + + transact_stub = self.get_transact_stub() + is_indexed_request = transact_pb2.IsIndexedRequest(key=key, indexId=index_id) + + return (transact_stub, is_indexed_request) + + def prepare_vector_search(self, namespace, index_name, query, limit, search_params, bin_names, logger) -> None: + + logger.debug( + "Performing vector search: namespace=%s, index_name=%s, query=%s, limit=%s, search_params=%s, bin_names=%s", + namespace, + index_name, + query, + limit, + search_params, + bin_names, + ) + + if search_params != None: + search_params = search_params._to_pb2() + bin_selector = self._get_bin_selector(bin_names=bin_names) + index = types_pb2.IndexId(namespace=namespace, name=index_name) + query_vector = conversions.toVectorDbValue(query).vectorValue + + + transact_stub = self.get_transact_stub() + + vector_search_request = transact_pb2.VectorSearchRequest( + index=index, + queryVector=query_vector, + limit=limit, + hnswSearchParams=search_params, + binSelector=bin_selector, + ) + + return (transact_stub, vector_search_request) + + def get_transact_stub(self): + return transact_pb2_grpc.TransactStub( + self._channelProvider.get_channel() + ) + + def respond_get(self, response, key) -> None: + return types.RecordWithKey( + key=conversions.fromVectorDbKey(key), + bins=conversions.fromVectorDbRecord(response), + ) + + def respond_exists(self, response) -> None: + return response.value + + def respond_is_indexed(self, response) -> None: + return response.value + + def respond_neighbor(self, response) -> None: + return conversions.fromVectorDbNeighbor(response) + + def _get_bin_selector(self, *, bin_names: Optional[list] = None): + + if not bin_names: + bin_selector = transact_pb2.BinSelector( + type=transact_pb2.BinSelectorType.ALL, binNames=bin_names + ) + else: + bin_selector = transact_pb2.BinSelector( + type=transact_pb2.BinSelectorType.SPECIFIED, binNames=bin_names + ) + return bin_selector + + def _get_key(self, namespace: str, set: str, key: Union[int, str, bytes, bytearray]): + if isinstance(key, str): + key = types_pb2.Key(namespace=namespace, set=set, stringValue=key) + elif isinstance(key, int): + key = types_pb2.Key(namespace=namespace, set=set, longValue=key) + elif isinstance(key, (bytes, bytearray)): + key = types_pb2.Key(namespace=namespace, set=set, bytesValue=key) + else: + raise Exception("Invalid key type" + type(key)) + return key + + def prepare_wait_for_index_waiting(self, namespace, name): + return helpers.prepare_wait_for_index_waiting(self, namespace, name) + + def check_completion_condition(self, start_time, timeout, index_status, unmerged_record_initialized): + + if start_time + 10 < time.monotonic(): + unmerged_record_initialized = True + + if index_status.unmergedRecordCount > 0: + unmerged_record_initialized = True + + if ( + index_status.unmergedRecordCount == 0 + and unmerged_record_initialized == True + ): + return True + else: + return False diff --git a/src/aerospike_vector/shared/conversions.py b/src/aerospike_vector_search/shared/conversions.py similarity index 100% rename from src/aerospike_vector/shared/conversions.py rename to src/aerospike_vector_search/shared/conversions.py diff --git a/src/aerospike_vector/shared/helpers.py b/src/aerospike_vector_search/shared/helpers.py similarity index 96% rename from src/aerospike_vector/shared/helpers.py rename to src/aerospike_vector_search/shared/helpers.py index e18c9ddb..7c644117 100644 --- a/src/aerospike_vector/shared/helpers.py +++ b/src/aerospike_vector_search/shared/helpers.py @@ -16,7 +16,7 @@ def prepare_seeds(seeds) -> None: def prepare_wait_for_index_waiting(self, namespace, name): - wait_interval = 0.100 + wait_interval = 5 unmerged_record_initialized = False start_time = time.monotonic() double_check = False diff --git a/src/aerospike_vector/shared/proto_generated/auth_pb2.py b/src/aerospike_vector_search/shared/proto_generated/auth_pb2.py similarity index 100% rename from src/aerospike_vector/shared/proto_generated/auth_pb2.py rename to src/aerospike_vector_search/shared/proto_generated/auth_pb2.py diff --git a/src/aerospike_vector/shared/proto_generated/auth_pb2_grpc.py b/src/aerospike_vector_search/shared/proto_generated/auth_pb2_grpc.py similarity index 100% rename from src/aerospike_vector/shared/proto_generated/auth_pb2_grpc.py rename to src/aerospike_vector_search/shared/proto_generated/auth_pb2_grpc.py diff --git a/src/aerospike_vector/shared/proto_generated/index_pb2.py b/src/aerospike_vector_search/shared/proto_generated/index_pb2.py similarity index 100% rename from src/aerospike_vector/shared/proto_generated/index_pb2.py rename to src/aerospike_vector_search/shared/proto_generated/index_pb2.py diff --git a/src/aerospike_vector/shared/proto_generated/index_pb2_grpc.py b/src/aerospike_vector_search/shared/proto_generated/index_pb2_grpc.py similarity index 100% rename from src/aerospike_vector/shared/proto_generated/index_pb2_grpc.py rename to src/aerospike_vector_search/shared/proto_generated/index_pb2_grpc.py diff --git a/src/aerospike_vector/shared/proto_generated/transact_pb2.py b/src/aerospike_vector_search/shared/proto_generated/transact_pb2.py similarity index 100% rename from src/aerospike_vector/shared/proto_generated/transact_pb2.py rename to src/aerospike_vector_search/shared/proto_generated/transact_pb2.py diff --git a/src/aerospike_vector/shared/proto_generated/transact_pb2_grpc.py b/src/aerospike_vector_search/shared/proto_generated/transact_pb2_grpc.py similarity index 100% rename from src/aerospike_vector/shared/proto_generated/transact_pb2_grpc.py rename to src/aerospike_vector_search/shared/proto_generated/transact_pb2_grpc.py diff --git a/src/aerospike_vector/shared/proto_generated/types_pb2.py b/src/aerospike_vector_search/shared/proto_generated/types_pb2.py similarity index 100% rename from src/aerospike_vector/shared/proto_generated/types_pb2.py rename to src/aerospike_vector_search/shared/proto_generated/types_pb2.py diff --git a/src/aerospike_vector/shared/proto_generated/types_pb2_grpc.py b/src/aerospike_vector_search/shared/proto_generated/types_pb2_grpc.py similarity index 100% rename from src/aerospike_vector/shared/proto_generated/types_pb2_grpc.py rename to src/aerospike_vector_search/shared/proto_generated/types_pb2_grpc.py diff --git a/src/aerospike_vector/shared/proto_generated/vector_db_pb2.py b/src/aerospike_vector_search/shared/proto_generated/vector_db_pb2.py similarity index 100% rename from src/aerospike_vector/shared/proto_generated/vector_db_pb2.py rename to src/aerospike_vector_search/shared/proto_generated/vector_db_pb2.py diff --git a/src/aerospike_vector/shared/proto_generated/vector_db_pb2_grpc.py b/src/aerospike_vector_search/shared/proto_generated/vector_db_pb2_grpc.py similarity index 100% rename from src/aerospike_vector/shared/proto_generated/vector_db_pb2_grpc.py rename to src/aerospike_vector_search/shared/proto_generated/vector_db_pb2_grpc.py diff --git a/src/aerospike_vector/types.py b/src/aerospike_vector_search/types.py similarity index 97% rename from src/aerospike_vector/types.py rename to src/aerospike_vector_search/types.py index 63d724ea..77258dae 100644 --- a/src/aerospike_vector/types.py +++ b/src/aerospike_vector_search/types.py @@ -12,13 +12,13 @@ class HostPort(object): Args: host (str): The host address. port (int): The port number. - isTls (Optional[bool], optional): Indicates if TLS is enabled. Defaults to False. + is_tls (Optional[bool], optional): Indicates if TLS is enabled. Defaults to False. """ - def __init__(self, *, host: str, port: int, isTls: Optional[bool] = False) -> None: + def __init__(self, *, host: str, port: int, is_tls: Optional[bool] = False) -> None: self.host = host self.port = port - self.isTls = isTls + self.is_tls = is_tls class Key(object): diff --git a/tests/aio/conftest.py b/tests/aio/conftest.py index 6f046eb6..54f9b182 100644 --- a/tests/aio/conftest.py +++ b/tests/aio/conftest.py @@ -1,8 +1,8 @@ import pytest import asyncio -from aerospike_vector.aio import Client -from aerospike_vector.aio.admin import Client as AdminClient -from aerospike_vector import types +from aerospike_vector_search.aio import Client +from aerospike_vector_search.aio.admin import Client as AdminClient +from aerospike_vector_search import types host = 'localhost' port = 5000 @@ -12,6 +12,7 @@ async def drop_all_indexes(): seeds=types.HostPort(host=host, port=port) ) as client: index_list = await client.index_list() + tasks = [] for item in index_list: tasks.append(client.index_drop(namespace="test", name=item['id']['name'])) diff --git a/tests/aio/test_admin_client_index_create.py b/tests/aio/test_admin_client_index_create.py index 9e92583f..ca40b93c 100644 --- a/tests/aio/test_admin_client_index_create.py +++ b/tests/aio/test_admin_client_index_create.py @@ -1,5 +1,5 @@ import pytest -from aerospike_vector import types +from aerospike_vector_search import types class index_create_test_case: diff --git a/tests/aio/test_vector_search.py b/tests/aio/test_vector_search.py index 681dacee..c120709d 100644 --- a/tests/aio/test_vector_search.py +++ b/tests/aio/test_vector_search.py @@ -2,7 +2,7 @@ import asyncio import pytest import random -from aerospike_vector import types +from aerospike_vector_search import types dimensions = 128 truth_vector_dimensions = 100 @@ -134,7 +134,6 @@ async def test_vector_search( count += 1 results = await asyncio.gather(*tasks) - # Get recall numbers for each query recall_for_each_query = [] for i, outside in enumerate(truth_numpy): @@ -145,6 +144,8 @@ async def test_vector_search( for j, result in enumerate(results[i]): binList.append(result.bins["unit_test"]) + + binList.append(result.bins["unit_test"]) for j, index in enumerate(outside): vector = base_numpy[index].tolist() if vector in binList: diff --git a/tests/sync/conftest.py b/tests/sync/conftest.py index ce92a375..6b25c78d 100644 --- a/tests/sync/conftest.py +++ b/tests/sync/conftest.py @@ -1,8 +1,8 @@ import pytest import asyncio -from aerospike_vector import Client -from aerospike_vector.admin import Client as AdminClient -from aerospike_vector import types +from aerospike_vector_search import Client +from aerospike_vector_search.admin import Client as AdminClient +from aerospike_vector_search import types host = 'localhost' port = 5000 diff --git a/tests/sync/test_admin_client_index_create.py b/tests/sync/test_admin_client_index_create.py index 98b34b7a..8600d9c7 100644 --- a/tests/sync/test_admin_client_index_create.py +++ b/tests/sync/test_admin_client_index_create.py @@ -1,5 +1,5 @@ import pytest -from aerospike_vector import types +from aerospike_vector_search import types class index_create_test_case: diff --git a/tests/sync/test_vector_search.py b/tests/sync/test_vector_search.py index 57811b89..2a2d95e0 100644 --- a/tests/sync/test_vector_search.py +++ b/tests/sync/test_vector_search.py @@ -1,7 +1,7 @@ import numpy as np import pytest import random -from aerospike_vector import types +from aerospike_vector_search import types dimensions = 128 truth_vector_dimensions = 100 @@ -129,7 +129,6 @@ def test_vector_search( else: results.append(vector_search_ef_80(session_vector_client, i.tolist())) count += 1 - # Get recall numbers for each query recall_for_each_query = [] for i, outside in enumerate(truth_numpy): From 68a72ccce00200c8d7a955e2823f1adf1060d90b Mon Sep 17 00:00:00 2001 From: Dominic Pelini <111786059+DomPeliniAerospike@users.noreply.github.com> Date: Tue, 30 Apr 2024 10:59:00 -0600 Subject: [PATCH 4/4] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 38c8217f..080a620f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Topic :: Database" ] -version = "0.5.0dev1" +version = "0.5.0" requires-python = ">3.8" dependencies = [ "grpcio",