Skip to content

Commit

Permalink
Allow asynchronous callbacks for async retries.
Browse files Browse the repository at this point in the history
Fixes jd#249
  • Loading branch information
mastizada committed Jun 2, 2022
1 parent da1bfc9 commit 1363916
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
features:
- Allow asynchronous callbacks for `before`, `after`, `retry_error_callback`, `wait`, and `before_sleep` parameters.
76 changes: 58 additions & 18 deletions tenacity/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,32 @@

import functools
import sys
import typing
from asyncio import sleep
from asyncio import sleep as aio_sleep
from collections.abc import Awaitable
from inspect import iscoroutinefunction
from typing import Union, Callable, Any, TypeVar

from tenacity import AttemptManager
from tenacity import BaseRetrying
from tenacity import DoAttempt
from tenacity import DoSleep
from tenacity import RetryCallState
from tenacity import AttemptManager, BaseRetrying, DoAttempt, DoSleep, RetryCallState, RetryAction, TryAgain

WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable)
_RetValT = typing.TypeVar("_RetValT")
WrappedFn = TypeVar("WrappedFn", bound=Callable)
_RetValT = TypeVar("_RetValT")


class AsyncRetrying(BaseRetrying):
def __init__(self, sleep: typing.Callable[[float], typing.Awaitable] = sleep, **kwargs: typing.Any) -> None:
def __init__(self, sleep: Callable[[float], Awaitable] = aio_sleep, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.sleep = sleep

async def __call__( # type: ignore # Change signature from supertype
self,
fn: typing.Callable[..., typing.Awaitable[_RetValT]],
*args: typing.Any,
**kwargs: typing.Any,
fn: Callable[..., Awaitable[_RetValT]],
*args: Any,
**kwargs: Any,
) -> _RetValT:
self.begin()

retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
while True:
do = self.iter(retry_state=retry_state)
do = await self.iter(retry_state=retry_state)
if isinstance(do, DoAttempt):
try:
result = await fn(*args, **kwargs)
Expand All @@ -64,9 +61,9 @@ def __aiter__(self) -> "AsyncRetrying":
self._retry_state = RetryCallState(self, fn=None, args=(), kwargs={})
return self

async def __anext__(self) -> typing.Union[AttemptManager, typing.Any]:
async def __anext__(self) -> Union[AttemptManager, Any]:
while True:
do = self.iter(retry_state=self._retry_state)
do = await self.iter(retry_state=self._retry_state)
if do is None:
raise StopAsyncIteration
elif isinstance(do, DoAttempt):
Expand All @@ -82,11 +79,54 @@ def wraps(self, fn: WrappedFn) -> WrappedFn:
# Ensure wrapper is recognized as a coroutine function.

@functools.wraps(fn)
async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
async def async_wrapped(*args: Any, **kwargs: Any) -> Any:
return await fn(*args, **kwargs)

# Preserve attributes
async_wrapped.retry = fn.retry
async_wrapped.retry_with = fn.retry_with

return async_wrapped

@staticmethod
async def handle_custom_function(func: Union[Callable, Awaitable], retry_state: RetryCallState) -> Any:
if iscoroutinefunction(func):
return await func(retry_state)
return func(retry_state)

async def iter(self, retry_state: "RetryCallState") -> Union[DoAttempt, DoSleep, Any]:
fut = retry_state.outcome
if fut is None:
if self.before is not None:
await self.handle_custom_function(self.before, retry_state)
return DoAttempt()

is_explicit_retry = retry_state.outcome.failed and isinstance(retry_state.outcome.exception(), TryAgain)
if not (is_explicit_retry or self.retry(retry_state=retry_state)):
return fut.result()

if self.after is not None:
await self.handle_custom_function(self.after, retry_state)

self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start
if self.stop(retry_state=retry_state):
if self.retry_error_callback:
return await self.handle_custom_function(self.retry_error_callback, retry_state)
retry_exc = self.retry_error_cls(fut)
if self.reraise:
raise retry_exc.reraise()
raise retry_exc from fut.exception()

if self.wait:
sleep = await self.handle_custom_function(self.wait, retry_state=retry_state)
else:
sleep = 0.0
retry_state.next_action = RetryAction(sleep)
retry_state.idle_for += sleep
self.statistics["idle_for"] += sleep
self.statistics["attempt_number"] += 1

if self.before_sleep is not None:
await self.handle_custom_function(self.before_sleep, retry_state)

return DoSleep(sleep)
138 changes: 131 additions & 7 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@

import asyncio
import inspect
import logging
import unittest
from functools import wraps

from tenacity import AsyncRetrying, RetryError
import tenacity
from tenacity import _asyncio as tasyncio
from tenacity import retry, stop_after_attempt
from tenacity.wait import wait_fixed

from .test_tenacity import NoIOErrorAfterCount, current_time_ms
from .test_tenacity import CapturingHandler, NoneReturnUntilAfterCount, NoIOErrorAfterCount, current_time_ms


def asynctest(callable_):
Expand Down Expand Up @@ -67,7 +67,7 @@ async def test_iscoroutinefunction(self):
@asynctest
async def test_retry_using_async_retying(self):
thing = NoIOErrorAfterCount(5)
retrying = AsyncRetrying()
retrying = tenacity.AsyncRetrying()
await retrying(_async_function, thing)
assert thing.counter == thing.count

Expand All @@ -76,7 +76,7 @@ async def test_stop_after_attempt(self):
thing = NoIOErrorAfterCount(2)
try:
await _retryable_coroutine_with_2_attempts(thing)
except RetryError:
except tenacity.RetryError:
assert thing.counter == 2

def test_repr(self):
Expand All @@ -86,6 +86,31 @@ def test_retry_attributes(self):
assert hasattr(_retryable_coroutine, "retry")
assert hasattr(_retryable_coroutine, "retry_with")

@asynctest
async def test_async_retry_error_callback_handler(self):
num_attempts = 3
self.attempt_counter = 0

async def _retry_error_callback_handler(retry_state: tenacity.RetryCallState):
_retry_error_callback_handler.called_times += 1
return retry_state.outcome

_retry_error_callback_handler.called_times = 0

@retry(
stop=stop_after_attempt(num_attempts),
retry_error_callback=_retry_error_callback_handler,
)
async def _foobar():
self.attempt_counter += 1
raise Exception("This exception should not be raised")

result = await _foobar()

self.assertEqual(_retry_error_callback_handler.called_times, 1)
self.assertEqual(num_attempts, self.attempt_counter)
self.assertIsInstance(result, tenacity.Future)

@asynctest
async def test_attempt_number_is_correct_for_interleaved_coroutines(self):

Expand Down Expand Up @@ -125,7 +150,7 @@ async def test_do_max_attempts(self):
with attempt:
attempts += 1
raise Exception
except RetryError:
except tenacity.RetryError:
pass

assert attempts == 3
Expand All @@ -151,11 +176,110 @@ async def test_sleeps(self):
async for attempt in tasyncio.AsyncRetrying(stop=stop_after_attempt(1), wait=wait_fixed(1)):
with attempt:
raise Exception()
except RetryError:
except tenacity.RetryError:
pass
t = current_time_ms() - start
self.assertLess(t, 1.1)


class TestAsyncBeforeAfterAttempts(unittest.TestCase):
_attempt_number = 0

@asynctest
async def test_before_attempts(self):
TestAsyncBeforeAfterAttempts._attempt_number = 0

async def _before(retry_state):
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number

@retry(
wait=tenacity.wait_fixed(1),
stop=tenacity.stop_after_attempt(1),
before=_before,
)
async def _test_before():
pass

await _test_before()

self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 1)

@asynctest
async def test_after_attempts(self):
TestAsyncBeforeAfterAttempts._attempt_number = 0

async def _after(retry_state):
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number

@retry(
wait=tenacity.wait_fixed(0.1),
stop=tenacity.stop_after_attempt(3),
after=_after,
)
async def _test_after():
if TestAsyncBeforeAfterAttempts._attempt_number < 2:
raise Exception("testing after_attempts handler")
else:
pass

await _test_after()

self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 2)

@asynctest
async def test_before_sleep(self):
async def _before_sleep(retry_state):
self.assertGreater(retry_state.next_action.sleep, 0)
_before_sleep.attempt_number = retry_state.attempt_number

_before_sleep.attempt_number = 0

@retry(
wait=tenacity.wait_fixed(0.01),
stop=tenacity.stop_after_attempt(3),
before_sleep=_before_sleep,
)
async def _test_before_sleep():
if _before_sleep.attempt_number < 2:
raise Exception("testing before_sleep_attempts handler")

await _test_before_sleep()
self.assertEqual(_before_sleep.attempt_number, 2)

async def _test_before_sleep_log_returns(self, exc_info):
thing = NoneReturnUntilAfterCount(2)
logger = logging.getLogger(self.id())
logger.propagate = False
logger.setLevel(logging.INFO)
handler = CapturingHandler()
logger.addHandler(handler)
try:
_before_sleep = tenacity.before_sleep_log(logger, logging.INFO, exc_info=exc_info)
_retry = tenacity.retry_if_result(lambda result: result is None)
retrying = tenacity.AsyncRetrying(
wait=tenacity.wait_fixed(0.01),
stop=tenacity.stop_after_attempt(3),
retry=_retry,
before_sleep=_before_sleep,
)
await retrying(_async_function, thing)
finally:
logger.removeHandler(handler)

etalon_re = r"^Retrying .* in 0\.01 seconds as it returned None\.$"
self.assertEqual(len(handler.records), 2)
fmt = logging.Formatter().format
self.assertRegex(fmt(handler.records[0]), etalon_re)
self.assertRegex(fmt(handler.records[1]), etalon_re)

@asynctest
async def test_before_sleep_log_returns_without_exc_info(self):
await self._test_before_sleep_log_returns(exc_info=False)

@asynctest
async def test_before_sleep_log_returns_with_exc_info(self):
await self._test_before_sleep_log_returns(exc_info=True)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions tests/test_tenacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,8 @@ def _before_sleep(retry_state):
self.assertGreater(retry_state.next_action.sleep, 0)
_before_sleep.attempt_number = retry_state.attempt_number

_before_sleep.attempt_number = 0

@retry(
wait=tenacity.wait_fixed(0.01),
stop=tenacity.stop_after_attempt(3),
Expand Down

0 comments on commit 1363916

Please sign in to comment.