Skip to content

Commit

Permalink
Handle exceptions and preserve function type.
Browse files Browse the repository at this point in the history
We add explicit exception handling, and the decorated function now will
look like the original function to inspect (for example, an async
function will have inspect.iscoroutinefunction evaluate to True.
  • Loading branch information
aebrahim committed Oct 12, 2023
1 parent 3168628 commit 6cc9511
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 149 deletions.
179 changes: 111 additions & 68 deletions once/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
import enum
import functools
import inspect
import time
import threading
import typing
import weakref

from . import _iterator_wrappers


def _new_lock() -> threading.Lock:
return threading.Lock()


def _is_method(func: collections.abc.Callable):
"""Determine if a function is a method on a class."""
if isinstance(func, (classmethod, staticmethod)):
Expand All @@ -42,10 +39,16 @@ def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionTy
return _WrappedFunctionType.SYNC_GENERATOR
if inspect.iscoroutinefunction(func):
return _WrappedFunctionType.ASYNC_FUNCTION
# This must come last, because it would return True for all the other types
if inspect.isfunction(func):
return _WrappedFunctionType.SYNC_FUNCTION
return _WrappedFunctionType.UNSUPPORTED
# We assume it is a callable sync function if it is callable.
if not callable(func):
return _WrappedFunctionType.UNSUPPORTED
return _WrappedFunctionType.SYNC_FUNCTION


class _ExecutionState(enum.Enum):
UNCALLED = 0
WAITING = 1
COMPLETED = 2


class _OnceBase(abc.ABC):
Expand All @@ -54,15 +57,18 @@ class _OnceBase(abc.ABC):
def __init__(self, func: collections.abc.Callable) -> None:
functools.update_wrapper(self, func)
self.func = self._inspect_function(func)
self.called = False
self.call_state = _ExecutionState.UNCALLED
self.return_value: typing.Any = None
self.fn_type = _wrapped_function_type(self.func)
if self.fn_type == _WrappedFunctionType.UNSUPPORTED:
raise SyntaxError(f"Unable to wrap a {type(func)}")
if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION:
if (
self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION
or self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR
):
self.async_lock = asyncio.Lock()
else:
self.lock = _new_lock()
self.lock = threading.Lock()

@abc.abstractmethod
def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable:
Expand All @@ -74,61 +80,108 @@ def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.C
It should return the function which should be executed once.
"""

def _callable(self, func: collections.abc.Callable):
"""Choose the appropriate call_once based on the function type."""
if self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR:
wrapped = functools.partial(self._execute_call_once_async_iter, func)
elif self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION:
wrapped = functools.partial(self._execute_call_once_async, func)
elif self.fn_type == _WrappedFunctionType.SYNC_FUNCTION:
wrapped = functools.partial(self._execute_call_once_sync, func)
else:
assert self.fn_type == _WrappedFunctionType.SYNC_GENERATOR
wrapped = functools.partial(self._execute_call_once_sync_iter, func)
functools.update_wrapper(wrapped, func)
return wrapped

async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value
async with self.async_lock:
if self.called:
return self.return_value
else:
self.return_value = await func(*args, **kwargs)
self.called = True
return self.return_value

# This cannot be an async function!
def _execute_call_once_async_iter(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value.yield_results()
with self.lock:
if not self.called:
self.called = True
call_state = self.call_state
while call_state != _ExecutionState.COMPLETED:
if call_state == _ExecutionState.WAITING:
# Release the GIL.
await asyncio.sleep(0)
async with self.async_lock:
call_state = self.call_state
if call_state == _ExecutionState.UNCALLED:
self.call_state = _ExecutionState.WAITING
# Only one thread will be allowed into this state.
if call_state == _ExecutionState.UNCALLED:
try:
return_value = await func(*args, **kwargs)
except Exception as exc:
async with self.async_lock:
self.call_state = _ExecutionState.UNCALLED
raise exc
async with self.async_lock:
self.return_value = return_value
self.call_state = _ExecutionState.COMPLETED
return self.return_value

async def _execute_call_once_async_iter(self, func: collections.abc.Callable, *args, **kwargs):
async with self.async_lock:
if self.call_state == _ExecutionState.UNCALLED:
self.return_value = _iterator_wrappers.AsyncGeneratorWrapper(func, *args, **kwargs)
return self.return_value.yield_results()

def _sync_return(self):
if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR:
return self.return_value.yield_results().__iter__()
else:
return self.return_value
self.call_state = _ExecutionState.COMPLETED
next_value = None
iterator = self.return_value.yield_results()
while True:
try:
next_value = yield await iterator.asend(next_value)
except StopAsyncIteration:
return

def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self._sync_return()
with self.lock:
if self.called:
return self._sync_return()
if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR:
call_state = self.call_state
while call_state != _ExecutionState.COMPLETED:
# We only hit this state in multi-threded code. To reduce contention, we invoke
# time.sleep so another thread an pick up the GIL.
if call_state == _ExecutionState.WAITING:
time.sleep(0)
with self.lock:
call_state = self.call_state
if call_state == _ExecutionState.UNCALLED:
self.call_state = _ExecutionState.WAITING
# Only one thread will be allowed into this state.
if call_state == _ExecutionState.UNCALLED:
try:
return_value = func(*args, **kwargs)
except Exception as exc:
with self.lock:
self.call_state = _ExecutionState.UNCALLED
raise exc
else:
with self.lock:
self.return_value = return_value
self.call_state = _ExecutionState.COMPLETED
return self.return_value

def _execute_call_once_sync_iter(self, func: collections.abc.Callable, *args, **kwargs):
with self.lock:
if self.call_state == _ExecutionState.UNCALLED:
self.return_value = _iterator_wrappers.GeneratorWrapper(func, *args, **kwargs)
else:
self.return_value = func(*args, **kwargs)
self.called = True
return self._sync_return()
self.call_state = _ExecutionState.COMPLETED
yield from self.return_value.yield_results()

def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs):
"""Choose the appropriate call_once based on the function type."""
if self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR:
return self._execute_call_once_async_iter(func, *args, **kwargs)
if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION:
return self._execute_call_once_async(func, *args, **kwargs)
return self._execute_call_once_sync(func, *args, **kwargs)

class _OnceFn(_OnceBase):
def _inspect_function(self, func: collections.abc.Callable):
if _is_method(func):
raise SyntaxError(
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
return func

class once(_OnceBase): # pylint: disable=invalid-name

def once(func: collections.abc.Callable):
"""Decorator to ensure a function is only called once.
The restriction of only one call also holds across threads. However, this
restriction does not apply to unsuccessful function calls. If the function
raises an exception, the next call will invoke a new call to the function.
raises an exception, the next call will invoke a new call to the function,
unless it is in iterator, in which case the failure will be cached.
If the function is called with multiple arguments, it will still only be
called only once.
Expand All @@ -141,17 +194,8 @@ class once(_OnceBase): # pylint: disable=invalid-name
module and class level functions (i.e. non-closures), this means the return
value will never be deleted.
"""

def _inspect_function(self, func: collections.abc.Callable):
if _is_method(func):
raise SyntaxError(
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
return func

def __call__(self, *args, **kwargs):
return self._execute_call_once(self.func, *args, **kwargs)
once_obj = _OnceFn(func)
return once_obj._callable(func)


class once_per_class(_OnceBase): # pylint: disable=invalid-name
Expand Down Expand Up @@ -182,11 +226,10 @@ def _inspect_function(self, func: collections.abc.Callable):
# bound version of the function to the object or class.
def __get__(self, obj, cls):
if self.is_classmethod:
func = functools.partial(self.func, cls)
return functools.partial(self._execute_call_once, func)
return self._callable(functools.partial(self.func, cls))
if self.is_staticmethod:
return functools.partial(self._execute_call_once, self.func)
return functools.partial(self._execute_call_once, self.func, obj)
return self._callable(self.func)
return self._callable(functools.partial(self.func, obj))


class once_per_instance(_OnceBase): # pylint: disable=invalid-name
Expand Down Expand Up @@ -233,7 +276,7 @@ def _execute_call_once_per_instance(self, obj, *args, **kwargs):
if obj in self.inflight_lock:
inflight_lock = self.inflight_lock[obj]
else:
inflight_lock = _new_lock()
inflight_lock = threading.Lock()
self.inflight_lock[obj] = inflight_lock
# Now we have a per-object lock. This means that we will not block
# other instances. In addition to better performance, this reduces the
Expand Down
16 changes: 13 additions & 3 deletions once/_iterator_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections.abc
import enum
import threading
import time

# Before we begin, a note on the assert statements in this file:
# Why are we using assert in here, you might ask, instead of implementing "proper" error handling?
Expand Down Expand Up @@ -114,24 +115,26 @@ def __init__(self, func, *args, **kwargs) -> None:
self.results: list = []
self.generating = False
self.lock = threading.Lock()
self.exception: Exception | None = None

def yield_results(self) -> collections.abc.Generator:
# Fast path for subsequent repeated call:
with self.lock:
finished = self.finished
if finished:
fast_path = self.finished and self.exception is None
if fast_path:
yield from self.results
return
i = 0
yield_value = None
next_send = None
# Fast path for subsequent calls will not require a lock
while True:
action: _IteratorAction | None = None
# With a lock, we figure out which action to take, and then we take it after release.
with self.lock:
if i == len(self.results):
if self.finished:
if self.exception:
raise self.exception
return
if self.generating:
action = _IteratorAction.WAITING
Expand All @@ -142,6 +145,7 @@ def yield_results(self) -> collections.abc.Generator:
action = _IteratorAction.YIELDING
yield_value = self.results[i]
if action == _IteratorAction.WAITING:
time.sleep(0)
continue
if action == _IteratorAction.YIELDING:
next_send = yield yield_value
Expand All @@ -154,8 +158,14 @@ def yield_results(self) -> collections.abc.Generator:
except StopIteration:
with self.lock:
self.finished = True
self.generating = False
self.generator = None # Allow this to be GCed.
except Exception as e:
with self.lock:
self.finished = True
self.generating = False
self.exception = e
self.generator = None # Allow this to be GCed.
else:
with self.lock:
self.generating = False
Expand Down
Loading

0 comments on commit 6cc9511

Please sign in to comment.