diff --git a/once/__init__.py b/once/__init__.py index f671970..6838ba5 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -2,12 +2,15 @@ import abc import asyncio import collections.abc +import enum import functools import inspect import threading import typing import weakref +from . import _iterator_wrappers + def _new_lock() -> threading.Lock: return threading.Lock() @@ -21,30 +24,42 @@ def _is_method(func: collections.abc.Callable): return "self" in sig.parameters +class _WrappedFunctionType(enum.Enum): + UNSUPPORTED = 0 + SYNC_FUNCTION = 1 + ASYNC_FUNCTION = 2 + SYNC_GENERATOR = 3 + ASYNC_GENERATOR = 4 + + +def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionType: + # The function inspect.isawaitable is a bit of a misnomer - it refers + # to the awaitable result of an async function, not the async function + # itself. + if inspect.isasyncgenfunction(func): + return _WrappedFunctionType.ASYNC_GENERATOR + if inspect.isgeneratorfunction(func): + return _WrappedFunctionType.SYNC_GENERATOR + if inspect.iscoroutinefunction(func): + return _WrappedFunctionType.ASYNC_FUNCTION + # This must come last, because it would return True for all the other types + if inspect.isfunction(func): + return _WrappedFunctionType.SYNC_FUNCTION + return _WrappedFunctionType.UNSUPPORTED + + class _OnceBase(abc.ABC): """Abstract Base Class for once function decorators.""" - def __init__(self, func: collections.abc.Callable): + def __init__(self, func: collections.abc.Callable) -> None: functools.update_wrapper(self, func) self.func = self._inspect_function(func) self.called = False self.return_value: typing.Any = None - - self.is_asyncgen = inspect.isasyncgenfunction(self.func) - if self.is_asyncgen: - self.asyncgen_finished = False - self.asyncgen_generator = None - self.asyncgen_results: list = [] - self.async_generating = False - - # Only works for one way generators, not anything that requires send for now. - # Async generators do support send. - self.is_syncgen = inspect.isgeneratorfunction(self.func) - # The function inspect.isawaitable is a bit of a misnomer - it refers - # to the awaitable result of an async function, not the async function - # itself. - self.is_async = True if self.is_asyncgen else inspect.iscoroutinefunction(self.func) - if self.is_async: + self.fn_type = _wrapped_function_type(self.func) + if self.fn_type == _WrappedFunctionType.UNSUPPORTED: + raise SyntaxError(f"Unable to wrap a {type(func)}") + if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION: self.async_lock = asyncio.Lock() else: self.lock = _new_lock() @@ -59,74 +74,6 @@ def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.C It should return the function which should be executed once. """ - def _sync_return(self): - if self.is_syncgen: - self.return_value.__iter__() - return self.return_value - - async def _async_gen_proxy(self, func, *args, **kwargs): - i = 0 - send = None - next_val = None - - # A copy of self.async_generating that we can access outside of the lock. - async_generating = None - - # Indicates that we're tied for the head generator, but someone started generating the next - # result first, so we should just poll until the result is available. - waiting_for_async_generating = False - - while True: - if waiting_for_async_generating: - # This is a load bearing sleep. We're waiting for the leader to generate the result, but - # we have control of the lock, so the async with will never yield execution to the event loop, - # so we would loop forever. By awaiting sleep(0), we yield execution which will allow us to - # poll for self.async_generating readiness. - await asyncio.sleep(0) - waiting_for_async_generating = False - async with self.async_lock: - if self.asyncgen_generator is None and not self.asyncgen_finished: - # We're the first! Do some setup. - self.asyncgen_generator = func(*args, **kwargs) - - if i == len(self.asyncgen_results) and not self.asyncgen_finished: - if self.async_generating: - # We're at the lead, but someone else is generating the next value - # so we just hop back onto the next iteration of the loop - # until it's ready. - waiting_for_async_generating = True - continue - # We're at the lead and no one else is generating, so we need to increment - # the iterator. We just store the value in self.asyncgen_results so that - # we can later yield it outside of the lock. - self.async_generating = self.asyncgen_generator.asend(send) - async_generating = self.async_generating - elif i == len(self.asyncgen_results) and self.asyncgen_finished: - # All done. - return - else: - # We already have the correct result, so we grab it here to - # yield it outside the lock. - next_val = self.asyncgen_results[i] - - if async_generating: - try: - next_val = await async_generating - except StopAsyncIteration: - async with self.async_lock: - self.asyncgen_generator = None # Allow this to be GCed. - self.asyncgen_finished = True - self.async_generating = None - async_generating = None - return - async with self.async_lock: - self.asyncgen_results.append(next_val) - async_generating = None - self.async_generating = None - - send = yield next_val - i += 1 - async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs): if self.called: return self.return_value @@ -138,24 +85,40 @@ async def _execute_call_once_async(self, func: collections.abc.Callable, *args, self.called = True return self.return_value + # This cannot be an async function! + def _execute_call_once_async_iter(self, func: collections.abc.Callable, *args, **kwargs): + if self.called: + return self.return_value.yield_results() + with self.lock: + if not self.called: + self.called = True + self.return_value = _iterator_wrappers.AsyncGeneratorWrapper(func, *args, **kwargs) + return self.return_value.yield_results() + + def _sync_return(self): + if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR: + return self.return_value.yield_results().__iter__() + else: + return self.return_value + def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs): if self.called: return self._sync_return() with self.lock: if self.called: return self._sync_return() - self.return_value = func(*args, **kwargs) - if self.is_syncgen: - # A potential optimization is to evaluate the iterator lazily, - # as opposed to eagerly like we do here. - self.return_value = tuple(self.return_value) + if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR: + self.return_value = _iterator_wrappers.GeneratorWrapper(func, *args, **kwargs) + else: + self.return_value = func(*args, **kwargs) self.called = True return self._sync_return() def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): - if self.is_asyncgen: - return self._async_gen_proxy(func, *args, **kwargs) - if self.is_async: + """Choose the appropriate call_once based on the function type.""" + if self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR: + return self._execute_call_once_async_iter(func, *args, **kwargs) + if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION: return self._execute_call_once_async(func, *args, **kwargs) return self._execute_call_once_sync(func, *args, **kwargs) @@ -229,7 +192,7 @@ def __get__(self, obj, cls): class once_per_instance(_OnceBase): # pylint: disable=invalid-name """A version of once for class methods which runs once per instance.""" - def __init__(self, func: collections.abc.Callable): + def __init__(self, func: collections.abc.Callable) -> None: super().__init__(func) self.return_value: weakref.WeakKeyDictionary[ typing.Any, typing.Any diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py new file mode 100644 index 0000000..4ae533f --- /dev/null +++ b/once/_iterator_wrappers.py @@ -0,0 +1,162 @@ +import asyncio +import collections.abc +import enum +import threading + +# Before we begin, a note on the assert statements in this file: +# Why are we using assert in here, you might ask, instead of implementing "proper" error handling? +# In this case, it is actually not being done out of laziness! The assert statements here +# represent our assumptions about the state at that point in time, and are always called with locks +# held, so they **REALLY** should always hold. If the assumption behind one of these asserts fails, +# the subsequent calls are going to fail anyways, so it's not like they are making the code +# artificially brittle. However, they do make testing easer, because we can directly test our +# assumption instead of having hard-to-trace errors, and also serve as very convenient +# documentation of the assumptions. +# We are always open to suggestions if there are other ways to achieve the same functionality in +# python! + + +class AsyncGeneratorWrapper: + """Wrapper around an async generator which only runs once. + + Subsequent calls will return results from the first call, which is + evaluated lazily. + """ + + def __init__(self, func, *args, **kwargs) -> None: + self.generator: collections.abc.AsyncGenerator | None = func(*args, **kwargs) + self.finished = False + self.results: list = [] + self.generating = False + self.lock = asyncio.Lock() + + async def yield_results(self) -> collections.abc.AsyncGenerator: + i = 0 + send = None + next_val = None + + # A copy of self.generating that we can access outside of the lock. + generating = None + + # Indicates that we're tied for the head generator, but someone started generating the next + # result first, so we should just poll until the result is available. + waiting_for_generating = False + + while True: + if waiting_for_generating: + # This is a load bearing sleep. We're waiting for the leader to generate the result, but + # we have control of the lock, so the async with will never yield execution to the event loop, + # so we would loop forever. By awaiting sleep(0), we yield execution which will allow us to + # poll for self.generating readiness. + await asyncio.sleep(0) + waiting_for_generating = False + async with self.lock: + if i == len(self.results) and not self.finished: + if self.generating: + # We're at the lead, but someone else is generating the next value + # so we just hop back onto the next iteration of the loop + # until it's ready. + waiting_for_generating = True + continue + # We're at the lead and no one else is generating, so we need to increment + # the iterator. We just store the value in self.results so that + # we can later yield it outside of the lock. + assert self.generator is not None + # TODO(matt): Is the fact that we have to suppress typing here a bug? + self.generating = self.generator.asend(send) # type: ignore + generating = self.generating + elif i == len(self.results) and self.finished: + # All done. + return + else: + # We already have the correct result, so we grab it here to + # yield it outside the lock. + next_val = self.results[i] + + if generating: + try: + next_val = await generating + except StopAsyncIteration: + async with self.lock: + self.generator = None # Allow this to be GCed. + self.finished = True + self.generating = None + generating = None + return + async with self.lock: + self.results.append(next_val) + generating = None + self.generating = None + + send = yield next_val + i += 1 + + +class _IteratorAction(enum.Enum): + # Generating the next value from the underlying iterator + GENERATING = 1 + # Yield an already computed value + YIELDING = 2 + # Waiting for the underlying iterator, already triggered from another call. + WAITING = 3 + + +class GeneratorWrapper: + """Wrapper around an sync generator which only runs once. + + Subsequent calls will return results from the first call, which is + evaluated lazily. + """ + + def __init__(self, func, *args, **kwargs) -> None: + self.generator: collections.abc.Generator | None = func(*args, **kwargs) + self.finished = False + self.results: list = [] + self.generating = False + self.lock = threading.Lock() + + def yield_results(self) -> collections.abc.Generator: + # Fast path for subsequent repeated call: + with self.lock: + finished = self.finished + if finished: + yield from self.results + return + i = 0 + yield_value = None + next_send = None + # Fast path for subsequent calls will not require a lock + while True: + action: _IteratorAction | None = None + # With a lock, we figure out which action to take, and then we take it after release. + with self.lock: + if i == len(self.results): + if self.finished: + return + if self.generating: + action = _IteratorAction.WAITING + else: + action = _IteratorAction.GENERATING + self.generating = True + else: + action = _IteratorAction.YIELDING + yield_value = self.results[i] + if action == _IteratorAction.WAITING: + continue + if action == _IteratorAction.YIELDING: + next_send = yield yield_value + i += 1 + continue + assert action == _IteratorAction.GENERATING + assert self.generator is not None + try: + result = self.generator.send(next_send) + except StopIteration: + with self.lock: + self.finished = True + self.generator = None # Allow this to be GCed. + self.generating = False + else: + with self.lock: + self.generating = False + self.results.append(result) diff --git a/once_test.py b/once_test.py index 0609b17..10fe922 100644 --- a/once_test.py +++ b/once_test.py @@ -56,6 +56,85 @@ def counting_fn(*args) -> int: return counting_fn, counter +class TestFunctionInspection(unittest.TestCase): + """Unit tests for function inspection""" + + def sample_sync_method(self): + return 1 + + def test_sync_function(self): + def sample_sync_fn(): + return 1 + + self.assertEqual( + once._wrapped_function_type(sample_sync_fn), once._WrappedFunctionType.SYNC_FUNCTION + ) + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_sync_method), + once._WrappedFunctionType.SYNC_FUNCTION, + ) + self.assertEqual( + once._wrapped_function_type(lambda x: x + 1), once._WrappedFunctionType.SYNC_FUNCTION + ) + + async def sample_async_method(self): + return 1 + + def test_async_function(self): + async def sample_async_fn(): + return 1 + + self.assertEqual( + once._wrapped_function_type(sample_async_fn), once._WrappedFunctionType.ASYNC_FUNCTION + ) + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_async_method), + once._WrappedFunctionType.ASYNC_FUNCTION, + ) + + def sample_sync_generator_method(self): + yield 1 + + def test_sync_generator_function(self): + def sample_sync_generator_fn(): + yield 1 + + self.assertEqual( + once._wrapped_function_type(sample_sync_generator_fn), + once._WrappedFunctionType.SYNC_GENERATOR, + ) + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_sync_generator_method), + once._WrappedFunctionType.SYNC_GENERATOR, + ) + # The output of a sync generator is not a wrappable. + self.assertEqual( + once._wrapped_function_type(sample_sync_generator_fn()), + once._WrappedFunctionType.UNSUPPORTED, + ) + + async def sample_async_generator_method(self): + yield 1 + + def test_sync_agenerator_function(self): + async def sample_async_generator_fn(): + yield 1 + + self.assertEqual( + once._wrapped_function_type(sample_async_generator_fn), + once._WrappedFunctionType.ASYNC_GENERATOR, + ) + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_async_generator_method), + once._WrappedFunctionType.ASYNC_GENERATOR, + ) + # The output of an async generator is not a wrappable. + self.assertEqual( + once._wrapped_function_type(sample_async_generator_fn()), + once._WrappedFunctionType.UNSUPPORTED, + ) + + class TestOnce(unittest.TestCase): """Unit tests for once decorators.""" @@ -77,13 +156,30 @@ def test_different_args_same_result(self): self.assertEqual(counter.value, 1) def test_iterator(self): + counter = Counter() + @once.once def yielding_iterator(): - for i in range(3): - yield i + nonlocal counter + for _ in range(3): + yield counter.get_incremented() - self.assertEqual(list(yielding_iterator()), [0, 1, 2]) - self.assertEqual(list(yielding_iterator()), [0, 1, 2]) + self.assertEqual(list(yielding_iterator()), [1, 2, 3]) + self.assertEqual(list(yielding_iterator()), [1, 2, 3]) + + def test_iterator_parallel_execution(self): + counter = Counter() + + @once.once + def yielding_iterator(): + nonlocal counter + for _ in range(3): + yield counter.get_incremented() + + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + results = list(executor.map(lambda _: list(yielding_iterator()), range(32))) + for result in results: + self.assertEqual(result, [1, 2, 3]) def test_threaded_single_function(self): counting_fn, counter = generate_once_counter_fn() @@ -360,6 +456,70 @@ def value(): self.assertEqual(_CallOnceClass.value(), 1) self.assertEqual(_CallOnceClass.value(), 1) + def test_receiving_iterator(self): + @once.once + def receiving_iterator(): + next = yield 0 + while next is not None: + next = yield next + + gen_1 = receiving_iterator() + gen_2 = receiving_iterator() + self.assertEqual(gen_1.send(None), 0) + self.assertEqual(gen_1.send(1), 1) + self.assertEqual(gen_1.send(2), 2) + self.assertEqual(gen_2.send(None), 0) + self.assertEqual(gen_2.send(-1), 1) + self.assertEqual(gen_2.send(-1), 2) + self.assertEqual(gen_2.send(5), 5) + self.assertEqual(next(gen_2, None), None) + self.assertEqual(gen_1.send(None), 5) + self.assertEqual(next(gen_1, None), None) + self.assertEqual(list(receiving_iterator()), [0, 1, 2, 5]) + + def test_receiving_iterator_parallel_execution(self): + @once.once + def receiving_iterator(): + next = yield 0 + while next is not None: + next = yield next + + def call_iterator(_): + gen = receiving_iterator() + result = [] + result.append(gen.send(None)) + for i in range(1, 32): + result.append(gen.send(i)) + return result + + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + results = list(executor.map(call_iterator, range(32))) + for result in results: + self.assertEqual(result, list(range(32))) + + def test_receiving_iterator_parallel_execution_halting(self): + @once.once + def receiving_iterator(): + next = yield 0 + while next is not None: + next = yield next + + def call_iterator(n): + """Call the iterator but end early""" + gen = receiving_iterator() + result = [] + result.append(gen.send(None)) + for i in range(1, n): + result.append(gen.send(i)) + return result + + # Unlike the previous test, each execution should yield lists of different lengths. + # This ensures that the iterator does not hang, even if not exhausted + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + results = list(executor.map(call_iterator, range(1, 32))) + for i, result in enumerate(results): + self.assertEqual(result, list(range(i + 1))) + class TestOnceAsync(unittest.IsolatedAsyncioTestCase): async def test_fn_called_once(self): @@ -458,22 +618,23 @@ async def async_yielding_iterator(): async def test_receiving_iterator(self): @once.once async def async_receiving_iterator(): - next = yield 1 + next = yield 0 while next is not None: next = yield next gen_1 = async_receiving_iterator() gen_2 = async_receiving_iterator() - self.assertEqual(await gen_1.asend(None), 1) + self.assertEqual(await gen_1.asend(None), 0) self.assertEqual(await gen_1.asend(1), 1) - self.assertEqual(await gen_1.asend(3), 3) - self.assertEqual(await gen_2.asend(None), 1) + self.assertEqual(await gen_1.asend(2), 2) + self.assertEqual(await gen_2.asend(None), 0) self.assertEqual(await gen_2.asend(None), 1) - self.assertEqual(await gen_2.asend(None), 3) + self.assertEqual(await gen_2.asend(None), 2) self.assertEqual(await gen_2.asend(5), 5) self.assertEqual(await anext(gen_2, None), None) self.assertEqual(await gen_1.asend(None), 5) self.assertEqual(await anext(gen_1, None), None) + self.assertEqual([i async for i in async_receiving_iterator()], [0, 1, 2, 5]) @unittest.skipIf(not hasattr(asyncio, "Barrier"), "Requires Barrier to evaluate") async def test_iterator_lock_not_held_during_evaluation(self):