Skip to content

Commit

Permalink
Implement support for async and iterators.
Browse files Browse the repository at this point in the history
This change makes the decorator work correctly for async functions,
instead of always returning the same coroutine, which can only be
awaited once and calls subsequent calls to fail.

This also detects returned iterators, evaluates them to completion, and
returns the result as a tuple. Prior to this change, an exhausted
iterator would be returned.
  • Loading branch information
aebrahim committed Sep 25, 2023
1 parent 9a0759e commit 4a10b6c
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 12 deletions.
68 changes: 56 additions & 12 deletions once/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Utility for initialization ensuring functions are called only once."""
import abc
import asyncio
import collections.abc
import functools
import inspect
import threading
import typing
import types
import weakref


Expand All @@ -24,30 +26,57 @@ class _OnceBase(abc.ABC):
"""Abstract Base Class for once function decorators."""

def __init__(self, func: collections.abc.Callable):
self._inspect_function(func)
functools.update_wrapper(self, func)
self.lock = _new_lock()
self.func = self._inspect_function(func)
self.called = False
self.return_value: typing.Any = None
self.func = func
self.is_asyncgen = inspect.isasyncgenfunction(self.func)
self.is_async = True if self.is_asyncgen else inspect.iscoroutinefunction(self.func)
if self.is_async:
self.async_lock = asyncio.Lock()
else:
self.lock = _new_lock()

@abc.abstractmethod
def _inspect_function(self, func: collections.abc.Callable):
def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable:
"""Inspect the passed-in function to ensure it can be wrapped.
This function should raise a SyntaxError if the passed-in function is
not suitable."""
not suitable.
def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs):
It should return the function which should be executed once.
"""

async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value
async with self.async_lock:
if self.called:
return self.return_value
if self.is_asyncgen:
self.return_value = tuple([i async for i in func(*args, **kwargs)])
else:
self.return_value = await func(*args, **kwargs)
self.called = True
return self.return_value

def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value
with self.lock:
if self.called:
return self.return_value
self.return_value = func(*args, **kwargs)
if isinstance(self.return_value, types.GeneratorType):
self.return_value = tuple(self.return_value)
self.called = True
return self.return_value

def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs):
if self.is_async:
return self._execute_call_once_async(func, *args, **kwargs)
return self._execute_call_once_sync(func, *args, **kwargs)


class once(_OnceBase): # pylint: disable=invalid-name
"""Decorator to ensure a function is only called once.
Expand All @@ -74,6 +103,7 @@ def _inspect_function(self, func: collections.abc.Callable):
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
return func

def __call__(self, *args, **kwargs):
return self._execute_call_once(self.func, *args, **kwargs)
Expand All @@ -82,22 +112,35 @@ def __call__(self, *args, **kwargs):
class once_per_class(_OnceBase): # pylint: disable=invalid-name
"""A version of once for class methods which runs once across all instances."""

def _inspect_function(self, func):
is_classmethod: bool
is_staticmethod: bool

def _inspect_function(self, func: collections.abc.Callable):
if not _is_method(func):
raise SyntaxError(
"Attempting to use @once.once_per_class method-only decorator "
"instead of @once.once"
)
if isinstance(func, classmethod):
self.is_classmethod = True
self.is_staticmethod = False
return func.__func__
if isinstance(func, staticmethod):
self.is_classmethod = False
self.is_staticmethod = True
return func.__func__
self.is_classmethod = False
self.is_staticmethod = False
return func

# This is needed for a decorator on a class method to return a
# bound version of the function to the object or class.
def __get__(self, obj, cls):
if isinstance(self.func, classmethod):
func = functools.partial(self.func.__func__, cls)
if self.is_classmethod:
func = functools.partial(self.func, cls)
return functools.partial(self._execute_call_once, func)
if isinstance(self.func, staticmethod):
# The additional __func__ is required for python <= 3.9
return functools.partial(self._execute_call_once, self.func.__func__)
if self.is_staticmethod:
return functools.partial(self._execute_call_once, self.func)
return functools.partial(self._execute_call_once, self.func, obj)


Expand All @@ -119,6 +162,7 @@ def _inspect_function(self, func: collections.abc.Callable):
"Attempting to use @once.once_per_instance method-only decorator "
"instead of @once.once"
)
return func

# This is needed for a decorator on a class method to return a
# bound version of the function to the object.
Expand Down
83 changes: 83 additions & 0 deletions once_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def test_different_args_same_result(self):
self.assertEqual(counting_fn(2), 1)
self.assertEqual(counter.value, 1)

def test_iterator(self):
@once.once
def yielding_iterator():
for i in range(3):
yield i

self.assertEqual(yielding_iterator(), (0, 1, 2))
self.assertEqual(yielding_iterator(), (0, 1, 2))

def test_threaded_single_function(self):
counting_fn, counter = generate_once_counter_fn()
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
Expand Down Expand Up @@ -338,5 +347,79 @@ def value():
self.assertEqual(_CallOnceClass.value(), 1)


class TestOnceAsync(unittest.IsolatedAsyncioTestCase):
async def test_fn_called_once(self):
counter1 = Counter()

@once.once
async def counting_fn1():
return counter1.get_incremented()

counter2 = Counter()
# We should get a different value than the previous function
counter2.get_incremented()

@once.once
async def counting_fn2():
return counter2.get_incremented()

self.assertEqual(await counting_fn1(), 1)
self.assertEqual(await counting_fn1(), 1)
self.assertEqual(await counting_fn2(), 2)
self.assertEqual(await counting_fn2(), 2)

async def test_iterator(self):
counter = Counter()

@once.once
async def async_yielding_iterator():
yield counter.get_incremented()
for i in range(3):
yield i

self.assertEqual(await async_yielding_iterator(), (1, 0, 1, 2))
self.assertEqual(await async_yielding_iterator(), (1, 0, 1, 2))

async def test_once_per_class(self):
class _CallOnceClass(Counter):
@once.once_per_class
async def once_fn(self):
return self.get_incremented()

a = _CallOnceClass() # pylint: disable=invalid-name
b = _CallOnceClass() # pylint: disable=invalid-name

self.assertEqual(await a.once_fn(), 1)
self.assertEqual(await a.once_fn(), 1)
self.assertEqual(await b.once_fn(), 1)
self.assertEqual(await b.once_fn(), 1)

async def test_once_per_class_classmethod(self):
counter = Counter()

class _CallOnceClass:
@once.once_per_class
@classmethod
async def value(cls):
nonlocal counter
return counter.get_incremented()

self.assertEqual(await _CallOnceClass.value(), 1)
self.assertEqual(await _CallOnceClass.value(), 1)

async def test_once_per_class_staticmethod(self):
counter = Counter()

class _CallOnceClass:
@once.once_per_class
@staticmethod
async def value():
nonlocal counter
return counter.get_incremented()

self.assertEqual(await _CallOnceClass.value(), 1)
self.assertEqual(await _CallOnceClass.value(), 1)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4a10b6c

Please sign in to comment.