From 4a10b6c4f3664987ef4f40ef335616e3977fbbd2 Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Mon, 25 Sep 2023 10:53:01 -0700 Subject: [PATCH] Implement support for async and iterators. This change makes the decorator work correctly for async functions, instead of always returning the same coroutine, which can only be awaited once and calls subsequent calls to fail. This also detects returned iterators, evaluates them to completion, and returns the result as a tuple. Prior to this change, an exhausted iterator would be returned. --- once/__init__.py | 68 ++++++++++++++++++++++++++++++++------- once_test.py | 83 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 12 deletions(-) diff --git a/once/__init__.py b/once/__init__.py index 8811048..8ca4e6b 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -1,10 +1,12 @@ """Utility for initialization ensuring functions are called only once.""" import abc +import asyncio import collections.abc import functools import inspect import threading import typing +import types import weakref @@ -24,30 +26,57 @@ class _OnceBase(abc.ABC): """Abstract Base Class for once function decorators.""" def __init__(self, func: collections.abc.Callable): - self._inspect_function(func) functools.update_wrapper(self, func) - self.lock = _new_lock() + self.func = self._inspect_function(func) self.called = False self.return_value: typing.Any = None - self.func = func + self.is_asyncgen = inspect.isasyncgenfunction(self.func) + self.is_async = True if self.is_asyncgen else inspect.iscoroutinefunction(self.func) + if self.is_async: + self.async_lock = asyncio.Lock() + else: + self.lock = _new_lock() @abc.abstractmethod - def _inspect_function(self, func: collections.abc.Callable): + def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable: """Inspect the passed-in function to ensure it can be wrapped. This function should raise a SyntaxError if the passed-in function is - not suitable.""" + not suitable. - def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): + It should return the function which should be executed once. + """ + + async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs): + if self.called: + return self.return_value + async with self.async_lock: + if self.called: + return self.return_value + if self.is_asyncgen: + self.return_value = tuple([i async for i in func(*args, **kwargs)]) + else: + self.return_value = await func(*args, **kwargs) + self.called = True + return self.return_value + + def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs): if self.called: return self.return_value with self.lock: if self.called: return self.return_value self.return_value = func(*args, **kwargs) + if isinstance(self.return_value, types.GeneratorType): + self.return_value = tuple(self.return_value) self.called = True return self.return_value + def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): + if self.is_async: + return self._execute_call_once_async(func, *args, **kwargs) + return self._execute_call_once_sync(func, *args, **kwargs) + class once(_OnceBase): # pylint: disable=invalid-name """Decorator to ensure a function is only called once. @@ -74,6 +103,7 @@ def _inspect_function(self, func: collections.abc.Callable): "Attempting to use @once.once decorator on method " "instead of @once.once_per_class or @once.once_per_instance" ) + return func def __call__(self, *args, **kwargs): return self._execute_call_once(self.func, *args, **kwargs) @@ -82,22 +112,35 @@ def __call__(self, *args, **kwargs): class once_per_class(_OnceBase): # pylint: disable=invalid-name """A version of once for class methods which runs once across all instances.""" - def _inspect_function(self, func): + is_classmethod: bool + is_staticmethod: bool + + def _inspect_function(self, func: collections.abc.Callable): if not _is_method(func): raise SyntaxError( "Attempting to use @once.once_per_class method-only decorator " "instead of @once.once" ) + if isinstance(func, classmethod): + self.is_classmethod = True + self.is_staticmethod = False + return func.__func__ + if isinstance(func, staticmethod): + self.is_classmethod = False + self.is_staticmethod = True + return func.__func__ + self.is_classmethod = False + self.is_staticmethod = False + return func # This is needed for a decorator on a class method to return a # bound version of the function to the object or class. def __get__(self, obj, cls): - if isinstance(self.func, classmethod): - func = functools.partial(self.func.__func__, cls) + if self.is_classmethod: + func = functools.partial(self.func, cls) return functools.partial(self._execute_call_once, func) - if isinstance(self.func, staticmethod): - # The additional __func__ is required for python <= 3.9 - return functools.partial(self._execute_call_once, self.func.__func__) + if self.is_staticmethod: + return functools.partial(self._execute_call_once, self.func) return functools.partial(self._execute_call_once, self.func, obj) @@ -119,6 +162,7 @@ def _inspect_function(self, func: collections.abc.Callable): "Attempting to use @once.once_per_instance method-only decorator " "instead of @once.once" ) + return func # This is needed for a decorator on a class method to return a # bound version of the function to the object. diff --git a/once_test.py b/once_test.py index 55783a3..f872222 100644 --- a/once_test.py +++ b/once_test.py @@ -62,6 +62,15 @@ def test_different_args_same_result(self): self.assertEqual(counting_fn(2), 1) self.assertEqual(counter.value, 1) + def test_iterator(self): + @once.once + def yielding_iterator(): + for i in range(3): + yield i + + self.assertEqual(yielding_iterator(), (0, 1, 2)) + self.assertEqual(yielding_iterator(), (0, 1, 2)) + def test_threaded_single_function(self): counting_fn, counter = generate_once_counter_fn() with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: @@ -338,5 +347,79 @@ def value(): self.assertEqual(_CallOnceClass.value(), 1) +class TestOnceAsync(unittest.IsolatedAsyncioTestCase): + async def test_fn_called_once(self): + counter1 = Counter() + + @once.once + async def counting_fn1(): + return counter1.get_incremented() + + counter2 = Counter() + # We should get a different value than the previous function + counter2.get_incremented() + + @once.once + async def counting_fn2(): + return counter2.get_incremented() + + self.assertEqual(await counting_fn1(), 1) + self.assertEqual(await counting_fn1(), 1) + self.assertEqual(await counting_fn2(), 2) + self.assertEqual(await counting_fn2(), 2) + + async def test_iterator(self): + counter = Counter() + + @once.once + async def async_yielding_iterator(): + yield counter.get_incremented() + for i in range(3): + yield i + + self.assertEqual(await async_yielding_iterator(), (1, 0, 1, 2)) + self.assertEqual(await async_yielding_iterator(), (1, 0, 1, 2)) + + async def test_once_per_class(self): + class _CallOnceClass(Counter): + @once.once_per_class + async def once_fn(self): + return self.get_incremented() + + a = _CallOnceClass() # pylint: disable=invalid-name + b = _CallOnceClass() # pylint: disable=invalid-name + + self.assertEqual(await a.once_fn(), 1) + self.assertEqual(await a.once_fn(), 1) + self.assertEqual(await b.once_fn(), 1) + self.assertEqual(await b.once_fn(), 1) + + async def test_once_per_class_classmethod(self): + counter = Counter() + + class _CallOnceClass: + @once.once_per_class + @classmethod + async def value(cls): + nonlocal counter + return counter.get_incremented() + + self.assertEqual(await _CallOnceClass.value(), 1) + self.assertEqual(await _CallOnceClass.value(), 1) + + async def test_once_per_class_staticmethod(self): + counter = Counter() + + class _CallOnceClass: + @once.once_per_class + @staticmethod + async def value(): + nonlocal counter + return counter.get_incremented() + + self.assertEqual(await _CallOnceClass.value(), 1) + self.assertEqual(await _CallOnceClass.value(), 1) + + if __name__ == "__main__": unittest.main()