From 5375a91065cc6cd4079139ed5b9fdcce6e186929 Mon Sep 17 00:00:00 2001 From: Derek Wan Date: Mon, 16 Sep 2024 22:21:51 +0900 Subject: [PATCH] hi --- src/utilities/redis.py | 52 ++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/src/utilities/redis.py b/src/utilities/redis.py index 3dea727eb..c585f844e 100644 --- a/src/utilities/redis.py +++ b/src/utilities/redis.py @@ -1587,6 +1587,7 @@ def __str__(self) -> str: @contextmanager def yield_client( *, + client: redis.Redis | None = None, host: str = _HOST, port: int = _PORT, db: int = 0, @@ -1596,24 +1597,28 @@ def yield_client( **kwargs: Any, ) -> Iterator[redis.Redis]: """Yield a synchronous client.""" - client = redis.Redis( - host=host, - port=port, - db=db, - password=password, - connection_pool=connection_pool, - decode_responses=decode_responses, - **kwargs, - ) + if client is None: + client_use = redis.Redis( + host=host, + port=port, + db=db, + password=password, + connection_pool=connection_pool, + decode_responses=decode_responses, + **kwargs, + ) + else: + client_use = client try: - yield client + yield client_use finally: - client.close() + client_use.close() @asynccontextmanager async def yield_client_async( *, + client: redis.asyncio.Redis | None = None, host: str = _HOST, port: int = _PORT, db: str | int = 0, @@ -1623,19 +1628,22 @@ async def yield_client_async( **kwargs: Any, ) -> AsyncIterator[redis.asyncio.Redis]: """Yield an asynchronous client.""" - client = redis.asyncio.Redis( - host=host, - port=port, - db=db, - password=password, - connection_pool=connection_pool, - decode_responses=decode_responses, - **kwargs, - ) + if client is None: + client_use = redis.asyncio.Redis( + host=host, + port=port, + db=db, + password=password, + connection_pool=connection_pool, + decode_responses=decode_responses, + **kwargs, + ) + else: + client_use = client try: - yield client + yield client_use finally: - match client.connection_pool: + match client_use.connection_pool: case redis.ConnectionPool() as pool: pool.disconnect(inuse_connections=False) # pragma: no cover case redis.asyncio.ConnectionPool() as pool: