Skip to content

Commit

Permalink
Implement support for async and iterators. (#3)
Browse files Browse the repository at this point in the history
* Implement support for async and iterators.

This change makes the decorator work correctly for async functions,
instead of always returning the same coroutine, which can only be
awaited once and calls subsequent calls to fail.

This also detects returned iterators, evaluates them to completion, and
returns the result as a tuple. Prior to this change, an exhausted
iterator would be returned.

* Explicitly disable async iterators.

Also, better handle sync iterators to be API-compatible with the
unwrapped function call by returning an iterator object instead of a
tuple object.
  • Loading branch information
aebrahim authored Sep 26, 2023
1 parent 9a0759e commit 63496a0
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 14 deletions.
87 changes: 73 additions & 14 deletions once/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility for initialization ensuring functions are called only once."""
import abc
import asyncio
import collections.abc
import functools
import inspect
Expand All @@ -24,30 +25,71 @@ class _OnceBase(abc.ABC):
"""Abstract Base Class for once function decorators."""

def __init__(self, func: collections.abc.Callable):
self._inspect_function(func)
functools.update_wrapper(self, func)
self.lock = _new_lock()
self.func = self._inspect_function(func)
self.called = False
self.return_value: typing.Any = None
self.func = func
self.is_asyncgen = inspect.isasyncgenfunction(self.func)
if self.is_asyncgen:
raise SyntaxError("async generators are not (yet) supported")
self.is_syncgen = inspect.isgeneratorfunction(self.func)
# The function inspect.isawaitable is a bit of a misnomer - it refers
# to the awaitable result of an async function, not the async function
# itself.
self.is_async = True if self.is_asyncgen else inspect.iscoroutinefunction(self.func)
if self.is_async:
self.async_lock = asyncio.Lock()
else:
self.lock = _new_lock()

@abc.abstractmethod
def _inspect_function(self, func: collections.abc.Callable):
def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable:
"""Inspect the passed-in function to ensure it can be wrapped.
This function should raise a SyntaxError if the passed-in function is
not suitable."""
not suitable.
def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs):
It should return the function which should be executed once.
"""

def _sync_return(self):
if self.is_syncgen:
self.return_value.__iter__()
return self.return_value

async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value
with self.lock:
async with self.async_lock:
if self.called:
return self.return_value
self.return_value = func(*args, **kwargs)
# Currently unreachable code - Async iterators are disabled for now.
if self.is_asyncgen:
self.return_value = [i async for i in func(*args, **kwargs)]
else:
self.return_value = await func(*args, **kwargs)
self.called = True
return self.return_value

def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self._sync_return()
with self.lock:
if self.called:
return self._sync_return()
self.return_value = func(*args, **kwargs)
if self.is_syncgen:
# A potential optimization is to evaluate the iterator lazily,
# as opposed to eagerly like we do here.
self.return_value = tuple(self.return_value)
self.called = True
return self._sync_return()

def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs):
if self.is_async:
return self._execute_call_once_async(func, *args, **kwargs)
return self._execute_call_once_sync(func, *args, **kwargs)


class once(_OnceBase): # pylint: disable=invalid-name
"""Decorator to ensure a function is only called once.
Expand All @@ -74,6 +116,7 @@ def _inspect_function(self, func: collections.abc.Callable):
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
return func

def __call__(self, *args, **kwargs):
return self._execute_call_once(self.func, *args, **kwargs)
Expand All @@ -82,22 +125,35 @@ def __call__(self, *args, **kwargs):
class once_per_class(_OnceBase): # pylint: disable=invalid-name
"""A version of once for class methods which runs once across all instances."""

def _inspect_function(self, func):
is_classmethod: bool
is_staticmethod: bool

def _inspect_function(self, func: collections.abc.Callable):
if not _is_method(func):
raise SyntaxError(
"Attempting to use @once.once_per_class method-only decorator "
"instead of @once.once"
)
if isinstance(func, classmethod):
self.is_classmethod = True
self.is_staticmethod = False
return func.__func__
if isinstance(func, staticmethod):
self.is_classmethod = False
self.is_staticmethod = True
return func.__func__
self.is_classmethod = False
self.is_staticmethod = False
return func

# This is needed for a decorator on a class method to return a
# bound version of the function to the object or class.
def __get__(self, obj, cls):
if isinstance(self.func, classmethod):
func = functools.partial(self.func.__func__, cls)
if self.is_classmethod:
func = functools.partial(self.func, cls)
return functools.partial(self._execute_call_once, func)
if isinstance(self.func, staticmethod):
# The additional __func__ is required for python <= 3.9
return functools.partial(self._execute_call_once, self.func.__func__)
if self.is_staticmethod:
return functools.partial(self._execute_call_once, self.func)
return functools.partial(self._execute_call_once, self.func, obj)


Expand All @@ -112,13 +168,16 @@ def __init__(self, func: collections.abc.Callable):
self.inflight_lock: typing.Dict[typing.Any, threading.Lock] = {}

def _inspect_function(self, func: collections.abc.Callable):
if inspect.isasyncgenfunction(func) or inspect.iscoroutinefunction(func):
raise SyntaxError("once_per_instance not (yet) supported for async")
if isinstance(func, (classmethod, staticmethod)):
raise SyntaxError("Must use @once.once_per_class on classmethod and staticmethod")
if not _is_method(func):
raise SyntaxError(
"Attempting to use @once.once_per_instance method-only decorator "
"instead of @once.once"
)
return func

# This is needed for a decorator on a class method to return a
# bound version of the function to the object.
Expand Down
85 changes: 85 additions & 0 deletions once_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def test_different_args_same_result(self):
self.assertEqual(counting_fn(2), 1)
self.assertEqual(counter.value, 1)

def test_iterator(self):
@once.once
def yielding_iterator():
for i in range(3):
yield i

self.assertEqual(list(yielding_iterator()), [0, 1, 2])
self.assertEqual(list(yielding_iterator()), [0, 1, 2])

def test_threaded_single_function(self):
counting_fn, counter = generate_once_counter_fn()
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
Expand Down Expand Up @@ -338,5 +347,81 @@ def value():
self.assertEqual(_CallOnceClass.value(), 1)


class TestOnceAsync(unittest.IsolatedAsyncioTestCase):
async def test_fn_called_once(self):
counter1 = Counter()

@once.once
async def counting_fn1():
return counter1.get_incremented()

counter2 = Counter()
# We should get a different value than the previous function
counter2.get_incremented()

@once.once
async def counting_fn2():
return counter2.get_incremented()

self.assertEqual(await counting_fn1(), 1)
self.assertEqual(await counting_fn1(), 1)
self.assertEqual(await counting_fn2(), 2)
self.assertEqual(await counting_fn2(), 2)

async def test_iterator(self):
counter = Counter()

with self.assertRaises(SyntaxError):

@once.once
async def async_yielding_iterator():
yield counter.get_incremented()
for i in range(3):
yield i

# self.assertEqual([i async for i in async_yielding_iterator()], [1, 0, 1, 2])
# self.assertEqual([i async for i in async_yielding_iterator()], [1, 0, 1, 2])

async def test_once_per_class(self):
class _CallOnceClass(Counter):
@once.once_per_class
async def once_fn(self):
return self.get_incremented()

a = _CallOnceClass() # pylint: disable=invalid-name
b = _CallOnceClass() # pylint: disable=invalid-name

self.assertEqual(await a.once_fn(), 1)
self.assertEqual(await a.once_fn(), 1)
self.assertEqual(await b.once_fn(), 1)
self.assertEqual(await b.once_fn(), 1)

async def test_once_per_class_classmethod(self):
counter = Counter()

class _CallOnceClass:
@once.once_per_class
@classmethod
async def value(cls):
nonlocal counter
return counter.get_incremented()

self.assertEqual(await _CallOnceClass.value(), 1)
self.assertEqual(await _CallOnceClass.value(), 1)

async def test_once_per_class_staticmethod(self):
counter = Counter()

class _CallOnceClass:
@once.once_per_class
@staticmethod
async def value():
nonlocal counter
return counter.get_incremented()

self.assertEqual(await _CallOnceClass.value(), 1)
self.assertEqual(await _CallOnceClass.value(), 1)


if __name__ == "__main__":
unittest.main()

0 comments on commit 63496a0

Please sign in to comment.