Skip to content

Commit

Permalink
Add client info reporting (#155)
Browse files Browse the repository at this point in the history
We need to be able to track the usage of clients for measuring adoption
and impact. The `CLIENT SETINFO` command can be used for this purpose.
The required string format is documented
[here](https://redis.io/docs/latest/commands/client-setinfo/#:~:text=The%20CLIENT%20SETINFO%20command%20assigns,CLIENT%20LIST%20and%20CLIENT%20INFO%20).

### Variants

- Standalone RedisVL:
```
CLIENT SETINFO LIB-NAME redis-py(redisvl_v0.2.0)
CLIENT SETINFO LIB-VER 5.0.4
```

- Abstraction layers (like LlamaIndex or LangChain):
```
CLIENT SETINFO LIB-NAME redis-py(redisvl_v0.2.0;llama-index-vector-stores-redis_v0.1.0)
CLIENT SETINFO LIB-VER 5.0.4
```

### Constraints
- RedisVL uses both Async connection and standard connection instances
from Redis Py
	- So the technique to run this command needs to properly handle this...
- Other wrappers around RedisVL like LangChain and LlamaIndex will need
to pass through their lib_name
- These libraries generally support the notion of providing your own
client instance OR providing the connection string and performing the
connection on your behalf -- which also adds some difficulty.


### Planned Route

In order to build the proper name string, all clients that wrap redis-py
will need to use the following format:
```
{package-name}_v{version}
```

RedisVL will implement the default which is `redisvl_v0.x.x`, but outer
wrappers of RedisVL can implement their own by using a lib_name `kwarg`
on the index class like

```
SearchIndex(lib_name="langchain_v0.2.1")
```
  • Loading branch information
tylerhutcherson authored Jun 7, 2024
1 parent c4e3017 commit 934d269
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 65 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ scratch
wiki_schema.yaml
docs/_build/
.venv
.env
coverage.xml
dist/
6 changes: 0 additions & 6 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@
from testcontainers.compose import DockerCompose


# @pytest.fixture(scope="session")
# def event_loop():
# loop = asyncio.get_event_loop_policy().new_event_loop()
# yield loop
# loop.close()


@pytest.fixture(scope="session", autouse=True)
def redis_container():
Expand Down
34 changes: 14 additions & 20 deletions redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,12 @@ def _process(doc: "Document") -> Dict[str, Any]:
return [_process(doc) for doc in results.docs]


def check_modules_present():
def setup_redis():
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
RedisConnectionFactory.validate_redis_modules(self._redis_client)
return result

return wrapper

return decorator


def check_async_modules_present():
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
RedisConnectionFactory.validate_async_redis_modules(self._redis_client)
RedisConnectionFactory.validate_redis(self._redis_client, self._lib_name)
return result

return wrapper
Expand Down Expand Up @@ -175,6 +162,9 @@ def __init__(

self.schema = schema

# set custom lib name
self._lib_name: Optional[str] = kwargs.pop("lib_name", None)

# set up redis connection
self._redis_client: Optional[Union[redis.Redis, aredis.Redis]] = None
if redis_client is not None:
Expand Down Expand Up @@ -350,11 +340,13 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
index.connect(redis_url="redis://localhost:6379")
"""
client = RedisConnectionFactory.connect(redis_url, use_async=False, **kwargs)
client = RedisConnectionFactory.connect(
redis_url=redis_url, use_async=False, **kwargs
)
return self.set_client(client)

@check_modules_present()
def set_client(self, client: redis.Redis):
@setup_redis()
def set_client(self, client: redis.Redis, **kwargs):
"""Manually set the Redis client to use with the search index.
This method configures the search index to use a specific Redis or
Expand Down Expand Up @@ -729,10 +721,12 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
index.connect(redis_url="redis://localhost:6379")
"""
client = RedisConnectionFactory.connect(redis_url, use_async=True, **kwargs)
client = RedisConnectionFactory.connect(
redis_url=redis_url, use_async=True, **kwargs
)
return self.set_client(client)

@check_async_modules_present()
@setup_redis()
def set_client(self, client: aredis.Redis):
"""Manually set the Redis client to use with the search index.
Expand Down
141 changes: 110 additions & 31 deletions redisvl/redis/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, Union

from redis import Redis
from redis.asyncio import Redis as AsyncRedis
Expand All @@ -10,9 +11,11 @@
ConnectionPool,
SSLConnection,
)
from redis.exceptions import ResponseError

from redisvl.redis.constants import REDIS_REQUIRED_MODULES
from redisvl.redis.utils import convert_bytes
from redisvl.version import __version__


def get_address_from_env() -> str:
Expand All @@ -26,6 +29,20 @@ def get_address_from_env() -> str:
return os.environ["REDIS_URL"]


def make_lib_name(*args) -> str:
"""Build the lib name to be reported through the Redis client setinfo
command.
Returns:
str: Redis client library name
"""
custom_libs = f"redisvl_v{__version__}"
for arg in args:
if arg:
custom_libs += f";{arg}"
return f"redis-py({custom_libs})"


class RedisConnectionFactory:
"""Builds connections to a Redis database, supporting both synchronous and
asynchronous clients.
Expand Down Expand Up @@ -108,54 +125,116 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi
return AsyncRedis.from_url(get_address_from_env(), **kwargs)

@staticmethod
def validate_redis_modules(
client: Redis, redis_required_modules: Optional[List[Dict[str, Any]]] = None
def validate_redis(
client: Union[Redis, AsyncRedis],
lib_name: Optional[str] = None,
redis_required_modules: Optional[List[Dict[str, Any]]] = None,
) -> None:
"""Validates if the required Redis modules are installed.
"""Validates the Redis connection.
Args:
client (Redis): Synchronous Redis client.
client (Redis or AsyncRedis): Redis client.
lib_name (str): Library name to set on the Redis client.
redis_required_modules (List[Dict[str, Any]]): List of required modules and their versions.
Raises:
ValueError: If required Redis modules are not installed.
"""
RedisConnectionFactory._validate_redis_modules(
convert_bytes(client.module_list()), redis_required_modules
)
if isinstance(client, AsyncRedis):
RedisConnectionFactory._run_async(
RedisConnectionFactory._validate_async_redis,
client,
lib_name,
redis_required_modules,
)
else:
RedisConnectionFactory._validate_sync_redis(
client, lib_name, redis_required_modules
)

@staticmethod
def _validate_sync_redis(
client: Redis,
lib_name: Optional[str],
redis_required_modules: Optional[List[Dict[str, Any]]],
) -> None:
"""Validates the sync client."""
# Set client library name
_lib_name = make_lib_name(lib_name)
try:
client.client_setinfo("LIB-NAME", _lib_name) # type: ignore
except ResponseError:
# Fall back to a simple log echo
client.echo(_lib_name)

# Get list of modules
modules_list = convert_bytes(client.module_list())

# Validate available modules
RedisConnectionFactory._validate_modules(modules_list, redis_required_modules)

@staticmethod
def validate_async_redis_modules(
async def _validate_async_redis(
client: AsyncRedis,
redis_required_modules: Optional[List[Dict[str, Any]]] = None,
lib_name: Optional[str],
redis_required_modules: Optional[List[Dict[str, Any]]],
) -> None:
"""Validates the async client."""
# Set client library name
_lib_name = make_lib_name(lib_name)
try:
await client.client_setinfo("LIB-NAME", _lib_name) # type: ignore
except ResponseError:
# Fall back to a simple log echo
await client.echo(_lib_name)

# Get list of modules
modules_list = convert_bytes(await client.module_list())

# Validate available modules
RedisConnectionFactory._validate_modules(modules_list, redis_required_modules)

@staticmethod
def _run_async(coro, *args, **kwargs):
"""
Validates if the required Redis modules are installed.
Runs an asynchronous function in the appropriate event loop context.
This method checks if there is an existing event loop running. If there is,
it schedules the coroutine to be run within the current loop using `asyncio.ensure_future`.
If no event loop is running, it creates a new event loop, runs the coroutine,
and then closes the loop to avoid resource leaks.
Args:
client (AsyncRedis): Asynchronous Redis client.
coro (coroutine): The coroutine function to be run.
*args: Positional arguments to pass to the coroutine function.
**kwargs: Keyword arguments to pass to the coroutine function.
Raises:
ValueError: If required Redis modules are not installed.
Returns:
The result of the coroutine if a new event loop is created,
otherwise a task object representing the coroutine execution.
"""
# pick the right connection class
connection_class: Type[AbstractConnection] = (
SSLConnection
if client.connection_pool.connection_class == ASSLConnection
else Connection
)
# set up a temp sync client
temp_client = Redis(
connection_pool=ConnectionPool(
connection_class=connection_class,
**client.connection_pool.connection_kwargs,
)
)
RedisConnectionFactory.validate_redis_modules(
temp_client, redis_required_modules
)
try:
# Try to get the current running event loop
loop = asyncio.get_running_loop()
except RuntimeError: # No running event loop
loop = None

if loop and loop.is_running():
# If an event loop is running, schedule the coroutine to run in the existing loop
return asyncio.ensure_future(coro(*args, **kwargs))
else:
# No event loop is running, create a new event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Run the coroutine in the new event loop and wait for it to complete
return loop.run_until_complete(coro(*args, **kwargs))
finally:
# Close the event loop to release resources
loop.close()

@staticmethod
def _validate_redis_modules(
def _validate_modules(
installed_modules, redis_required_modules: Optional[List[Dict[str, Any]]] = None
) -> None:
"""
Expand Down
56 changes: 48 additions & 8 deletions tests/integration/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,45 @@
from redis.exceptions import ConnectionError

from redisvl.redis.connection import RedisConnectionFactory, get_address_from_env
from redisvl.version import __version__

EXPECTED_LIB_NAME = f"redis-py(redisvl_v{__version__})"


def compare_versions(version1, version2):
"""
Compare two Redis version strings numerically.
Parameters:
version1 (str): The first version string (e.g., "7.2.4").
version2 (str): The second version string (e.g., "6.2.1").
Returns:
int: -1 if version1 < version2, 0 if version1 == version2, 1 if version1 > version2.
"""
v1_parts = list(map(int, version1.split(".")))
v2_parts = list(map(int, version2.split(".")))

for v1, v2 in zip(v1_parts, v2_parts):
if v1 < v2:
return False
elif v1 > v2:
return True

# If the versions are equal so far, compare the lengths of the version parts
if len(v1_parts) < len(v2_parts):
return False
elif len(v1_parts) > len(v2_parts):
return True

return True


def test_get_address_from_env(redis_url):
assert get_address_from_env() == redis_url


def test_sync_redis_connection(redis_url):
def test_sync_redis_connect(redis_url):
client = RedisConnectionFactory.connect(redis_url)
assert client is not None
assert isinstance(client, Redis)
Expand All @@ -21,7 +53,7 @@ def test_sync_redis_connection(redis_url):


@pytest.mark.asyncio
async def test_async_redis_connection(redis_url):
async def test_async_redis_connect(redis_url):
client = RedisConnectionFactory.connect(redis_url, use_async=True)
assert client is not None
assert isinstance(client, AsyncRedis)
Expand Down Expand Up @@ -49,11 +81,19 @@ def test_unknown_redis():
bad_client.ping()


def test_required_modules(client):
RedisConnectionFactory.validate_redis_modules(client)
def test_validate_redis(client):
redis_version = client.info()["redis_version"]
if not compare_versions(redis_version, "7.2.0"):
pytest.skip("Not using a late enough version of Redis")
RedisConnectionFactory.validate_redis(client)
lib_name = client.client_info()
assert lib_name["lib-name"] == EXPECTED_LIB_NAME


@pytest.mark.asyncio
async def test_async_required_modules(async_client):
client = await async_client
RedisConnectionFactory.validate_async_redis_modules(client)
def test_validate_redis_custom_lib_name(client):
redis_version = client.info()["redis_version"]
if not compare_versions(redis_version, "7.2.0"):
pytest.skip("Not using a late enough version of Redis")
RedisConnectionFactory.validate_redis(client, "langchain_v0.1.0")
lib_name = client.client_info()
assert lib_name["lib-name"] == f"redis-py(redisvl_v{__version__};langchain_v0.1.0)"

0 comments on commit 934d269

Please sign in to comment.