Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major refactor to allow lazy sync iterators. #10

Merged
merged 7 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 57 additions & 94 deletions once/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import abc
import asyncio
import collections.abc
import enum
import functools
import inspect
import threading
import typing
import weakref

from . import _iterator_wrappers


def _new_lock() -> threading.Lock:
return threading.Lock()
Expand All @@ -21,30 +24,42 @@ def _is_method(func: collections.abc.Callable):
return "self" in sig.parameters


class _WrappedFunctionType(enum.Enum):
UNSUPPORTED = 0
SYNC_FUNCTION = 1
ASYNC_FUNCTION = 2
SYNC_GENERATOR = 3
ASYNC_GENERATOR = 4


def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionType:
# The function inspect.isawaitable is a bit of a misnomer - it refers
# to the awaitable result of an async function, not the async function
# itself.
if inspect.isasyncgenfunction(func):
return _WrappedFunctionType.ASYNC_GENERATOR
if inspect.isgeneratorfunction(func):
return _WrappedFunctionType.SYNC_GENERATOR
if inspect.iscoroutinefunction(func):
return _WrappedFunctionType.ASYNC_FUNCTION
# This must come last, because it would return True for all the other types
if inspect.isfunction(func):
return _WrappedFunctionType.SYNC_FUNCTION
return _WrappedFunctionType.UNSUPPORTED


class _OnceBase(abc.ABC):
"""Abstract Base Class for once function decorators."""

def __init__(self, func: collections.abc.Callable):
def __init__(self, func: collections.abc.Callable) -> None:
functools.update_wrapper(self, func)
self.func = self._inspect_function(func)
self.called = False
self.return_value: typing.Any = None

self.is_asyncgen = inspect.isasyncgenfunction(self.func)
if self.is_asyncgen:
self.asyncgen_finished = False
self.asyncgen_generator = None
self.asyncgen_results: list = []
self.async_generating = False

# Only works for one way generators, not anything that requires send for now.
# Async generators do support send.
self.is_syncgen = inspect.isgeneratorfunction(self.func)
# The function inspect.isawaitable is a bit of a misnomer - it refers
# to the awaitable result of an async function, not the async function
# itself.
self.is_async = True if self.is_asyncgen else inspect.iscoroutinefunction(self.func)
if self.is_async:
self.fn_type = _wrapped_function_type(self.func)
if self.fn_type == _WrappedFunctionType.UNSUPPORTED:
raise SyntaxError(f"Unable to wrap a {type(func)}")
if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION:
self.async_lock = asyncio.Lock()
else:
self.lock = _new_lock()
Expand All @@ -59,74 +74,6 @@ def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.C
It should return the function which should be executed once.
"""

def _sync_return(self):
if self.is_syncgen:
self.return_value.__iter__()
return self.return_value

async def _async_gen_proxy(self, func, *args, **kwargs):
i = 0
send = None
next_val = None

# A copy of self.async_generating that we can access outside of the lock.
async_generating = None

# Indicates that we're tied for the head generator, but someone started generating the next
# result first, so we should just poll until the result is available.
waiting_for_async_generating = False

while True:
if waiting_for_async_generating:
# This is a load bearing sleep. We're waiting for the leader to generate the result, but
# we have control of the lock, so the async with will never yield execution to the event loop,
# so we would loop forever. By awaiting sleep(0), we yield execution which will allow us to
# poll for self.async_generating readiness.
await asyncio.sleep(0)
waiting_for_async_generating = False
async with self.async_lock:
if self.asyncgen_generator is None and not self.asyncgen_finished:
# We're the first! Do some setup.
self.asyncgen_generator = func(*args, **kwargs)

if i == len(self.asyncgen_results) and not self.asyncgen_finished:
if self.async_generating:
# We're at the lead, but someone else is generating the next value
# so we just hop back onto the next iteration of the loop
# until it's ready.
waiting_for_async_generating = True
continue
# We're at the lead and no one else is generating, so we need to increment
# the iterator. We just store the value in self.asyncgen_results so that
# we can later yield it outside of the lock.
self.async_generating = self.asyncgen_generator.asend(send)
async_generating = self.async_generating
elif i == len(self.asyncgen_results) and self.asyncgen_finished:
# All done.
return
else:
# We already have the correct result, so we grab it here to
# yield it outside the lock.
next_val = self.asyncgen_results[i]

if async_generating:
try:
next_val = await async_generating
except StopAsyncIteration:
async with self.async_lock:
self.asyncgen_generator = None # Allow this to be GCed.
self.asyncgen_finished = True
self.async_generating = None
async_generating = None
return
async with self.async_lock:
self.asyncgen_results.append(next_val)
async_generating = None
self.async_generating = None

send = yield next_val
i += 1

async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value
Expand All @@ -138,24 +85,40 @@ async def _execute_call_once_async(self, func: collections.abc.Callable, *args,
self.called = True
return self.return_value

# This cannot be an async function!
def _execute_call_once_async_iter(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value._yield_results()
with self.lock:
if not self.called:
self.called = True
self.return_value = _iterator_wrappers.AsyncGeneratorWrapper(func, *args, **kwargs)
return self.return_value._yield_results()

def _sync_return(self):
if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR:
return self.return_value._yield_results().__iter__()
else:
return self.return_value

def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self._sync_return()
with self.lock:
if self.called:
return self._sync_return()
self.return_value = func(*args, **kwargs)
if self.is_syncgen:
# A potential optimization is to evaluate the iterator lazily,
# as opposed to eagerly like we do here.
self.return_value = tuple(self.return_value)
if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR:
self.return_value = _iterator_wrappers.GeneratorWrapper(func, *args, **kwargs)
else:
self.return_value = func(*args, **kwargs)
self.called = True
return self._sync_return()

def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs):
if self.is_asyncgen:
return self._async_gen_proxy(func, *args, **kwargs)
if self.is_async:
"""Choose the appropriate call_once based on the function type."""
if self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR:
return self._execute_call_once_async_iter(func, *args, **kwargs)
if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION:
return self._execute_call_once_async(func, *args, **kwargs)
return self._execute_call_once_sync(func, *args, **kwargs)

Expand Down Expand Up @@ -229,7 +192,7 @@ def __get__(self, obj, cls):
class once_per_instance(_OnceBase): # pylint: disable=invalid-name
"""A version of once for class methods which runs once per instance."""

def __init__(self, func: collections.abc.Callable):
def __init__(self, func: collections.abc.Callable) -> None:
super().__init__(func)
self.return_value: weakref.WeakKeyDictionary[
typing.Any, typing.Any
Expand Down
186 changes: 186 additions & 0 deletions once/_iterator_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import asyncio
import collections.abc
import threading

# Before we begin, a note on the assert statements in this file:
# Why are we using assert in here, you might ask, instead of implementing "proper" error handling?
# In this case, it is actually not being done out of laziness! The assert statements here
# represent our assumptions about the state at that point in time, and are always called with locks
# held, so they **REALLY** should always hold. If the assumption behind one of these asserts fails,
# the subsequent calls are going to fail anyways, so it's not like they are making the code
# artificially brittle. However, they do make testing easer, because we can directly test our
# assumption instead of having hard-to-trace errors, and also serve as very convenient
# documentation of the assumptions.
# We are always open to suggestions if there are other ways to achieve the same functionality in
# python!


class AsyncGeneratorWrapper:
"""Wrapper around an async generator which only runs once.

Subsequent calls will return results from the first call, which is
evaluated lazily.
"""

def __init__(self, func, *args, **kwargs) -> None:
self.generator: collections.abc.AsyncGenerator | None = func(*args, **kwargs)
self.finished = False
self.results: list = []
self.generating = False
self.lock = asyncio.Lock()

async def _yield_results(self) -> collections.abc.AsyncGenerator:
aebrahim marked this conversation as resolved.
Show resolved Hide resolved
i = 0
send = None
next_val = None

# A copy of self.generating that we can access outside of the lock.
generating = None

# Indicates that we're tied for the head generator, but someone started generating the next
# result first, so we should just poll until the result is available.
waiting_for_generating = False

while True:
if waiting_for_generating:
# This is a load bearing sleep. We're waiting for the leader to generate the result, but
# we have control of the lock, so the async with will never yield execution to the event loop,
# so we would loop forever. By awaiting sleep(0), we yield execution which will allow us to
# poll for self.generating readiness.
await asyncio.sleep(0)
waiting_for_generating = False
async with self.lock:
if i == len(self.results) and not self.finished:
if self.generating:
# We're at the lead, but someone else is generating the next value
# so we just hop back onto the next iteration of the loop
# until it's ready.
waiting_for_generating = True
continue
# We're at the lead and no one else is generating, so we need to increment
# the iterator. We just store the value in self.results so that
# we can later yield it outside of the lock.
assert self.generator is not None
# TODO(matt): Is the fact that we have to suppress typing here a bug?
self.generating = self.generator.asend(send) # type: ignore
generating = self.generating
elif i == len(self.results) and self.finished:
# All done.
return
else:
# We already have the correct result, so we grab it here to
# yield it outside the lock.
next_val = self.results[i]

if generating:
try:
next_val = await generating
except StopAsyncIteration:
async with self.lock:
self.generator = None # Allow this to be GCed.
self.finished = True
self.generating = None
generating = None
return
async with self.lock:
self.results.append(next_val)
generating = None
self.generating = None

send = yield next_val
i += 1


class GeneratorWrapper:
"""Wrapper around an sync generator which only runs once.

Subsequent calls will return results from the first call, which is
evaluated lazily.
"""

def __init__(self, func, *args, **kwargs) -> None:
self.generator: collections.abc.Generator | None = func(*args, **kwargs)
self.finished = False
self.results: list = []
self.generating = False
self.lock = threading.Lock()
self.next_send = None

def _yield_results(self) -> collections.abc.Generator:
i = 0
# Fast path for subsequent calls will not require a lock
while True:
# If we on before the penultimate entry, we can return now. When yielding the last
# element of results, we need to be recording next_send, so that needs the lock.
if i < len(self.results) - 1:
yield self.results[i]
aebrahim marked this conversation as resolved.
Show resolved Hide resolved
i += 1
continue
# Because we don't hold a lock here, we can't make this assumption
# i == len(self.results) - 1 or i == len(self.results)
# because the iterator could have moved in the interim. However, it will no longer
# move once self.finished.
if self.finished:
if i < len(self.results):
yield self.results[i]
i += 1
continue
if i == len(self.results):
return
aebrahim marked this conversation as resolved.
Show resolved Hide resolved

# Initial calls, and concurrent calls before completion will require the lock.
with self.lock:
# Just in case a race condition prevented us from hitting these conditions before,
# check them again, so they can be handled by the code before the lock.
if i < len(self.results) - 1:
continue
if self.finished:
if i < len(self.results):
continue
if i == len(self.results):
return
assert i == len(self.results) - 1 or i == len(self.results)
# If we are at the end and waiting for the generator to complete, there is nothing
# to do!
if self.generating and i == len(self.results):
continue

# At this point, there are 2 states to handle, which we will want to do outside the
# lock to avoid deadlocks.
# State #1: We are about to yield back the last entry in self.results and potentially
# log next send. We can allow multiple calls to enter this state, as long
# as we re-grab the lock before modifying self.next_send
# State #2: We are at the end of self.results, and need to call our underlying
# iterator. Only one call may enter this state due to our check of
# self.generating above.
if i == len(self.results) and not self.generating:
self.generating = True
next_send = self.next_send
listening = False
else:
assert i == len(self.results) - 1 or self.generating
listening = True
# We break outside the lock to either listen or kick off a new generation.
if listening:
next_send = yield self.results[i]
i += 1
with self.lock:
if not self.finished and i == len(self.results):
self.next_send = next_send
continue
# We must be in generating state
assert self.generator is not None
try:
result = self.generator.send(next_send)
except StopIteration:
# This lock should be unnecessary, which by definition means there should be no
# contention on it, so we use it to preserve our assumptions about variables which
# are modified under lock.
with self.lock:
self.finished = True
self.generator = None # Allow this to be GCed.
self.generating = False
return
with self.lock:
self.results.append(result)
self.generating = False
Loading