Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for async and iterators. #3

Merged
merged 2 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 56 additions & 12 deletions once/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Utility for initialization ensuring functions are called only once."""
import abc
import asyncio
import collections.abc
import functools
import inspect
import threading
import typing
import types
import weakref


Expand All @@ -24,30 +26,57 @@ class _OnceBase(abc.ABC):
"""Abstract Base Class for once function decorators."""

def __init__(self, func: collections.abc.Callable):
self._inspect_function(func)
functools.update_wrapper(self, func)
self.lock = _new_lock()
self.func = self._inspect_function(func)
self.called = False
self.return_value: typing.Any = None
self.func = func
self.is_asyncgen = inspect.isasyncgenfunction(self.func)
self.is_async = True if self.is_asyncgen else inspect.iscoroutinefunction(self.func)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment about why you chose iscoroutinefunction vs something like isawaitable would be good

if self.is_async:
self.async_lock = asyncio.Lock()
else:
self.lock = _new_lock()

@abc.abstractmethod
def _inspect_function(self, func: collections.abc.Callable):
def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable:
"""Inspect the passed-in function to ensure it can be wrapped.

This function should raise a SyntaxError if the passed-in function is
not suitable."""
not suitable.

def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs):
It should return the function which should be executed once.
"""

async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value
async with self.async_lock:
if self.called:
return self.return_value
if self.is_asyncgen:
self.return_value = tuple([i async for i in func(*args, **kwargs)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a big caveat that needs detailing in the documentation, that the generator result gets converted to a tuple. I think the more correct version, that's also more work, would wrap the generator in a generator that extends the return_value list each time next is called. This is potentially surprising enough to call it a bug that warrants exclusion from this library until it's fully supported.

else:
self.return_value = await func(*args, **kwargs)
self.called = True
return self.return_value

def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value
with self.lock:
if self.called:
return self.return_value
self.return_value = func(*args, **kwargs)
if isinstance(self.return_value, types.GeneratorType):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, this is a deal breaker IMO

self.return_value = tuple(self.return_value)
self.called = True
return self.return_value

def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs):
if self.is_async:
return self._execute_call_once_async(func, *args, **kwargs)
return self._execute_call_once_sync(func, *args, **kwargs)


class once(_OnceBase): # pylint: disable=invalid-name
"""Decorator to ensure a function is only called once.
Expand All @@ -74,6 +103,7 @@ def _inspect_function(self, func: collections.abc.Callable):
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
return func

def __call__(self, *args, **kwargs):
return self._execute_call_once(self.func, *args, **kwargs)
Expand All @@ -82,22 +112,35 @@ def __call__(self, *args, **kwargs):
class once_per_class(_OnceBase): # pylint: disable=invalid-name
"""A version of once for class methods which runs once across all instances."""

def _inspect_function(self, func):
is_classmethod: bool
is_staticmethod: bool

def _inspect_function(self, func: collections.abc.Callable):
if not _is_method(func):
raise SyntaxError(
"Attempting to use @once.once_per_class method-only decorator "
"instead of @once.once"
)
if isinstance(func, classmethod):
self.is_classmethod = True
self.is_staticmethod = False
return func.__func__
if isinstance(func, staticmethod):
self.is_classmethod = False
self.is_staticmethod = True
return func.__func__
self.is_classmethod = False
self.is_staticmethod = False
return func

# This is needed for a decorator on a class method to return a
# bound version of the function to the object or class.
def __get__(self, obj, cls):
if isinstance(self.func, classmethod):
func = functools.partial(self.func.__func__, cls)
if self.is_classmethod:
func = functools.partial(self.func, cls)
return functools.partial(self._execute_call_once, func)
if isinstance(self.func, staticmethod):
# The additional __func__ is required for python <= 3.9
return functools.partial(self._execute_call_once, self.func.__func__)
if self.is_staticmethod:
return functools.partial(self._execute_call_once, self.func)
return functools.partial(self._execute_call_once, self.func, obj)


Expand All @@ -119,6 +162,7 @@ def _inspect_function(self, func: collections.abc.Callable):
"Attempting to use @once.once_per_instance method-only decorator "
"instead of @once.once"
)
return func

# This is needed for a decorator on a class method to return a
# bound version of the function to the object.
Expand Down
83 changes: 83 additions & 0 deletions once_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def test_different_args_same_result(self):
self.assertEqual(counting_fn(2), 1)
self.assertEqual(counter.value, 1)

def test_iterator(self):
@once.once
def yielding_iterator():
for i in range(3):
yield i

self.assertEqual(yielding_iterator(), (0, 1, 2))
self.assertEqual(yielding_iterator(), (0, 1, 2))

def test_threaded_single_function(self):
counting_fn, counter = generate_once_counter_fn()
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
Expand Down Expand Up @@ -338,5 +347,79 @@ def value():
self.assertEqual(_CallOnceClass.value(), 1)


class TestOnceAsync(unittest.IsolatedAsyncioTestCase):
async def test_fn_called_once(self):
counter1 = Counter()

@once.once
async def counting_fn1():
return counter1.get_incremented()

counter2 = Counter()
# We should get a different value than the previous function
counter2.get_incremented()

@once.once
async def counting_fn2():
return counter2.get_incremented()

self.assertEqual(await counting_fn1(), 1)
self.assertEqual(await counting_fn1(), 1)
self.assertEqual(await counting_fn2(), 2)
self.assertEqual(await counting_fn2(), 2)

async def test_iterator(self):
counter = Counter()

@once.once
async def async_yielding_iterator():
yield counter.get_incremented()
for i in range(3):
yield i

self.assertEqual(await async_yielding_iterator(), (1, 0, 1, 2))
self.assertEqual(await async_yielding_iterator(), (1, 0, 1, 2))

async def test_once_per_class(self):
class _CallOnceClass(Counter):
@once.once_per_class
async def once_fn(self):
return self.get_incremented()

a = _CallOnceClass() # pylint: disable=invalid-name
b = _CallOnceClass() # pylint: disable=invalid-name

self.assertEqual(await a.once_fn(), 1)
self.assertEqual(await a.once_fn(), 1)
self.assertEqual(await b.once_fn(), 1)
self.assertEqual(await b.once_fn(), 1)

async def test_once_per_class_classmethod(self):
counter = Counter()

class _CallOnceClass:
@once.once_per_class
@classmethod
async def value(cls):
nonlocal counter
return counter.get_incremented()

self.assertEqual(await _CallOnceClass.value(), 1)
self.assertEqual(await _CallOnceClass.value(), 1)

async def test_once_per_class_staticmethod(self):
counter = Counter()

class _CallOnceClass:
@once.once_per_class
@staticmethod
async def value():
nonlocal counter
return counter.get_incremented()

self.assertEqual(await _CallOnceClass.value(), 1)
self.assertEqual(await _CallOnceClass.value(), 1)


if __name__ == "__main__":
unittest.main()