From ab7bd20e15dc5d86d068c8d01bbf6a9f3234744a Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Sat, 7 Oct 2023 22:44:06 -0700 Subject: [PATCH 1/7] Major refactor to allow lazy sync iterators. Fixes #6. Because the code was getting too unweildy and hard to follow, I refactored the iterator proxies to their own classes in a separate file. --- once/__init__.py | 151 ++++++++++++++----------------------- once/_iterator_wrappers.py | 142 ++++++++++++++++++++++++++++++++++ once_test.py | 79 +++++++++++++++++++ 3 files changed, 278 insertions(+), 94 deletions(-) create mode 100644 once/_iterator_wrappers.py diff --git a/once/__init__.py b/once/__init__.py index f671970..d6298b6 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -2,12 +2,15 @@ import abc import asyncio import collections.abc +import enum import functools import inspect import threading import typing import weakref +from . import _iterator_wrappers + def _new_lock() -> threading.Lock: return threading.Lock() @@ -21,30 +24,42 @@ def _is_method(func: collections.abc.Callable): return "self" in sig.parameters +class _WrappedFunctionType(enum.Enum): + UNSUPPORTED = 0 + SYNC_FUNCTION = 1 + ASYNC_FUNCTION = 2 + SYNC_GENERATOR = 3 + ASYNC_GENERATOR = 4 + + +def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionType: + # The function inspect.isawaitable is a bit of a misnomer - it refers + # to the awaitable result of an async function, not the async function + # itself. + if inspect.isasyncgenfunction(func): + return _WrappedFunctionType.ASYNC_GENERATOR + if inspect.isgeneratorfunction(func): + return _WrappedFunctionType.SYNC_GENERATOR + if inspect.iscoroutinefunction(func): + return _WrappedFunctionType.ASYNC_FUNCTION + # This must come last, because it would return True for all the other types + if inspect.isfunction(func): + return _WrappedFunctionType.SYNC_FUNCTION + return _WrappedFunctionType.UNSUPPORTED + + class _OnceBase(abc.ABC): """Abstract Base Class for once function decorators.""" - def __init__(self, func: collections.abc.Callable): + def __init__(self, func: collections.abc.Callable) -> None: functools.update_wrapper(self, func) self.func = self._inspect_function(func) self.called = False self.return_value: typing.Any = None - - self.is_asyncgen = inspect.isasyncgenfunction(self.func) - if self.is_asyncgen: - self.asyncgen_finished = False - self.asyncgen_generator = None - self.asyncgen_results: list = [] - self.async_generating = False - - # Only works for one way generators, not anything that requires send for now. - # Async generators do support send. - self.is_syncgen = inspect.isgeneratorfunction(self.func) - # The function inspect.isawaitable is a bit of a misnomer - it refers - # to the awaitable result of an async function, not the async function - # itself. - self.is_async = True if self.is_asyncgen else inspect.iscoroutinefunction(self.func) - if self.is_async: + self.fn_type = _wrapped_function_type(self.func) + if self.fn_type == _WrappedFunctionType.UNSUPPORTED: + raise SyntaxError(f"Unable to wrap a {type(func)}") + if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION: self.async_lock = asyncio.Lock() else: self.lock = _new_lock() @@ -59,74 +74,6 @@ def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.C It should return the function which should be executed once. """ - def _sync_return(self): - if self.is_syncgen: - self.return_value.__iter__() - return self.return_value - - async def _async_gen_proxy(self, func, *args, **kwargs): - i = 0 - send = None - next_val = None - - # A copy of self.async_generating that we can access outside of the lock. - async_generating = None - - # Indicates that we're tied for the head generator, but someone started generating the next - # result first, so we should just poll until the result is available. - waiting_for_async_generating = False - - while True: - if waiting_for_async_generating: - # This is a load bearing sleep. We're waiting for the leader to generate the result, but - # we have control of the lock, so the async with will never yield execution to the event loop, - # so we would loop forever. By awaiting sleep(0), we yield execution which will allow us to - # poll for self.async_generating readiness. - await asyncio.sleep(0) - waiting_for_async_generating = False - async with self.async_lock: - if self.asyncgen_generator is None and not self.asyncgen_finished: - # We're the first! Do some setup. - self.asyncgen_generator = func(*args, **kwargs) - - if i == len(self.asyncgen_results) and not self.asyncgen_finished: - if self.async_generating: - # We're at the lead, but someone else is generating the next value - # so we just hop back onto the next iteration of the loop - # until it's ready. - waiting_for_async_generating = True - continue - # We're at the lead and no one else is generating, so we need to increment - # the iterator. We just store the value in self.asyncgen_results so that - # we can later yield it outside of the lock. - self.async_generating = self.asyncgen_generator.asend(send) - async_generating = self.async_generating - elif i == len(self.asyncgen_results) and self.asyncgen_finished: - # All done. - return - else: - # We already have the correct result, so we grab it here to - # yield it outside the lock. - next_val = self.asyncgen_results[i] - - if async_generating: - try: - next_val = await async_generating - except StopAsyncIteration: - async with self.async_lock: - self.asyncgen_generator = None # Allow this to be GCed. - self.asyncgen_finished = True - self.async_generating = None - async_generating = None - return - async with self.async_lock: - self.asyncgen_results.append(next_val) - async_generating = None - self.async_generating = None - - send = yield next_val - i += 1 - async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs): if self.called: return self.return_value @@ -138,24 +85,40 @@ async def _execute_call_once_async(self, func: collections.abc.Callable, *args, self.called = True return self.return_value + # This cannot be an async function! + def _execute_call_once_async_iter(self, func: collections.abc.Callable, *args, **kwargs): + if self.called: + return self.return_value._yield_results() + with self.lock: + if not self.called: + self.called = True + self.return_value = _iterator_wrappers.AsyncGeneratorWrapper(func, *args, **kwargs) + return self.return_value._yield_results() + + def _sync_return(self): + if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR: + return self.return_value._yield_results().__iter__() + else: + return self.return_value + def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs): if self.called: return self._sync_return() with self.lock: if self.called: return self._sync_return() - self.return_value = func(*args, **kwargs) - if self.is_syncgen: - # A potential optimization is to evaluate the iterator lazily, - # as opposed to eagerly like we do here. - self.return_value = tuple(self.return_value) + if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR: + self.return_value = _iterator_wrappers.GeneratorWrapper(func, *args, **kwargs) + else: + self.return_value = func(*args, **kwargs) self.called = True return self._sync_return() def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): - if self.is_asyncgen: - return self._async_gen_proxy(func, *args, **kwargs) - if self.is_async: + """Choose the appropriate call_once based on the function type.""" + if self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR: + return self._execute_call_once_async_iter(func, *args, **kwargs) + if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION: return self._execute_call_once_async(func, *args, **kwargs) return self._execute_call_once_sync(func, *args, **kwargs) @@ -229,7 +192,7 @@ def __get__(self, obj, cls): class once_per_instance(_OnceBase): # pylint: disable=invalid-name """A version of once for class methods which runs once per instance.""" - def __init__(self, func: collections.abc.Callable): + def __init__(self, func: collections.abc.Callable) -> None: super().__init__(func) self.return_value: weakref.WeakKeyDictionary[ typing.Any, typing.Any diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py new file mode 100644 index 0000000..5245c13 --- /dev/null +++ b/once/_iterator_wrappers.py @@ -0,0 +1,142 @@ +import asyncio +import collections.abc +import threading + +# Before we begin, a note on the assert statements in this file: +# Why are we using assert in here, you might ask, instead of implementing "proper" error handling? +# In this case, it is actually not being done out of laziness! The assert statements here +# represent our assumptions about the state at that point in time, and are always called with locks +# held, so they **REALLY** should always hold. If the assumption behind one of these asserts fails, +# the subsequent calls are going to fail anyways, so it's not like they are making the code +# artificially brittle. However, they do make testing easer, because we can directly test our +# assumption instead of having hard-to-trace errors, and also serve as very convenient +# documentation of the assumptions. +# We are always open to suggestions if there are other ways to achieve the same functionality in +# python! + + +class AsyncGeneratorWrapper: + """Wrapper around an async generator which only runs once. + + Subsequent calls will return results from the first call, which is + evaluated lazily. + """ + + def __init__(self, func, *args, **kwargs) -> None: + self.generator: collections.abc.AsyncGenerator | None = func(*args, **kwargs) + self.finished = False + self.results: list = [] + self.generating = False + self.lock = asyncio.Lock() + + async def _yield_results(self) -> collections.abc.AsyncGenerator: + i = 0 + send = None + next_val = None + + # A copy of self.generating that we can access outside of the lock. + generating = None + + # Indicates that we're tied for the head generator, but someone started generating the next + # result first, so we should just poll until the result is available. + waiting_for_generating = False + + while True: + if waiting_for_generating: + # This is a load bearing sleep. We're waiting for the leader to generate the result, but + # we have control of the lock, so the async with will never yield execution to the event loop, + # so we would loop forever. By awaiting sleep(0), we yield execution which will allow us to + # poll for self.generating readiness. + await asyncio.sleep(0) + waiting_for_generating = False + async with self.lock: + if i == len(self.results) and not self.finished: + if self.generating: + # We're at the lead, but someone else is generating the next value + # so we just hop back onto the next iteration of the loop + # until it's ready. + waiting_for_generating = True + continue + # We're at the lead and no one else is generating, so we need to increment + # the iterator. We just store the value in self.results so that + # we can later yield it outside of the lock. + assert self.generator is not None + # TODO(matt): Is the fact that we have to suppress typing here a bug? + self.generating = self.generator.asend(send) # type: ignore + generating = self.generating + elif i == len(self.results) and self.finished: + # All done. + return + else: + # We already have the correct result, so we grab it here to + # yield it outside the lock. + next_val = self.results[i] + + if generating: + try: + next_val = await generating + except StopAsyncIteration: + async with self.lock: + self.generator = None # Allow this to be GCed. + self.finished = True + self.generating = None + generating = None + return + async with self.lock: + self.results.append(next_val) + generating = None + self.generating = None + + send = yield next_val + i += 1 + + +class GeneratorWrapper: + """Wrapper around an sync generator which only runs once. + + Subsequent calls will return results from the first call, which is + evaluated lazily. + """ + + def __init__(self, func, *args, **kwargs) -> None: + self.generator: collections.abc.Generator | None = func(*args, **kwargs) + self.finished = False + self.results: list = [] + self.generating = False + self.lock = threading.Lock() + self.next_send = None + + def _yield_results(self) -> collections.abc.Generator: + i = 0 + # Fast path for subsequent calls will not require a lock + while True: + if i < len(self.results): + yield self.results[i] + i += 1 + continue + if self.finished: + return + + # Initial calls, and concurrent calls before completion will require the lock. + with self.lock: + if i < len(self.results): + yield self.results[i] + i += 1 + continue + # Because we hold a lock, this should never be violated. + # If it does, something has gone seriously wrong! + assert i == len(self.results) + if self.finished: + return + # The generator should never be garbage collected while self.finished is False + # and the lock is held. + assert self.generator is not None + try: + self.results.append(self.generator.send(self.next_send)) + except StopIteration: + self.finished = True + self.generator = None # Allow this to be GCed. + return + else: + i += 1 + self.next_send = yield self.results[-1] diff --git a/once_test.py b/once_test.py index 0609b17..b30fa6f 100644 --- a/once_test.py +++ b/once_test.py @@ -56,6 +56,85 @@ def counting_fn(*args) -> int: return counting_fn, counter +class TestFunctionInspection(unittest.TestCase): + """Unit tests for function inspection""" + + def sample_sync_method(self): + return 1 + + def test_sync_function(self): + def sample_sync_fn(): + return 1 + + self.assertEqual( + once._wrapped_function_type(sample_sync_fn), once._WrappedFunctionType.SYNC_FUNCTION + ) + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_sync_method), + once._WrappedFunctionType.SYNC_FUNCTION, + ) + self.assertEqual( + once._wrapped_function_type(lambda x: x + 1), once._WrappedFunctionType.SYNC_FUNCTION + ) + + async def sample_async_method(self): + return 1 + + def test_async_function(self): + async def sample_async_fn(): + return 1 + + self.assertEqual( + once._wrapped_function_type(sample_async_fn), once._WrappedFunctionType.ASYNC_FUNCTION + ) + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_async_method), + once._WrappedFunctionType.ASYNC_FUNCTION, + ) + + def sample_sync_generator_method(self): + yield 1 + + def test_sync_generator_function(self): + def sample_sync_generator_fn(): + yield 1 + + self.assertEqual( + once._wrapped_function_type(sample_sync_generator_fn), + once._WrappedFunctionType.SYNC_GENERATOR, + ) + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_sync_generator_method), + once._WrappedFunctionType.SYNC_GENERATOR, + ) + # The output of a sync generator is not a wrappable. + self.assertEqual( + once._wrapped_function_type(sample_sync_generator_fn()), + once._WrappedFunctionType.UNSUPPORTED, + ) + + async def sample_async_generator_method(self): + yield 1 + + def test_sync_agenerator_function(self): + async def sample_async_generator_fn(): + yield 1 + + self.assertEqual( + once._wrapped_function_type(sample_async_generator_fn), + once._WrappedFunctionType.ASYNC_GENERATOR, + ) + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_async_generator_method), + once._WrappedFunctionType.ASYNC_GENERATOR, + ) + # The output of an async generator is not a wrappable. + self.assertEqual( + once._wrapped_function_type(sample_async_generator_fn()), + once._WrappedFunctionType.UNSUPPORTED, + ) + + class TestOnce(unittest.TestCase): """Unit tests for once decorators.""" From 4da3c7729e6780aac2a651479a86f54f0a15f18e Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Tue, 10 Oct 2023 11:06:33 -0700 Subject: [PATCH 2/7] Sync iterator listens outside the lock. We yield only outside the lock to avoid potential deadlocks. --- once/_iterator_wrappers.py | 82 +++++++++++++++++++++++-------- once_test.py | 98 ++++++++++++++++++++++++++++++++++---- 2 files changed, 152 insertions(+), 28 deletions(-) diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py index 5245c13..175535b 100644 --- a/once/_iterator_wrappers.py +++ b/once/_iterator_wrappers.py @@ -110,33 +110,77 @@ def _yield_results(self) -> collections.abc.Generator: i = 0 # Fast path for subsequent calls will not require a lock while True: - if i < len(self.results): + # If we on before the penultimate entry, we can return now. When yielding the last + # element of results, we need to be recording next_send, so that needs the lock. + if i < len(self.results) - 1: yield self.results[i] i += 1 continue + # Because we don't hold a lock here, we can't make this assumption + # i == len(self.results) - 1 or i == len(self.results) + # because the iterator could have moved in the interim. However, it will no longer + # move once self.finished. if self.finished: - return - - # Initial calls, and concurrent calls before completion will require the lock. - with self.lock: if i < len(self.results): yield self.results[i] i += 1 continue - # Because we hold a lock, this should never be violated. - # If it does, something has gone seriously wrong! - assert i == len(self.results) - if self.finished: + if i == len(self.results): return - # The generator should never be garbage collected while self.finished is False - # and the lock is held. - assert self.generator is not None - try: - self.results.append(self.generator.send(self.next_send)) - except StopIteration: + + # Initial calls, and concurrent calls before completion will require the lock. + with self.lock: + # Just in case a race condition prevented us from hitting these conditions before, + # check them again, so they can be handled by the code before the lock. + if i < len(self.results) - 1: + continue + if self.finished: + if i < len(self.results): + continue + if i == len(self.results): + return + assert i == len(self.results) - 1 or i == len(self.results) + # If we are at the end and waiting for the generator to complete, there is nothing + # to do! + if self.generating and i == len(self.results): + continue + + # At this point, there are 2 states to handle, which we will want to do outside the + # lock to avoid deadlocks. + # State #1: We are about to yield back the last entry in self.results and potentially + # log next send. We can allow multiple calls to enter this state, as long + # as we re-grab the lock before modifying self.next_send + # State #2: We are at the end of self.results, and need to call our underlying + # iterator. Only one call may enter this state due to our check of + # self.generating above. + if i == len(self.results) and not self.generating: + self.generating = True + next_send = self.next_send + listening = False + else: + assert i == len(self.results) - 1 or self.generating + listening = True + # We break outside the lock to either listen or kick off a new generation. + if listening: + next_send = yield self.results[i] + i += 1 + with self.lock: + if not self.finished and i == len(self.results): + self.next_send = next_send + continue + # We must be in generating state + assert self.generator is not None + try: + result = self.generator.send(next_send) + except StopIteration: + # This lock should be unnecessary, which by definition means there should be no + # contention on it, so we use it to preserve our assumptions about variables which + # are modified under lock. + with self.lock: self.finished = True self.generator = None # Allow this to be GCed. - return - else: - i += 1 - self.next_send = yield self.results[-1] + self.generating = False + return + with self.lock: + self.results.append(result) + self.generating = False diff --git a/once_test.py b/once_test.py index b30fa6f..323e264 100644 --- a/once_test.py +++ b/once_test.py @@ -156,13 +156,30 @@ def test_different_args_same_result(self): self.assertEqual(counter.value, 1) def test_iterator(self): + counter = Counter() + @once.once def yielding_iterator(): - for i in range(3): - yield i + nonlocal counter + for _ in range(3): + yield counter.get_incremented() + + self.assertEqual(list(yielding_iterator()), [1, 2, 3]) + self.assertEqual(list(yielding_iterator()), [1, 2, 3]) + + def test_iterator_parallel_execution(self): + counter = Counter() - self.assertEqual(list(yielding_iterator()), [0, 1, 2]) - self.assertEqual(list(yielding_iterator()), [0, 1, 2]) + @once.once + def yielding_iterator(): + nonlocal counter + for _ in range(3): + yield counter.get_incremented() + + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + results = list(executor.map(lambda _: list(yielding_iterator()), range(32))) + for result in results: + self.assertEqual(result, [1, 2, 3]) def test_threaded_single_function(self): counting_fn, counter = generate_once_counter_fn() @@ -439,6 +456,68 @@ def value(): self.assertEqual(_CallOnceClass.value(), 1) self.assertEqual(_CallOnceClass.value(), 1) + def test_receiving_iterator(self): + @once.once + def receiving_iterator(): + next = yield 0 + while next is not None: + next = yield next + + gen_1 = receiving_iterator() + gen_2 = receiving_iterator() + self.assertEqual(gen_1.send(None), 0) + self.assertEqual(gen_1.send(1), 1) + self.assertEqual(gen_1.send(2), 2) + self.assertEqual(gen_2.send(None), 0) + self.assertEqual(gen_2.send(-1), 1) + self.assertEqual(gen_2.send(-1), 2) + self.assertEqual(gen_2.send(5), 5) + self.assertEqual(next(gen_2, None), None) + self.assertEqual(gen_1.send(None), 5) + self.assertEqual(next(gen_1, None), None) + self.assertEqual(list(receiving_iterator()), [0, 1, 2, 5]) + + def test_receiving_iterator_parallel_execution(self): + @once.once + def receiving_iterator(): + next = yield 0 + while next is not None: + next = yield next + + def call_iterator(_): + gen = receiving_iterator() + result = [] + result.append(gen.send(None)) + for i in range(1, 32): + result.append(gen.send(i)) + return result + + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + results = list(executor.map(call_iterator, range(32))) + for result in results: + self.assertEqual(result, list(range(32))) + + def test_receiving_iterator_parallel_execution_halting(self): + @once.once + def receiving_iterator(): + next = yield 0 + while next is not None: + next = yield next + + def call_iterator(n): + """Call the iterator but end early""" + gen = receiving_iterator() + result = [] + result.append(gen.send(None)) + for i in range(1, n): + result.append(gen.send(i)) + return result + + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + results = list(executor.map(call_iterator, range(1, 32))) + for i, result in enumerate(results): + self.assertEqual(result, list(range(i + 1))) + class TestOnceAsync(unittest.IsolatedAsyncioTestCase): async def test_fn_called_once(self): @@ -537,22 +616,23 @@ async def async_yielding_iterator(): async def test_receiving_iterator(self): @once.once async def async_receiving_iterator(): - next = yield 1 + next = yield 0 while next is not None: next = yield next gen_1 = async_receiving_iterator() gen_2 = async_receiving_iterator() - self.assertEqual(await gen_1.asend(None), 1) + self.assertEqual(await gen_1.asend(None), 0) self.assertEqual(await gen_1.asend(1), 1) - self.assertEqual(await gen_1.asend(3), 3) - self.assertEqual(await gen_2.asend(None), 1) + self.assertEqual(await gen_1.asend(2), 2) + self.assertEqual(await gen_2.asend(None), 0) self.assertEqual(await gen_2.asend(None), 1) - self.assertEqual(await gen_2.asend(None), 3) + self.assertEqual(await gen_2.asend(None), 2) self.assertEqual(await gen_2.asend(5), 5) self.assertEqual(await anext(gen_2, None), None) self.assertEqual(await gen_1.asend(None), 5) self.assertEqual(await anext(gen_1, None), None) + self.assertEqual([i async for i in async_receiving_iterator()], [0, 1, 2, 5]) @unittest.skipIf(not hasattr(asyncio, "Barrier"), "Requires Barrier to evaluate") async def test_iterator_lock_not_held_during_evaluation(self): From 9df9ecd4ac2123bdcb9c63189210520058e6e85e Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Tue, 10 Oct 2023 14:25:38 -0700 Subject: [PATCH 3/7] Make yield_results a public method. Response to comment from #10 --- once/__init__.py | 6 +++--- once/_iterator_wrappers.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/once/__init__.py b/once/__init__.py index d6298b6..6838ba5 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -88,16 +88,16 @@ async def _execute_call_once_async(self, func: collections.abc.Callable, *args, # This cannot be an async function! def _execute_call_once_async_iter(self, func: collections.abc.Callable, *args, **kwargs): if self.called: - return self.return_value._yield_results() + return self.return_value.yield_results() with self.lock: if not self.called: self.called = True self.return_value = _iterator_wrappers.AsyncGeneratorWrapper(func, *args, **kwargs) - return self.return_value._yield_results() + return self.return_value.yield_results() def _sync_return(self): if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR: - return self.return_value._yield_results().__iter__() + return self.return_value.yield_results().__iter__() else: return self.return_value diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py index 175535b..0ac77d2 100644 --- a/once/_iterator_wrappers.py +++ b/once/_iterator_wrappers.py @@ -29,7 +29,7 @@ def __init__(self, func, *args, **kwargs) -> None: self.generating = False self.lock = asyncio.Lock() - async def _yield_results(self) -> collections.abc.AsyncGenerator: + async def yield_results(self) -> collections.abc.AsyncGenerator: i = 0 send = None next_val = None @@ -106,7 +106,7 @@ def __init__(self, func, *args, **kwargs) -> None: self.lock = threading.Lock() self.next_send = None - def _yield_results(self) -> collections.abc.Generator: + def yield_results(self) -> collections.abc.Generator: i = 0 # Fast path for subsequent calls will not require a lock while True: From cf63c616038e3c982cf4c58bafc56fcef9483d3d Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Tue, 10 Oct 2023 17:47:16 -0700 Subject: [PATCH 4/7] Sync iterator wrapper has fewer possible states. We use an enum to thoroughly document the possible states the yielding function can be in, making the code a lot more readable IMHO! --- once/_iterator_wrappers.py | 120 ++++++++++++++++--------------------- 1 file changed, 53 insertions(+), 67 deletions(-) diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py index 0ac77d2..8138ee7 100644 --- a/once/_iterator_wrappers.py +++ b/once/_iterator_wrappers.py @@ -1,5 +1,6 @@ import asyncio import collections.abc +import enum import threading # Before we begin, a note on the assert statements in this file: @@ -91,6 +92,15 @@ async def yield_results(self) -> collections.abc.AsyncGenerator: i += 1 +class _IteratorAction(enum.Enum): + # Generating the next value from the underlying iterator + GENERATING = 1 + # Yield an already computed value + YIELDING = 2 + # Waiting for the underlying iterator, already triggered from another call. + WAITING = 3 + + class GeneratorWrapper: """Wrapper around an sync generator which only runs once. @@ -107,80 +117,56 @@ def __init__(self, func, *args, **kwargs) -> None: self.next_send = None def yield_results(self) -> collections.abc.Generator: + # Fast path for subsequent repeated call: + with self.lock: + finished = self.finished + if finished: + yield from self.results + return i = 0 + yield_value = None + next_send = None # Fast path for subsequent calls will not require a lock while True: - # If we on before the penultimate entry, we can return now. When yielding the last - # element of results, we need to be recording next_send, so that needs the lock. - if i < len(self.results) - 1: - yield self.results[i] - i += 1 - continue - # Because we don't hold a lock here, we can't make this assumption - # i == len(self.results) - 1 or i == len(self.results) - # because the iterator could have moved in the interim. However, it will no longer - # move once self.finished. - if self.finished: - if i < len(self.results): - yield self.results[i] - i += 1 - continue - if i == len(self.results): - return - - # Initial calls, and concurrent calls before completion will require the lock. + action: _IteratorAction | None = None + # With a lock, we figure out which action to take, and then we take it after release. with self.lock: - # Just in case a race condition prevented us from hitting these conditions before, - # check them again, so they can be handled by the code before the lock. - if i < len(self.results) - 1: - continue - if self.finished: - if i < len(self.results): - continue - if i == len(self.results): + if i == len(self.results): + if self.finished: return - assert i == len(self.results) - 1 or i == len(self.results) - # If we are at the end and waiting for the generator to complete, there is nothing - # to do! - if self.generating and i == len(self.results): - continue - - # At this point, there are 2 states to handle, which we will want to do outside the - # lock to avoid deadlocks. - # State #1: We are about to yield back the last entry in self.results and potentially - # log next send. We can allow multiple calls to enter this state, as long - # as we re-grab the lock before modifying self.next_send - # State #2: We are at the end of self.results, and need to call our underlying - # iterator. Only one call may enter this state due to our check of - # self.generating above. - if i == len(self.results) and not self.generating: - self.generating = True - next_send = self.next_send - listening = False + if self.generating: + action = _IteratorAction.WAITING + else: + action = _IteratorAction.GENERATING + next_send = self.next_send + self.generating = True else: - assert i == len(self.results) - 1 or self.generating - listening = True - # We break outside the lock to either listen or kick off a new generation. - if listening: - next_send = yield self.results[i] + action = _IteratorAction.YIELDING + yield_value = self.results[i] + if action == _IteratorAction.WAITING: + continue + if action == _IteratorAction.YIELDING: + next_send = yield yield_value i += 1 + # If we have just sent the last element and we have not yet kicked off the next + # iteration, we need to record the next send value. with self.lock: - if not self.finished and i == len(self.results): + if i == len(self.results) and not self.generating: self.next_send = next_send continue - # We must be in generating state - assert self.generator is not None - try: - result = self.generator.send(next_send) - except StopIteration: - # This lock should be unnecessary, which by definition means there should be no - # contention on it, so we use it to preserve our assumptions about variables which - # are modified under lock. - with self.lock: - self.finished = True - self.generator = None # Allow this to be GCed. - self.generating = False - return - with self.lock: - self.results.append(result) - self.generating = False + if action == _IteratorAction.GENERATING: + assert self.generator is not None + try: + result = self.generator.send(next_send) + except StopIteration: + # This lock should be unnecessary, which by definition means there should be no + # contention on it, so we use it to preserve our assumptions about variables which + # are modified under lock. + with self.lock: + self.finished = True + self.generator = None # Allow this to be GCed. + self.generating = False + else: + with self.lock: + self.generating = False + self.results.append(result) From eb1aaf23c257f7f8f233913deffc8ad4f29eaa5c Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Wed, 11 Oct 2023 15:22:45 -0700 Subject: [PATCH 5/7] Add readability improvements per review. --- once/_iterator_wrappers.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py index 8138ee7..f297200 100644 --- a/once/_iterator_wrappers.py +++ b/once/_iterator_wrappers.py @@ -154,19 +154,17 @@ def yield_results(self) -> collections.abc.Generator: if i == len(self.results) and not self.generating: self.next_send = next_send continue - if action == _IteratorAction.GENERATING: - assert self.generator is not None - try: - result = self.generator.send(next_send) - except StopIteration: - # This lock should be unnecessary, which by definition means there should be no - # contention on it, so we use it to preserve our assumptions about variables which - # are modified under lock. - with self.lock: - self.finished = True - self.generator = None # Allow this to be GCed. - self.generating = False - else: - with self.lock: - self.generating = False - self.results.append(result) + if action != _IteratorAction.GENERATING: + continue + assert self.generator is not None + try: + result = self.generator.send(next_send) + except StopIteration: + with self.lock: + self.finished = True + self.generator = None # Allow this to be GCed. + self.generating = False + else: + with self.lock: + self.generating = False + self.results.append(result) From 1acccbc27b9927eddcf79248879dd4edbea9a641 Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Wed, 11 Oct 2023 15:23:49 -0700 Subject: [PATCH 6/7] Make next_send execution local. It is no longer set on the object, per discussion. --- once/_iterator_wrappers.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py index f297200..9106d1f 100644 --- a/once/_iterator_wrappers.py +++ b/once/_iterator_wrappers.py @@ -114,7 +114,6 @@ def __init__(self, func, *args, **kwargs) -> None: self.results: list = [] self.generating = False self.lock = threading.Lock() - self.next_send = None def yield_results(self) -> collections.abc.Generator: # Fast path for subsequent repeated call: @@ -138,7 +137,6 @@ def yield_results(self) -> collections.abc.Generator: action = _IteratorAction.WAITING else: action = _IteratorAction.GENERATING - next_send = self.next_send self.generating = True else: action = _IteratorAction.YIELDING @@ -148,11 +146,6 @@ def yield_results(self) -> collections.abc.Generator: if action == _IteratorAction.YIELDING: next_send = yield yield_value i += 1 - # If we have just sent the last element and we have not yet kicked off the next - # iteration, we need to record the next send value. - with self.lock: - if i == len(self.results) and not self.generating: - self.next_send = next_send continue if action != _IteratorAction.GENERATING: continue From f13164dffaf16be7b40a52005fda36aaf0fe2c22 Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Wed, 11 Oct 2023 16:35:46 -0700 Subject: [PATCH 7/7] Address PR comments. --- once/_iterator_wrappers.py | 3 +-- once_test.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py index 9106d1f..4ae533f 100644 --- a/once/_iterator_wrappers.py +++ b/once/_iterator_wrappers.py @@ -147,8 +147,7 @@ def yield_results(self) -> collections.abc.Generator: next_send = yield yield_value i += 1 continue - if action != _IteratorAction.GENERATING: - continue + assert action == _IteratorAction.GENERATING assert self.generator is not None try: result = self.generator.send(next_send) diff --git a/once_test.py b/once_test.py index 323e264..10fe922 100644 --- a/once_test.py +++ b/once_test.py @@ -513,6 +513,8 @@ def call_iterator(n): result.append(gen.send(i)) return result + # Unlike the previous test, each execution should yield lists of different lengths. + # This ensures that the iterator does not hang, even if not exhausted with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: results = list(executor.map(call_iterator, range(1, 32))) for i, result in enumerate(results):