From 63496a022269f9da707c27a551f0fa5a97cbc396 Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Mon, 25 Sep 2023 18:34:45 -0700 Subject: [PATCH] Implement support for async and iterators. (#3) * Implement support for async and iterators. This change makes the decorator work correctly for async functions, instead of always returning the same coroutine, which can only be awaited once and calls subsequent calls to fail. This also detects returned iterators, evaluates them to completion, and returns the result as a tuple. Prior to this change, an exhausted iterator would be returned. * Explicitly disable async iterators. Also, better handle sync iterators to be API-compatible with the unwrapped function call by returning an iterator object instead of a tuple object. --- once/__init__.py | 87 ++++++++++++++++++++++++++++++++++++++++-------- once_test.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 14 deletions(-) diff --git a/once/__init__.py b/once/__init__.py index 8811048..b7f1560 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -1,5 +1,6 @@ """Utility for initialization ensuring functions are called only once.""" import abc +import asyncio import collections.abc import functools import inspect @@ -24,30 +25,71 @@ class _OnceBase(abc.ABC): """Abstract Base Class for once function decorators.""" def __init__(self, func: collections.abc.Callable): - self._inspect_function(func) functools.update_wrapper(self, func) - self.lock = _new_lock() + self.func = self._inspect_function(func) self.called = False self.return_value: typing.Any = None - self.func = func + self.is_asyncgen = inspect.isasyncgenfunction(self.func) + if self.is_asyncgen: + raise SyntaxError("async generators are not (yet) supported") + self.is_syncgen = inspect.isgeneratorfunction(self.func) + # The function inspect.isawaitable is a bit of a misnomer - it refers + # to the awaitable result of an async function, not the async function + # itself. + self.is_async = True if self.is_asyncgen else inspect.iscoroutinefunction(self.func) + if self.is_async: + self.async_lock = asyncio.Lock() + else: + self.lock = _new_lock() @abc.abstractmethod - def _inspect_function(self, func: collections.abc.Callable): + def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable: """Inspect the passed-in function to ensure it can be wrapped. This function should raise a SyntaxError if the passed-in function is - not suitable.""" + not suitable. - def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): + It should return the function which should be executed once. + """ + + def _sync_return(self): + if self.is_syncgen: + self.return_value.__iter__() + return self.return_value + + async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs): if self.called: return self.return_value - with self.lock: + async with self.async_lock: if self.called: return self.return_value - self.return_value = func(*args, **kwargs) + # Currently unreachable code - Async iterators are disabled for now. + if self.is_asyncgen: + self.return_value = [i async for i in func(*args, **kwargs)] + else: + self.return_value = await func(*args, **kwargs) self.called = True return self.return_value + def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs): + if self.called: + return self._sync_return() + with self.lock: + if self.called: + return self._sync_return() + self.return_value = func(*args, **kwargs) + if self.is_syncgen: + # A potential optimization is to evaluate the iterator lazily, + # as opposed to eagerly like we do here. + self.return_value = tuple(self.return_value) + self.called = True + return self._sync_return() + + def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): + if self.is_async: + return self._execute_call_once_async(func, *args, **kwargs) + return self._execute_call_once_sync(func, *args, **kwargs) + class once(_OnceBase): # pylint: disable=invalid-name """Decorator to ensure a function is only called once. @@ -74,6 +116,7 @@ def _inspect_function(self, func: collections.abc.Callable): "Attempting to use @once.once decorator on method " "instead of @once.once_per_class or @once.once_per_instance" ) + return func def __call__(self, *args, **kwargs): return self._execute_call_once(self.func, *args, **kwargs) @@ -82,22 +125,35 @@ def __call__(self, *args, **kwargs): class once_per_class(_OnceBase): # pylint: disable=invalid-name """A version of once for class methods which runs once across all instances.""" - def _inspect_function(self, func): + is_classmethod: bool + is_staticmethod: bool + + def _inspect_function(self, func: collections.abc.Callable): if not _is_method(func): raise SyntaxError( "Attempting to use @once.once_per_class method-only decorator " "instead of @once.once" ) + if isinstance(func, classmethod): + self.is_classmethod = True + self.is_staticmethod = False + return func.__func__ + if isinstance(func, staticmethod): + self.is_classmethod = False + self.is_staticmethod = True + return func.__func__ + self.is_classmethod = False + self.is_staticmethod = False + return func # This is needed for a decorator on a class method to return a # bound version of the function to the object or class. def __get__(self, obj, cls): - if isinstance(self.func, classmethod): - func = functools.partial(self.func.__func__, cls) + if self.is_classmethod: + func = functools.partial(self.func, cls) return functools.partial(self._execute_call_once, func) - if isinstance(self.func, staticmethod): - # The additional __func__ is required for python <= 3.9 - return functools.partial(self._execute_call_once, self.func.__func__) + if self.is_staticmethod: + return functools.partial(self._execute_call_once, self.func) return functools.partial(self._execute_call_once, self.func, obj) @@ -112,6 +168,8 @@ def __init__(self, func: collections.abc.Callable): self.inflight_lock: typing.Dict[typing.Any, threading.Lock] = {} def _inspect_function(self, func: collections.abc.Callable): + if inspect.isasyncgenfunction(func) or inspect.iscoroutinefunction(func): + raise SyntaxError("once_per_instance not (yet) supported for async") if isinstance(func, (classmethod, staticmethod)): raise SyntaxError("Must use @once.once_per_class on classmethod and staticmethod") if not _is_method(func): @@ -119,6 +177,7 @@ def _inspect_function(self, func: collections.abc.Callable): "Attempting to use @once.once_per_instance method-only decorator " "instead of @once.once" ) + return func # This is needed for a decorator on a class method to return a # bound version of the function to the object. diff --git a/once_test.py b/once_test.py index 55783a3..34def17 100644 --- a/once_test.py +++ b/once_test.py @@ -62,6 +62,15 @@ def test_different_args_same_result(self): self.assertEqual(counting_fn(2), 1) self.assertEqual(counter.value, 1) + def test_iterator(self): + @once.once + def yielding_iterator(): + for i in range(3): + yield i + + self.assertEqual(list(yielding_iterator()), [0, 1, 2]) + self.assertEqual(list(yielding_iterator()), [0, 1, 2]) + def test_threaded_single_function(self): counting_fn, counter = generate_once_counter_fn() with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: @@ -338,5 +347,81 @@ def value(): self.assertEqual(_CallOnceClass.value(), 1) +class TestOnceAsync(unittest.IsolatedAsyncioTestCase): + async def test_fn_called_once(self): + counter1 = Counter() + + @once.once + async def counting_fn1(): + return counter1.get_incremented() + + counter2 = Counter() + # We should get a different value than the previous function + counter2.get_incremented() + + @once.once + async def counting_fn2(): + return counter2.get_incremented() + + self.assertEqual(await counting_fn1(), 1) + self.assertEqual(await counting_fn1(), 1) + self.assertEqual(await counting_fn2(), 2) + self.assertEqual(await counting_fn2(), 2) + + async def test_iterator(self): + counter = Counter() + + with self.assertRaises(SyntaxError): + + @once.once + async def async_yielding_iterator(): + yield counter.get_incremented() + for i in range(3): + yield i + + # self.assertEqual([i async for i in async_yielding_iterator()], [1, 0, 1, 2]) + # self.assertEqual([i async for i in async_yielding_iterator()], [1, 0, 1, 2]) + + async def test_once_per_class(self): + class _CallOnceClass(Counter): + @once.once_per_class + async def once_fn(self): + return self.get_incremented() + + a = _CallOnceClass() # pylint: disable=invalid-name + b = _CallOnceClass() # pylint: disable=invalid-name + + self.assertEqual(await a.once_fn(), 1) + self.assertEqual(await a.once_fn(), 1) + self.assertEqual(await b.once_fn(), 1) + self.assertEqual(await b.once_fn(), 1) + + async def test_once_per_class_classmethod(self): + counter = Counter() + + class _CallOnceClass: + @once.once_per_class + @classmethod + async def value(cls): + nonlocal counter + return counter.get_incremented() + + self.assertEqual(await _CallOnceClass.value(), 1) + self.assertEqual(await _CallOnceClass.value(), 1) + + async def test_once_per_class_staticmethod(self): + counter = Counter() + + class _CallOnceClass: + @once.once_per_class + @staticmethod + async def value(): + nonlocal counter + return counter.get_incremented() + + self.assertEqual(await _CallOnceClass.value(), 1) + self.assertEqual(await _CallOnceClass.value(), 1) + + if __name__ == "__main__": unittest.main()