From 42d4e841b7108f8b92849dc512b33a64757ac163 Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Thu, 19 Oct 2023 17:19:40 +0000 Subject: [PATCH] Implement exception handling for async generators. Fixes #12. The default behavior is to cache Exceptions. However, there is an option to retry exceptions, which will also respect the concurrency guarentees from once. --- once/__init__.py | 112 ++++++++++++----- once/_iterator_wrappers.py | 240 ++++++++++++++++++++----------------- once_test.py | 152 +++++++++++++++++++++-- 3 files changed, 355 insertions(+), 149 deletions(-) diff --git a/once/__init__.py b/once/__init__.py index 0e95f82..6ad8ce6 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -95,10 +95,16 @@ def return_value(self, value: typing.Any) -> None: _ONCE_FACTORY_TYPE = collections.abc.Callable # type: ignore +class _CachedException: + def __init__(self, exception: Exception): + self.exception = exception + + def _wrap( func: collections.abc.Callable, once_factory: _ONCE_FACTORY_TYPE, fn_type: _WrappedFunctionType, + retry_exceptions: bool, ) -> collections.abc.Callable: """Generate a wrapped function appropriate to the function type. @@ -119,7 +125,7 @@ async def wrapped(*args, **kwargs) -> typing.Any: async with once_base.async_lock: if not once_base.called: once_base.return_value = _iterator_wrappers.AsyncGeneratorWrapper( - func, *args, **kwargs + retry_exceptions, func, *args, **kwargs ) once_base.called = True return_value = once_base.return_value @@ -132,24 +138,58 @@ async def wrapped(*args, **kwargs) -> typing.Any: return elif fn_type == _WrappedFunctionType.ASYNC_FUNCTION: + if retry_exceptions: - async def wrapped(*args, **kwargs) -> typing.Any: - once_base: _OnceBase = once_factory() - async with once_base.async_lock: - if not once_base.called: - once_base.return_value = await func(*args, **kwargs) - once_base.called = True - return once_base.return_value + async def wrapped(*args, **kwargs) -> typing.Any: + once_base: _OnceBase = once_factory() + async with once_base.async_lock: + if not once_base.called: + once_base.return_value = await func(*args, **kwargs) + once_base.called = True + return once_base.return_value + + else: + + async def wrapped(*args, **kwargs) -> typing.Any: + once_base: _OnceBase = once_factory() + async with once_base.async_lock: + if not once_base.called: + try: + once_base.return_value = await func(*args, **kwargs) + except Exception as exception: + once_base.return_value = _CachedException(exception) + once_base.called = True + return_value = once_base.return_value + if isinstance(return_value, _CachedException): + raise return_value.exception + return return_value elif fn_type == _WrappedFunctionType.SYNC_FUNCTION: + if retry_exceptions: - def wrapped(*args, **kwargs) -> typing.Any: - once_base: _OnceBase = once_factory() - with once_base.lock: - if not once_base.called: - once_base.return_value = func(*args, **kwargs) - once_base.called = True - return once_base.return_value + def wrapped(*args, **kwargs) -> typing.Any: + once_base: _OnceBase = once_factory() + with once_base.lock: + if not once_base.called: + once_base.return_value = func(*args, **kwargs) + once_base.called = True + return once_base.return_value + + else: + + def wrapped(*args, **kwargs) -> typing.Any: + once_base: _OnceBase = once_factory() + with once_base.lock: + if not once_base.called: + try: + once_base.return_value = func(*args, **kwargs) + except Exception as exception: + once_base.return_value = _CachedException(exception) + once_base.called = True + return_value = once_base.return_value + if isinstance(return_value, _CachedException): + raise return_value.exception + return return_value elif fn_type == _WrappedFunctionType.SYNC_GENERATOR: @@ -158,7 +198,7 @@ def wrapped(*args, **kwargs) -> typing.Any: with once_base.lock: if not once_base.called: once_base.return_value = _iterator_wrappers.GeneratorWrapper( - func, *args, **kwargs + retry_exceptions, func, *args, **kwargs ) once_base.called = True iterator = once_base.return_value @@ -195,7 +235,7 @@ def _once_factory(is_async: bool, per_thread: bool) -> _ONCE_FACTORY_TYPE: return lambda: singleton_once -def once(*args, per_thread=False) -> collections.abc.Callable: +def once(*args, per_thread=False, retry_exceptions=False) -> collections.abc.Callable: """Decorator to ensure a function is only called once. The restriction of only one call also holds across threads. However, this @@ -225,7 +265,7 @@ def once(*args, per_thread=False) -> collections.abc.Callable: # This trick lets this function be a decorator directly, or be called # to create a decorator. # Both @once and @once() will function correctly. - return functools.partial(once, per_thread=per_thread) + return functools.partial(once, per_thread=per_thread, retry_exceptions=retry_exceptions) if _is_method(func): raise SyntaxError( "Attempting to use @once.once decorator on method " @@ -233,7 +273,7 @@ def once(*args, per_thread=False) -> collections.abc.Callable: ) fn_type = _wrapped_function_type(func) once_factory = _once_factory(is_async=fn_type in _ASYNC_FN_TYPES, per_thread=per_thread) - return _wrap(func, once_factory, fn_type) + return _wrap(func, once_factory, fn_type, retry_exceptions) class once_per_class: # pylint: disable=invalid-name @@ -243,15 +283,21 @@ class once_per_class: # pylint: disable=invalid-name is_staticmethod: bool @classmethod - def with_options(cls, per_thread: bool = False): - return lambda func: cls(func, per_thread=per_thread) - - def __init__(self, func: collections.abc.Callable, per_thread: bool = False) -> None: + def with_options(cls, per_thread: bool = False, retry_exceptions=False): + return lambda func: cls(func, per_thread=per_thread, retry_exceptions=retry_exceptions) + + def __init__( + self, + func: collections.abc.Callable, + per_thread: bool = False, + retry_exceptions: bool = False, + ) -> None: self.func = self._inspect_function(func) self.fn_type = _wrapped_function_type(self.func) self.once_factory = _once_factory( is_async=self.fn_type in _ASYNC_FN_TYPES, per_thread=per_thread ) + self.retry_exceptions = retry_exceptions def _inspect_function(self, func: collections.abc.Callable): if not _is_method(func): @@ -280,17 +326,22 @@ def __get__(self, obj, cls) -> collections.abc.Callable: func = self.func else: func = functools.partial(self.func, obj) - return _wrap(func, self.once_factory, self.fn_type) + return _wrap(func, self.once_factory, self.fn_type, self.retry_exceptions) class once_per_instance: # pylint: disable=invalid-name """A version of once for class methods which runs once per instance.""" @classmethod - def with_options(cls, per_thread: bool = False): - return lambda func: cls(func, per_thread=per_thread) - - def __init__(self, func: collections.abc.Callable, per_thread: bool = False) -> None: + def with_options(cls, per_thread: bool = False, retry_exceptions=False): + return lambda func: cls(func, per_thread=per_thread, retry_exceptions=retry_exceptions) + + def __init__( + self, + func: collections.abc.Callable, + per_thread: bool = False, + retry_exceptions: bool = False, + ) -> None: self.func = self._inspect_function(func) self.fn_type = _wrapped_function_type(self.func) self.is_async_fn = self.fn_type in _ASYNC_FN_TYPES @@ -299,6 +350,7 @@ def __init__(self, func: collections.abc.Callable, per_thread: bool = False) -> typing.Any, collections.abc.Callable ] = weakref.WeakKeyDictionary() self.per_thread = per_thread + self.retry_exceptions = retry_exceptions def once_factory(self) -> _ONCE_FACTORY_TYPE: """Generate a new once factory. @@ -324,6 +376,8 @@ def __get__(self, obj, cls) -> collections.abc.Callable: with self.callables_lock: if (callable := self.callables.get(obj)) is None: bound_func = functools.partial(self.func, obj) - callable = _wrap(bound_func, self.once_factory(), self.fn_type) + callable = _wrap( + bound_func, self.once_factory(), self.fn_type, self.retry_exceptions + ) self.callables[obj] = callable return callable diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py index 6185e22..b6c0959 100644 --- a/once/_iterator_wrappers.py +++ b/once/_iterator_wrappers.py @@ -4,6 +4,7 @@ import functools import threading import time +import typing # Before we begin, a note on the assert statements in this file: # Why are we using assert in here, you might ask, instead of implementing "proper" error handling? @@ -24,116 +25,160 @@ def __init__(self) -> None: self.exception: Exception | None = None self.finished = False + def fast_path(self): + """Should only be computed with a lock.""" + return self.finished and self.exception is None -# TODO(matt): Refactor AsyncGeneratorWrapper to use state enums, and add error-handling. -class AsyncGeneratorWrapper: + +class _IteratorAction(enum.Enum): + # Generating the next value from the underlying iterator + GENERATING = 1 + # Yield an already computed value + YIELDING = 2 + # Waiting for the underlying iterator, already triggered from another call. + WAITING = 3 + # We can return, we are done! + RETURNING = 4 + + +class _GeneratorWrapperBase: + """Base class for generator wrapper. + + Even though the class stores a result, all of the methods separately take a result input. + Why is that? Great question. + During a call to yield_results, we should grab a reference to the existing self.result, and + only operate on that by passing it into here. In the event of an Exception, if + reset_on_exception is set, a new execution is kicked off to retry, but existing iterators from + yield_results will continue as if that never happened, and still raise an Exception. This will + avoid mixing results from different iterators. + """ + + def __init__( + self, reset_on_exception: bool, func: collections.abc.Callable, *args, **kwargs + ) -> None: + self.callable: collections.abc.Callable | None = functools.partial(func, *args, **kwargs) + self.generator = self.callable() + self.result = IteratorResults() + self.generating = False + self.reset_on_exception = reset_on_exception + + def compute_next_action( + self, result: IteratorResults, i: int + ) -> typing.Tuple[_IteratorAction, typing.Any]: + """Must be called with lock.""" + if i == len(result.items): + if result.finished: + if result.exception: + raise result.exception + return _IteratorAction.RETURNING, None + if self.generating: + return _IteratorAction.WAITING, None + else: + # If all of these functions are called with locks, we will never have more than one + # caller have GENERATING at any time. + self.generating = True + return _IteratorAction.GENERATING, None + else: + return _IteratorAction.YIELDING, result.items[i] + + def successful_completion(self, result: IteratorResults): + """Must be called with lock.""" + result.finished = True + self.generating = False + self.generator = None # Allow this to be GCed. + self.callable = None # Allow this to be GCed. + + def record_item(self, result: IteratorResults, item: typing.Any): + self.generating = False + result.items.append(item) + + def record_exception(self, result: IteratorResults, exception: Exception): + """Must be called with lock.""" + result.finished = True + # We need to keep track of the exception so that we can raise it in the same + # position every time the iterator is called. + result.exception = exception + self.generating = False + assert self.callable is not None + self.generator = self.callable() # Reset the iterator for the next call. + if self.reset_on_exception: + self.result = IteratorResults() + + +class AsyncGeneratorWrapper(_GeneratorWrapperBase): """Wrapper around an async generator which only runs once. Subsequent calls will return results from the first call, which is evaluated lazily. """ - def __init__(self, func, *args, **kwargs) -> None: - self.generator: collections.abc.AsyncGenerator | None = func(*args, **kwargs) - self.result = IteratorResults() - self.generating = False + generator: collections.abc.AsyncGenerator + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) self.lock = asyncio.Lock() async def yield_results(self) -> collections.abc.AsyncGenerator: - i = 0 - send = None - next_val = None - - # A copy of self.generating that we can access outside of the lock. - generating = None + # Fast path for subsequent repeated call: + async with self.lock: + result = self.result + fast_path = result.fast_path() + if fast_path: + for item in result.items: + yield item + return - # Indicates that we're tied for the head generator, but someone started generating the next - # result first, so we should just poll until the result is available. - waiting_for_generating = False + i = 0 + yield_value = None + next_send = None while True: - if waiting_for_generating: - # This is a load bearing sleep. We're waiting for the leader to generate the result, but - # we have control of the lock, so the async with will never yield execution to the event loop, - # so we would loop forever. By awaiting sleep(0), we yield execution which will allow us to - # poll for self.generating readiness. - await asyncio.sleep(0) - waiting_for_generating = False + # With a lock, we figure out which action to take, and then we take it after release. async with self.lock: - if i == len(self.result.items) and not self.result.finished: - if self.generating: - # We're at the lead, but someone else is generating the next value - # so we just hop back onto the next iteration of the loop - # until it's ready. - waiting_for_generating = True - continue - # We're at the lead and no one else is generating, so we need to increment - # the iterator. We just store the value in self.result.items so that - # we can later yield it outside of the lock. - assert self.generator is not None - # TODO(matt): Is the fact that we have to suppress typing here a bug? - self.generating = self.generator.asend(send) # type: ignore - generating = self.generating - elif i == len(self.result.items) and self.result.finished: - # All done. - return - else: - # We already have the correct result, so we grab it here to - # yield it outside the lock. - next_val = self.result.items[i] - - if generating: - try: - next_val = await generating - except StopAsyncIteration: - async with self.lock: - self.generator = None # Allow this to be GCed. - self.result.finished = True - self.generating = None - generating = None - return + action, yield_value = self.compute_next_action(result, i) + if action == _IteratorAction.RETURNING: + return + if action == _IteratorAction.WAITING: + # Indicate to python that it should switch to another thread, so we do not hog the GIL. + await asyncio.sleep(0) + continue + if action == _IteratorAction.YIELDING: + next_send = yield yield_value + i += 1 + continue + assert action == _IteratorAction.GENERATING + assert self.generator is not None + try: + item = await self.generator.asend(next_send) + except StopAsyncIteration: async with self.lock: - self.result.items.append(next_val) - generating = None - self.generating = None - - send = yield next_val - i += 1 - - -class _IteratorAction(enum.Enum): - # Generating the next value from the underlying iterator - GENERATING = 1 - # Yield an already computed value - YIELDING = 2 - # Waiting for the underlying iterator, already triggered from another call. - WAITING = 3 + self.successful_completion(result) + except Exception as e: + async with self.lock: + self.record_exception(result, e) + else: + async with self.lock: + self.record_item(result, item) -class GeneratorWrapper: +class GeneratorWrapper(_GeneratorWrapperBase): """Wrapper around an sync generator which only runs once. Subsequent calls will return results from the first call, which is evaluated lazily. """ - def __init__(self, func: collections.abc.Callable, *args, **kwargs) -> None: - self.callable: collections.abc.Callable | None = functools.partial(func, *args, **kwargs) - self.generator: collections.abc.Generator | None = self.callable() - self.result = IteratorResults() - self.generating = False + generator: collections.abc.Generator + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) self.lock = threading.Lock() def yield_results(self) -> collections.abc.Generator: - # We will grab a reference to the existing result. In the event of an Exception, a new - # execution can be kicked off to retry, but the existing call will therefore continue as - # if that never happened, and still raise an Exception. This will avoid mixing results from - # different iterators. - # Fast path for subsequent repeated call: with self.lock: result = self.result - fast_path = result.finished and result.exception is None + fast_path = result.fast_path() if fast_path: yield from self.result.items return @@ -144,19 +189,9 @@ def yield_results(self) -> collections.abc.Generator: action: _IteratorAction | None = None # With a lock, we figure out which action to take, and then we take it after release. with self.lock: - if i == len(result.items): - if result.finished: - if result.exception: - raise result.exception - return - if self.generating: - action = _IteratorAction.WAITING - else: - action = _IteratorAction.GENERATING - self.generating = True - else: - action = _IteratorAction.YIELDING - yield_value = self.result.items[i] + action, yield_value = self.compute_next_action(result, i) + if action == _IteratorAction.RETURNING: + return if action == _IteratorAction.WAITING: # Indicate to python that it should switch to another thread, so we do not hog the GIL. time.sleep(0) @@ -171,21 +206,10 @@ def yield_results(self) -> collections.abc.Generator: item = self.generator.send(next_send) except StopIteration: with self.lock: - result.finished = True - self.generating = False - self.generator = None # Allow this to be GCed. - self.callable = None # Allow this to be GCed. + self.successful_completion(result) except Exception as e: with self.lock: - result.finished = True - # We need to keep track of the exception so that we can raise it in the same - # position every time the iterator is called. - result.exception = e - self.generating = False - assert self.callable is not None - self.generator = self.callable() # Reset the iterator for the next call. - self.result = IteratorResults() + self.record_exception(result, e) else: with self.lock: - self.generating = False - result.items.append(item) + self.record_item(result, item) diff --git a/once_test.py b/once_test.py index be21001..b67418e 100644 --- a/once_test.py +++ b/once_test.py @@ -8,7 +8,6 @@ import math import sys import threading -import time import unittest import weakref @@ -28,7 +27,7 @@ async def anext(iter, default=StopAsyncIteration): # This is a "large" number of workers to schedule function calls in parallel. -_N_WORKERS = 16 +_N_WORKERS = 32 class Counter: @@ -87,9 +86,7 @@ def run(): event.set() return func(*args, **kwargs) - future = executor.submit(run) - event.wait() - return future + return executor.submit(run) def generate_once_counter_fn(): @@ -311,6 +308,22 @@ def sample_failing_fn(): raise ValueError("expected failure") return 1 + with self.assertRaises(ValueError): + sample_failing_fn() + self.assertEqual(counter.get_incremented(), 2) + with self.assertRaises(ValueError): + sample_failing_fn() + self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter") + + def test_failing_function_retry_exceptions(self): + counter = Counter() + + @once.once(retry_exceptions=True) + def sample_failing_fn(): + if counter.get_incremented() < 4: + raise ValueError("expected failure") + return 1 + with self.assertRaises(ValueError): sample_failing_fn() self.assertEqual(counter.get_incremented(), 2) @@ -343,6 +356,40 @@ def sample_failing_fn(): if result == 2: raise ValueError("expected failure after 2.") + # Both of these calls should return the same results. + call1 = sample_failing_fn() + call2 = sample_failing_fn() + self.assertEqual(next(call1), 1) + self.assertEqual(next(call2), 1) + self.assertEqual(next(call1), 2) + self.assertEqual(next(call2), 2) + with self.assertRaises(ValueError): + next(call1) + with self.assertRaises(ValueError): + next(call2) + # These next 2 calls should also fail. + call3 = sample_failing_fn() + call4 = sample_failing_fn() + self.assertEqual(next(call3), 1) + self.assertEqual(next(call4), 1) + self.assertEqual(next(call3), 2) + self.assertEqual(next(call4), 2) + with self.assertRaises(ValueError): + next(call3) + with self.assertRaises(ValueError): + next(call4) + + def test_failing_generator_retry_exceptions(self): + counter = Counter() + + @once.once(retry_exceptions=True) + def sample_failing_fn(): + yield counter.get_incremented() + result = counter.get_incremented() + yield result + if result == 2: + raise ValueError("expected failure after 2.") + # Both of these calls should return the same results. call1 = sample_failing_fn() call2 = sample_failing_fn() @@ -846,6 +893,38 @@ async def sample_failing_fn(): raise ValueError("expected failure") return 1 + with self.assertRaises(ValueError): + await sample_failing_fn() + self.assertEqual(counter.get_incremented(), 2) + with self.assertRaises(ValueError): + await sample_failing_fn() + self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter") + + async def test_inspect_func(self): + @once.once + async def async_func(): + return True + + self.assertFalse(inspect.isasyncgenfunction(async_func)) + self.assertTrue(inspect.iscoroutinefunction(async_func)) + + coroutine = async_func() + self.assertTrue(inspect.iscoroutine(coroutine)) + self.assertTrue(inspect.isawaitable(coroutine)) + self.assertFalse(inspect.isasyncgen(coroutine)) + + # Just for cleanup. + await coroutine + + async def test_failing_function_retry_exceptions(self): + counter = Counter() + + @once.once(retry_exceptions=True) + async def sample_failing_fn(): + if counter.get_incremented() < 4: + raise ValueError("expected failure") + return 1 + with self.assertRaises(ValueError): await sample_failing_fn() self.assertEqual(counter.get_incremented(), 2) @@ -900,21 +979,70 @@ async def async_yielding_iterator(): self.assertEqual([i async for i in async_yielding_iterator()], [1, 2, 3]) self.assertEqual([i async for i in async_yielding_iterator()], [1, 2, 3]) - @unittest.skip("This currently hangs and needs to be fixed, GitHub Issue #12") - async def test_failing_generator(self): + async def test_failing_generator_retry_exceptions(self): counter = Counter() @once.once async def sample_failing_fn(): yield counter.get_incremented() - raise ValueError("expected failure") + result = counter.get_incremented() + yield result + if result == 2: + raise ValueError("we raise an error when result is exactly 2") + # Both of these calls should return the same results. + call1 = sample_failing_fn() + call2 = sample_failing_fn() + self.assertEqual(await anext(call1), 1) + self.assertEqual(await anext(call2), 1) + self.assertEqual(await anext(call1), 2) + self.assertEqual(await anext(call2), 2) + with self.assertRaises(ValueError): + await anext(call1) + with self.assertRaises(ValueError): + await anext(call2) + # These next 2 calls should also fail. + call3 = sample_failing_fn() + call4 = sample_failing_fn() + self.assertEqual(await anext(call3), 1) + self.assertEqual(await anext(call4), 1) + self.assertEqual(await anext(call3), 2) + self.assertEqual(await anext(call4), 2) with self.assertRaises(ValueError): - [i async for i in sample_failing_fn()] + await anext(call3) with self.assertRaises(ValueError): - [i async for i in sample_failing_fn()] - self.assertEqual(await anext(sample_failing_fn()), 1) - self.assertEqual(await anext(sample_failing_fn()), 1) + await anext(call4) + + async def test_failing_generator_retry_exceptions(self): + counter = Counter() + + @once.once(retry_exceptions=True) + async def sample_failing_fn(): + yield counter.get_incremented() + result = counter.get_incremented() + yield result + if result == 2: + raise ValueError("we raise an error when result is exactly 2") + + # Both of these calls should return the same results. + call1 = sample_failing_fn() + call2 = sample_failing_fn() + self.assertEqual(await anext(call1), 1) + self.assertEqual(await anext(call2), 1) + self.assertEqual(await anext(call1), 2) + self.assertEqual(await anext(call2), 2) + with self.assertRaises(ValueError): + await anext(call1) + with self.assertRaises(ValueError): + await anext(call2) + # These next 2 calls should succeed. + call3 = sample_failing_fn() + call4 = sample_failing_fn() + self.assertEqual([i async for i in call3], [3, 4]) + self.assertEqual([i async for i in call4], [3, 4]) + # Subsequent calls should return the original value. + self.assertEqual([i async for i in sample_failing_fn()], [3, 4]) + self.assertEqual([i async for i in sample_failing_fn()], [3, 4]) async def test_iterator_is_lazily_evaluted(self): counter = Counter()