diff --git a/.gitignore b/.gitignore index d4204a44..a238288e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,5 +9,6 @@ scratch wiki_schema.yaml docs/_build/ .venv +.env coverage.xml dist/ \ No newline at end of file diff --git a/conftest.py b/conftest.py index e7975057..204ac177 100644 --- a/conftest.py +++ b/conftest.py @@ -6,12 +6,6 @@ from testcontainers.compose import DockerCompose -# @pytest.fixture(scope="session") -# def event_loop(): -# loop = asyncio.get_event_loop_policy().new_event_loop() -# yield loop -# loop.close() - @pytest.fixture(scope="session", autouse=True) def redis_container(): diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 2162f4de..3fe79631 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -84,25 +84,12 @@ def _process(doc: "Document") -> Dict[str, Any]: return [_process(doc) for doc in results.docs] -def check_modules_present(): +def setup_redis(): def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): result = func(self, *args, **kwargs) - RedisConnectionFactory.validate_redis_modules(self._redis_client) - return result - - return wrapper - - return decorator - - -def check_async_modules_present(): - def decorator(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - result = func(self, *args, **kwargs) - RedisConnectionFactory.validate_async_redis_modules(self._redis_client) + RedisConnectionFactory.validate_redis(self._redis_client, self._lib_name) return result return wrapper @@ -175,6 +162,9 @@ def __init__( self.schema = schema + # set custom lib name + self._lib_name: Optional[str] = kwargs.pop("lib_name", None) + # set up redis connection self._redis_client: Optional[Union[redis.Redis, aredis.Redis]] = None if redis_client is not None: @@ -350,11 +340,13 @@ def connect(self, redis_url: Optional[str] = None, **kwargs): index.connect(redis_url="redis://localhost:6379") """ - client = RedisConnectionFactory.connect(redis_url, use_async=False, **kwargs) + client = RedisConnectionFactory.connect( + redis_url=redis_url, use_async=False, **kwargs + ) return self.set_client(client) - @check_modules_present() - def set_client(self, client: redis.Redis): + @setup_redis() + def set_client(self, client: redis.Redis, **kwargs): """Manually set the Redis client to use with the search index. This method configures the search index to use a specific Redis or @@ -729,10 +721,12 @@ def connect(self, redis_url: Optional[str] = None, **kwargs): index.connect(redis_url="redis://localhost:6379") """ - client = RedisConnectionFactory.connect(redis_url, use_async=True, **kwargs) + client = RedisConnectionFactory.connect( + redis_url=redis_url, use_async=True, **kwargs + ) return self.set_client(client) - @check_async_modules_present() + @setup_redis() def set_client(self, client: aredis.Redis): """Manually set the Redis client to use with the search index. diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 79eb3060..d747b7ca 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -1,5 +1,6 @@ +import asyncio import os -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union from redis import Redis from redis.asyncio import Redis as AsyncRedis @@ -10,9 +11,11 @@ ConnectionPool, SSLConnection, ) +from redis.exceptions import ResponseError from redisvl.redis.constants import REDIS_REQUIRED_MODULES from redisvl.redis.utils import convert_bytes +from redisvl.version import __version__ def get_address_from_env() -> str: @@ -26,6 +29,20 @@ def get_address_from_env() -> str: return os.environ["REDIS_URL"] +def make_lib_name(*args) -> str: + """Build the lib name to be reported through the Redis client setinfo + command. + + Returns: + str: Redis client library name + """ + custom_libs = f"redisvl_v{__version__}" + for arg in args: + if arg: + custom_libs += f";{arg}" + return f"redis-py({custom_libs})" + + class RedisConnectionFactory: """Builds connections to a Redis database, supporting both synchronous and asynchronous clients. @@ -108,54 +125,116 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi return AsyncRedis.from_url(get_address_from_env(), **kwargs) @staticmethod - def validate_redis_modules( - client: Redis, redis_required_modules: Optional[List[Dict[str, Any]]] = None + def validate_redis( + client: Union[Redis, AsyncRedis], + lib_name: Optional[str] = None, + redis_required_modules: Optional[List[Dict[str, Any]]] = None, ) -> None: - """Validates if the required Redis modules are installed. + """Validates the Redis connection. Args: - client (Redis): Synchronous Redis client. + client (Redis or AsyncRedis): Redis client. + lib_name (str): Library name to set on the Redis client. + redis_required_modules (List[Dict[str, Any]]): List of required modules and their versions. Raises: ValueError: If required Redis modules are not installed. """ - RedisConnectionFactory._validate_redis_modules( - convert_bytes(client.module_list()), redis_required_modules - ) + if isinstance(client, AsyncRedis): + RedisConnectionFactory._run_async( + RedisConnectionFactory._validate_async_redis, + client, + lib_name, + redis_required_modules, + ) + else: + RedisConnectionFactory._validate_sync_redis( + client, lib_name, redis_required_modules + ) + + @staticmethod + def _validate_sync_redis( + client: Redis, + lib_name: Optional[str], + redis_required_modules: Optional[List[Dict[str, Any]]], + ) -> None: + """Validates the sync client.""" + # Set client library name + _lib_name = make_lib_name(lib_name) + try: + client.client_setinfo("LIB-NAME", _lib_name) # type: ignore + except ResponseError: + # Fall back to a simple log echo + client.echo(_lib_name) + + # Get list of modules + modules_list = convert_bytes(client.module_list()) + + # Validate available modules + RedisConnectionFactory._validate_modules(modules_list, redis_required_modules) @staticmethod - def validate_async_redis_modules( + async def _validate_async_redis( client: AsyncRedis, - redis_required_modules: Optional[List[Dict[str, Any]]] = None, + lib_name: Optional[str], + redis_required_modules: Optional[List[Dict[str, Any]]], ) -> None: + """Validates the async client.""" + # Set client library name + _lib_name = make_lib_name(lib_name) + try: + await client.client_setinfo("LIB-NAME", _lib_name) # type: ignore + except ResponseError: + # Fall back to a simple log echo + await client.echo(_lib_name) + + # Get list of modules + modules_list = convert_bytes(await client.module_list()) + + # Validate available modules + RedisConnectionFactory._validate_modules(modules_list, redis_required_modules) + + @staticmethod + def _run_async(coro, *args, **kwargs): """ - Validates if the required Redis modules are installed. + Runs an asynchronous function in the appropriate event loop context. + + This method checks if there is an existing event loop running. If there is, + it schedules the coroutine to be run within the current loop using `asyncio.ensure_future`. + If no event loop is running, it creates a new event loop, runs the coroutine, + and then closes the loop to avoid resource leaks. Args: - client (AsyncRedis): Asynchronous Redis client. + coro (coroutine): The coroutine function to be run. + *args: Positional arguments to pass to the coroutine function. + **kwargs: Keyword arguments to pass to the coroutine function. - Raises: - ValueError: If required Redis modules are not installed. + Returns: + The result of the coroutine if a new event loop is created, + otherwise a task object representing the coroutine execution. """ - # pick the right connection class - connection_class: Type[AbstractConnection] = ( - SSLConnection - if client.connection_pool.connection_class == ASSLConnection - else Connection - ) - # set up a temp sync client - temp_client = Redis( - connection_pool=ConnectionPool( - connection_class=connection_class, - **client.connection_pool.connection_kwargs, - ) - ) - RedisConnectionFactory.validate_redis_modules( - temp_client, redis_required_modules - ) + try: + # Try to get the current running event loop + loop = asyncio.get_running_loop() + except RuntimeError: # No running event loop + loop = None + + if loop and loop.is_running(): + # If an event loop is running, schedule the coroutine to run in the existing loop + return asyncio.ensure_future(coro(*args, **kwargs)) + else: + # No event loop is running, create a new event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + # Run the coroutine in the new event loop and wait for it to complete + return loop.run_until_complete(coro(*args, **kwargs)) + finally: + # Close the event loop to release resources + loop.close() @staticmethod - def _validate_redis_modules( + def _validate_modules( installed_modules, redis_required_modules: Optional[List[Dict[str, Any]]] = None ) -> None: """ diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index 13347bfe..cfe6e45b 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -6,13 +6,45 @@ from redis.exceptions import ConnectionError from redisvl.redis.connection import RedisConnectionFactory, get_address_from_env +from redisvl.version import __version__ + +EXPECTED_LIB_NAME = f"redis-py(redisvl_v{__version__})" + + +def compare_versions(version1, version2): + """ + Compare two Redis version strings numerically. + + Parameters: + version1 (str): The first version string (e.g., "7.2.4"). + version2 (str): The second version string (e.g., "6.2.1"). + + Returns: + int: -1 if version1 < version2, 0 if version1 == version2, 1 if version1 > version2. + """ + v1_parts = list(map(int, version1.split("."))) + v2_parts = list(map(int, version2.split("."))) + + for v1, v2 in zip(v1_parts, v2_parts): + if v1 < v2: + return False + elif v1 > v2: + return True + + # If the versions are equal so far, compare the lengths of the version parts + if len(v1_parts) < len(v2_parts): + return False + elif len(v1_parts) > len(v2_parts): + return True + + return True def test_get_address_from_env(redis_url): assert get_address_from_env() == redis_url -def test_sync_redis_connection(redis_url): +def test_sync_redis_connect(redis_url): client = RedisConnectionFactory.connect(redis_url) assert client is not None assert isinstance(client, Redis) @@ -21,7 +53,7 @@ def test_sync_redis_connection(redis_url): @pytest.mark.asyncio -async def test_async_redis_connection(redis_url): +async def test_async_redis_connect(redis_url): client = RedisConnectionFactory.connect(redis_url, use_async=True) assert client is not None assert isinstance(client, AsyncRedis) @@ -49,11 +81,19 @@ def test_unknown_redis(): bad_client.ping() -def test_required_modules(client): - RedisConnectionFactory.validate_redis_modules(client) +def test_validate_redis(client): + redis_version = client.info()["redis_version"] + if not compare_versions(redis_version, "7.2.0"): + pytest.skip("Not using a late enough version of Redis") + RedisConnectionFactory.validate_redis(client) + lib_name = client.client_info() + assert lib_name["lib-name"] == EXPECTED_LIB_NAME -@pytest.mark.asyncio -async def test_async_required_modules(async_client): - client = await async_client - RedisConnectionFactory.validate_async_redis_modules(client) +def test_validate_redis_custom_lib_name(client): + redis_version = client.info()["redis_version"] + if not compare_versions(redis_version, "7.2.0"): + pytest.skip("Not using a late enough version of Redis") + RedisConnectionFactory.validate_redis(client, "langchain_v0.1.0") + lib_name = client.client_info() + assert lib_name["lib-name"] == f"redis-py(redisvl_v{__version__};langchain_v0.1.0)"