diff --git a/once/__init__.py b/once/__init__.py index 8811048..b7f1560 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -1,5 +1,6 @@ """Utility for initialization ensuring functions are called only once.""" import abc +import asyncio import collections.abc import functools import inspect @@ -24,30 +25,71 @@ class _OnceBase(abc.ABC): """Abstract Base Class for once function decorators.""" def __init__(self, func: collections.abc.Callable): - self._inspect_function(func) functools.update_wrapper(self, func) - self.lock = _new_lock() + self.func = self._inspect_function(func) self.called = False self.return_value: typing.Any = None - self.func = func + self.is_asyncgen = inspect.isasyncgenfunction(self.func) + if self.is_asyncgen: + raise SyntaxError("async generators are not (yet) supported") + self.is_syncgen = inspect.isgeneratorfunction(self.func) + # The function inspect.isawaitable is a bit of a misnomer - it refers + # to the awaitable result of an async function, not the async function + # itself. + self.is_async = True if self.is_asyncgen else inspect.iscoroutinefunction(self.func) + if self.is_async: + self.async_lock = asyncio.Lock() + else: + self.lock = _new_lock() @abc.abstractmethod - def _inspect_function(self, func: collections.abc.Callable): + def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable: """Inspect the passed-in function to ensure it can be wrapped. This function should raise a SyntaxError if the passed-in function is - not suitable.""" + not suitable. - def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): + It should return the function which should be executed once. + """ + + def _sync_return(self): + if self.is_syncgen: + self.return_value.__iter__() + return self.return_value + + async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs): if self.called: return self.return_value - with self.lock: + async with self.async_lock: if self.called: return self.return_value - self.return_value = func(*args, **kwargs) + # Currently unreachable code - Async iterators are disabled for now. + if self.is_asyncgen: + self.return_value = [i async for i in func(*args, **kwargs)] + else: + self.return_value = await func(*args, **kwargs) self.called = True return self.return_value + def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs): + if self.called: + return self._sync_return() + with self.lock: + if self.called: + return self._sync_return() + self.return_value = func(*args, **kwargs) + if self.is_syncgen: + # A potential optimization is to evaluate the iterator lazily, + # as opposed to eagerly like we do here. + self.return_value = tuple(self.return_value) + self.called = True + return self._sync_return() + + def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): + if self.is_async: + return self._execute_call_once_async(func, *args, **kwargs) + return self._execute_call_once_sync(func, *args, **kwargs) + class once(_OnceBase): # pylint: disable=invalid-name """Decorator to ensure a function is only called once. @@ -74,6 +116,7 @@ def _inspect_function(self, func: collections.abc.Callable): "Attempting to use @once.once decorator on method " "instead of @once.once_per_class or @once.once_per_instance" ) + return func def __call__(self, *args, **kwargs): return self._execute_call_once(self.func, *args, **kwargs) @@ -82,22 +125,35 @@ def __call__(self, *args, **kwargs): class once_per_class(_OnceBase): # pylint: disable=invalid-name """A version of once for class methods which runs once across all instances.""" - def _inspect_function(self, func): + is_classmethod: bool + is_staticmethod: bool + + def _inspect_function(self, func: collections.abc.Callable): if not _is_method(func): raise SyntaxError( "Attempting to use @once.once_per_class method-only decorator " "instead of @once.once" ) + if isinstance(func, classmethod): + self.is_classmethod = True + self.is_staticmethod = False + return func.__func__ + if isinstance(func, staticmethod): + self.is_classmethod = False + self.is_staticmethod = True + return func.__func__ + self.is_classmethod = False + self.is_staticmethod = False + return func # This is needed for a decorator on a class method to return a # bound version of the function to the object or class. def __get__(self, obj, cls): - if isinstance(self.func, classmethod): - func = functools.partial(self.func.__func__, cls) + if self.is_classmethod: + func = functools.partial(self.func, cls) return functools.partial(self._execute_call_once, func) - if isinstance(self.func, staticmethod): - # The additional __func__ is required for python <= 3.9 - return functools.partial(self._execute_call_once, self.func.__func__) + if self.is_staticmethod: + return functools.partial(self._execute_call_once, self.func) return functools.partial(self._execute_call_once, self.func, obj) @@ -112,6 +168,8 @@ def __init__(self, func: collections.abc.Callable): self.inflight_lock: typing.Dict[typing.Any, threading.Lock] = {} def _inspect_function(self, func: collections.abc.Callable): + if inspect.isasyncgenfunction(func) or inspect.iscoroutinefunction(func): + raise SyntaxError("once_per_instance not (yet) supported for async") if isinstance(func, (classmethod, staticmethod)): raise SyntaxError("Must use @once.once_per_class on classmethod and staticmethod") if not _is_method(func): @@ -119,6 +177,7 @@ def _inspect_function(self, func: collections.abc.Callable): "Attempting to use @once.once_per_instance method-only decorator " "instead of @once.once" ) + return func # This is needed for a decorator on a class method to return a # bound version of the function to the object. diff --git a/once_test.py b/once_test.py index 55783a3..34def17 100644 --- a/once_test.py +++ b/once_test.py @@ -62,6 +62,15 @@ def test_different_args_same_result(self): self.assertEqual(counting_fn(2), 1) self.assertEqual(counter.value, 1) + def test_iterator(self): + @once.once + def yielding_iterator(): + for i in range(3): + yield i + + self.assertEqual(list(yielding_iterator()), [0, 1, 2]) + self.assertEqual(list(yielding_iterator()), [0, 1, 2]) + def test_threaded_single_function(self): counting_fn, counter = generate_once_counter_fn() with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: @@ -338,5 +347,81 @@ def value(): self.assertEqual(_CallOnceClass.value(), 1) +class TestOnceAsync(unittest.IsolatedAsyncioTestCase): + async def test_fn_called_once(self): + counter1 = Counter() + + @once.once + async def counting_fn1(): + return counter1.get_incremented() + + counter2 = Counter() + # We should get a different value than the previous function + counter2.get_incremented() + + @once.once + async def counting_fn2(): + return counter2.get_incremented() + + self.assertEqual(await counting_fn1(), 1) + self.assertEqual(await counting_fn1(), 1) + self.assertEqual(await counting_fn2(), 2) + self.assertEqual(await counting_fn2(), 2) + + async def test_iterator(self): + counter = Counter() + + with self.assertRaises(SyntaxError): + + @once.once + async def async_yielding_iterator(): + yield counter.get_incremented() + for i in range(3): + yield i + + # self.assertEqual([i async for i in async_yielding_iterator()], [1, 0, 1, 2]) + # self.assertEqual([i async for i in async_yielding_iterator()], [1, 0, 1, 2]) + + async def test_once_per_class(self): + class _CallOnceClass(Counter): + @once.once_per_class + async def once_fn(self): + return self.get_incremented() + + a = _CallOnceClass() # pylint: disable=invalid-name + b = _CallOnceClass() # pylint: disable=invalid-name + + self.assertEqual(await a.once_fn(), 1) + self.assertEqual(await a.once_fn(), 1) + self.assertEqual(await b.once_fn(), 1) + self.assertEqual(await b.once_fn(), 1) + + async def test_once_per_class_classmethod(self): + counter = Counter() + + class _CallOnceClass: + @once.once_per_class + @classmethod + async def value(cls): + nonlocal counter + return counter.get_incremented() + + self.assertEqual(await _CallOnceClass.value(), 1) + self.assertEqual(await _CallOnceClass.value(), 1) + + async def test_once_per_class_staticmethod(self): + counter = Counter() + + class _CallOnceClass: + @once.once_per_class + @staticmethod + async def value(): + nonlocal counter + return counter.get_incremented() + + self.assertEqual(await _CallOnceClass.value(), 1) + self.assertEqual(await _CallOnceClass.value(), 1) + + if __name__ == "__main__": unittest.main()