Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle exceptions and preserve function type. #11

Merged
merged 12 commits into from
Oct 16, 2023
276 changes: 147 additions & 129 deletions once/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
"""Utility for initialization ensuring functions are called only once."""
import abc
import asyncio
import collections.abc
import enum
import functools
import inspect
import time
import threading
import typing
import weakref

from . import _iterator_wrappers


def _new_lock() -> threading.Lock:
return threading.Lock()


def _is_method(func: collections.abc.Callable):
"""Determine if a function is a method on a class."""
if isinstance(func, (classmethod, staticmethod)):
Expand All @@ -25,110 +21,163 @@ def _is_method(func: collections.abc.Callable):


class _WrappedFunctionType(enum.Enum):
UNSUPPORTED = 0
SYNC_FUNCTION = 1
ASYNC_FUNCTION = 2
SYNC_GENERATOR = 3
ASYNC_GENERATOR = 4
SYNC_FUNCTION = 0
ASYNC_FUNCTION = 1
SYNC_GENERATOR = 2
ASYNC_GENERATOR = 3


def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionType:
# The function inspect.isawaitable is a bit of a misnomer - it refers
# to the awaitable result of an async function, not the async function
# itself.
original_func = func
while isinstance(func, functools.partial):
# Work around inspect not functioning properly in python < 3.10 for partial functions.
func = func.func
if inspect.isasyncgenfunction(func):
return _WrappedFunctionType.ASYNC_GENERATOR
if inspect.isgeneratorfunction(func):
return _WrappedFunctionType.SYNC_GENERATOR
if inspect.iscoroutinefunction(func):
return _WrappedFunctionType.ASYNC_FUNCTION
# This must come last, because it would return True for all the other types
if inspect.isfunction(func):
return _WrappedFunctionType.SYNC_FUNCTION
return _WrappedFunctionType.UNSUPPORTED
raise SyntaxError(f"Unable to determine function type for {repr(original_func)}")


class _OnceBase(abc.ABC):
"""Abstract Base Class for once function decorators."""
class _ExecutionState(enum.Enum):
UNCALLED = 0
WAITING = 1
COMPLETED = 2

def __init__(self, func: collections.abc.Callable) -> None:
functools.update_wrapper(self, func)
self.func = self._inspect_function(func)
self.called = False

class _OnceBase:
def __init__(self, fn_type: _WrappedFunctionType) -> None:
self.call_state = _ExecutionState.UNCALLED
self.return_value: typing.Any = None
self.fn_type = _wrapped_function_type(self.func)
if self.fn_type == _WrappedFunctionType.UNSUPPORTED:
raise SyntaxError(f"Unable to wrap a {type(func)}")
if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION:
self.fn_type = fn_type
if (
self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION
or self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR
):
self.async_lock = asyncio.Lock()
else:
self.lock = _new_lock()
self.lock = threading.Lock()

@abc.abstractmethod
def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable:
"""Inspect the passed-in function to ensure it can be wrapped.
def _callable(self, func: collections.abc.Callable):
"""Generate a wrapped function appropriate to the function type.

This function should raise a SyntaxError if the passed-in function is
not suitable.

It should return the function which should be executed once.
This wrapped function will call the correct _execute_call_once function.
"""
if self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR:

async def wrapped(*args, **kwargs):
aebrahim marked this conversation as resolved.
Show resolved Hide resolved
next_value = None
iterator = self._execute_call_once_async_iter(func, *args, **kwargs)
while True:
try:
next_value = yield await iterator.asend(next_value)
except StopAsyncIteration:
return

elif self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION:

async def wrapped(*args, **kwargs):
return await self._execute_call_once_async(func, *args, **kwargs)

elif self.fn_type == _WrappedFunctionType.SYNC_FUNCTION:

def wrapped(*args, **kwargs):
return self._execute_call_once_sync(func, *args, **kwargs)

else:
assert self.fn_type == _WrappedFunctionType.SYNC_GENERATOR
aebrahim marked this conversation as resolved.
Show resolved Hide resolved

def wrapped(*args, **kwargs):
yield from self._execute_call_once_sync_iter(func, *args, **kwargs)

functools.update_wrapper(wrapped, func)
return wrapped

async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value
async with self.async_lock:
if self.called:
return self.return_value
else:
self.return_value = await func(*args, **kwargs)
self.called = True
return self.return_value

# This cannot be an async function!
def _execute_call_once_async_iter(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value.yield_results()
with self.lock:
if not self.called:
self.called = True
call_state = self.call_state
aebrahim marked this conversation as resolved.
Show resolved Hide resolved
while call_state != _ExecutionState.COMPLETED:
if call_state == _ExecutionState.WAITING:
# Allow another thread to grab the GIL.
await asyncio.sleep(0)
async with self.async_lock:
call_state = self.call_state
if call_state == _ExecutionState.UNCALLED:
self.call_state = _ExecutionState.WAITING
# Only one thread will be allowed into this state.
if call_state == _ExecutionState.UNCALLED:
try:
return_value = await func(*args, **kwargs)
except Exception as exc:
async with self.async_lock:
self.call_state = _ExecutionState.UNCALLED
raise exc
async with self.async_lock:
self.return_value = return_value
self.call_state = _ExecutionState.COMPLETED
return self.return_value

async def _execute_call_once_async_iter(self, func: collections.abc.Callable, *args, **kwargs):
aebrahim marked this conversation as resolved.
Show resolved Hide resolved
aebrahim marked this conversation as resolved.
Show resolved Hide resolved
async with self.async_lock:
if self.call_state == _ExecutionState.UNCALLED:
self.return_value = _iterator_wrappers.AsyncGeneratorWrapper(func, *args, **kwargs)
return self.return_value.yield_results()

def _sync_return(self):
if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR:
return self.return_value.yield_results().__iter__()
else:
return self.return_value
self.call_state = _ExecutionState.COMPLETED
next_value = None
iterator = self.return_value.yield_results()
while True:
try:
next_value = yield await iterator.asend(next_value)
except StopAsyncIteration:
return

def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self._sync_return()
with self.lock:
if self.called:
return self._sync_return()
if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR:
call_state = self.call_state
while call_state != _ExecutionState.COMPLETED:
# We only hit this state in multi-threded code. To reduce contention, we invoke
# time.sleep so another thread an pick up the GIL.
if call_state == _ExecutionState.WAITING:
time.sleep(0)
with self.lock:
call_state = self.call_state
if call_state == _ExecutionState.UNCALLED:
self.call_state = _ExecutionState.WAITING
# Only one thread will be allowed into this state.
if call_state == _ExecutionState.UNCALLED:
try:
return_value = func(*args, **kwargs)
except Exception as exc:
with self.lock:
self.call_state = _ExecutionState.UNCALLED
raise exc
else:
with self.lock:
self.return_value = return_value
self.call_state = _ExecutionState.COMPLETED
return self.return_value

def _execute_call_once_sync_iter(self, func: collections.abc.Callable, *args, **kwargs):
with self.lock:
if self.call_state == _ExecutionState.UNCALLED:
self.return_value = _iterator_wrappers.GeneratorWrapper(func, *args, **kwargs)
else:
self.return_value = func(*args, **kwargs)
self.called = True
return self._sync_return()
self.call_state = _ExecutionState.COMPLETED
yield from self.return_value.yield_results()

def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs):
"""Choose the appropriate call_once based on the function type."""
if self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR:
return self._execute_call_once_async_iter(func, *args, **kwargs)
if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION:
return self._execute_call_once_async(func, *args, **kwargs)
return self._execute_call_once_sync(func, *args, **kwargs)


class once(_OnceBase): # pylint: disable=invalid-name
def once(func: collections.abc.Callable):
"""Decorator to ensure a function is only called once.

The restriction of only one call also holds across threads. However, this
restriction does not apply to unsuccessful function calls. If the function
raises an exception, the next call will invoke a new call to the function.
raises an exception, the next call will invoke a new call to the function,
unless it is in iterator, in which case the failure will be cached.
If the function is called with multiple arguments, it will still only be
called only once.

Expand All @@ -141,17 +190,13 @@ class once(_OnceBase): # pylint: disable=invalid-name
module and class level functions (i.e. non-closures), this means the return
value will never be deleted.
"""

def _inspect_function(self, func: collections.abc.Callable):
if _is_method(func):
raise SyntaxError(
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
return func

def __call__(self, *args, **kwargs):
return self._execute_call_once(self.func, *args, **kwargs)
if _is_method(func):
raise SyntaxError(
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
once_obj = _OnceBase(_wrapped_function_type(func))
return once_obj._callable(func)


class once_per_class(_OnceBase): # pylint: disable=invalid-name
Expand All @@ -160,6 +205,10 @@ class once_per_class(_OnceBase): # pylint: disable=invalid-name
is_classmethod: bool
is_staticmethod: bool

def __init__(self, func: collections.abc.Callable) -> None:
self.func = self._inspect_function(func)
super().__init__(_wrapped_function_type(self.func))

def _inspect_function(self, func: collections.abc.Callable):
if not _is_method(func):
raise SyntaxError(
Expand All @@ -182,26 +231,24 @@ def _inspect_function(self, func: collections.abc.Callable):
# bound version of the function to the object or class.
def __get__(self, obj, cls):
if self.is_classmethod:
func = functools.partial(self.func, cls)
return functools.partial(self._execute_call_once, func)
return self._callable(functools.partial(self.func, cls))
if self.is_staticmethod:
return functools.partial(self._execute_call_once, self.func)
return functools.partial(self._execute_call_once, self.func, obj)
return self._callable(self.func)
return self._callable(functools.partial(self.func, obj))
aebrahim marked this conversation as resolved.
Show resolved Hide resolved


class once_per_instance(_OnceBase): # pylint: disable=invalid-name
"""A version of once for class methods which runs once per instance."""

def __init__(self, func: collections.abc.Callable) -> None:
super().__init__(func)
self.return_value: weakref.WeakKeyDictionary[
typing.Any, typing.Any
self.func = self._inspect_function(func)
super().__init__(_wrapped_function_type(self.func))
self.callables_lock = threading.Lock()
self.callables: weakref.WeakKeyDictionary[
typing.Any, collections.abc.Callable
aebrahim marked this conversation as resolved.
Show resolved Hide resolved
] = weakref.WeakKeyDictionary()
self.inflight_lock: typing.Dict[typing.Any, threading.Lock] = {}

def _inspect_function(self, func: collections.abc.Callable):
if inspect.isasyncgenfunction(func) or inspect.iscoroutinefunction(func):
raise SyntaxError("once_per_instance not (yet) supported for async")
if isinstance(func, (classmethod, staticmethod)):
raise SyntaxError("Must use @once.once_per_class on classmethod and staticmethod")
if not _is_method(func):
Expand All @@ -215,39 +262,10 @@ def _inspect_function(self, func: collections.abc.Callable):
# bound version of the function to the object.
def __get__(self, obj, cls):
del cls
return functools.partial(self._execute_call_once_per_instance, obj)

def _execute_call_once_per_instance(self, obj, *args, **kwargs):
# We only append to the call history, and do not overwrite or remove keys.
# Therefore, we can check the call history without a lock for an early
# exit.
# Another concern might be the weakref dictionary for return_value
# getting garbage collected without a lock. However, because
# user_function references whichever key it matches, it cannot be
# garbage collected during this call.
if obj in self.return_value:
return self.return_value[obj]
with self.lock:
if obj in self.return_value:
return self.return_value[obj]
if obj in self.inflight_lock:
inflight_lock = self.inflight_lock[obj]
else:
inflight_lock = _new_lock()
self.inflight_lock[obj] = inflight_lock
# Now we have a per-object lock. This means that we will not block
# other instances. In addition to better performance, this reduces the
# potential for deadlocks.
with inflight_lock:
if obj in self.return_value:
return self.return_value[obj]
result = self.func(obj, *args, **kwargs)
self.return_value[obj] = result
# At this point, any new call will find a cache hit before
# even grabbing a lock. It is now safe to clean up the inflight
# lock entry from the dictionary, as all subsequent will not need
# it. Any other previously called inflight requests already have
# their reference to the lock object, and do not need it present
# in this dict either.
self.inflight_lock.pop(obj)
return result
with self.callables_lock:
if callable := self.callables.get(obj):
return callable
once_obj = _OnceBase(self.fn_type)
callable = once_obj._callable(functools.partial(self.func, obj))
self.callables[obj] = callable
return callable
Loading