-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement support for async and iterators. #3
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
"""Utility for initialization ensuring functions are called only once.""" | ||
import abc | ||
import asyncio | ||
import collections.abc | ||
import functools | ||
import inspect | ||
import threading | ||
import typing | ||
import types | ||
import weakref | ||
|
||
|
||
|
@@ -24,30 +26,57 @@ class _OnceBase(abc.ABC): | |
"""Abstract Base Class for once function decorators.""" | ||
|
||
def __init__(self, func: collections.abc.Callable): | ||
self._inspect_function(func) | ||
functools.update_wrapper(self, func) | ||
self.lock = _new_lock() | ||
self.func = self._inspect_function(func) | ||
self.called = False | ||
self.return_value: typing.Any = None | ||
self.func = func | ||
self.is_asyncgen = inspect.isasyncgenfunction(self.func) | ||
self.is_async = True if self.is_asyncgen else inspect.iscoroutinefunction(self.func) | ||
if self.is_async: | ||
self.async_lock = asyncio.Lock() | ||
else: | ||
self.lock = _new_lock() | ||
|
||
@abc.abstractmethod | ||
def _inspect_function(self, func: collections.abc.Callable): | ||
def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable: | ||
"""Inspect the passed-in function to ensure it can be wrapped. | ||
|
||
This function should raise a SyntaxError if the passed-in function is | ||
not suitable.""" | ||
not suitable. | ||
|
||
def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): | ||
It should return the function which should be executed once. | ||
""" | ||
|
||
async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs): | ||
if self.called: | ||
return self.return_value | ||
async with self.async_lock: | ||
if self.called: | ||
return self.return_value | ||
if self.is_asyncgen: | ||
self.return_value = tuple([i async for i in func(*args, **kwargs)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a big caveat that needs detailing in the documentation, that the generator result gets converted to a tuple. I think the more correct version, that's also more work, would wrap the generator in a generator that extends the return_value list each time next is called. This is potentially surprising enough to call it a bug that warrants exclusion from this library until it's fully supported. |
||
else: | ||
self.return_value = await func(*args, **kwargs) | ||
self.called = True | ||
return self.return_value | ||
|
||
def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs): | ||
if self.called: | ||
return self.return_value | ||
with self.lock: | ||
if self.called: | ||
return self.return_value | ||
self.return_value = func(*args, **kwargs) | ||
if isinstance(self.return_value, types.GeneratorType): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto, this is a deal breaker IMO |
||
self.return_value = tuple(self.return_value) | ||
self.called = True | ||
return self.return_value | ||
|
||
def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): | ||
if self.is_async: | ||
return self._execute_call_once_async(func, *args, **kwargs) | ||
return self._execute_call_once_sync(func, *args, **kwargs) | ||
|
||
|
||
class once(_OnceBase): # pylint: disable=invalid-name | ||
"""Decorator to ensure a function is only called once. | ||
|
@@ -74,6 +103,7 @@ def _inspect_function(self, func: collections.abc.Callable): | |
"Attempting to use @once.once decorator on method " | ||
"instead of @once.once_per_class or @once.once_per_instance" | ||
) | ||
return func | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self._execute_call_once(self.func, *args, **kwargs) | ||
|
@@ -82,22 +112,35 @@ def __call__(self, *args, **kwargs): | |
class once_per_class(_OnceBase): # pylint: disable=invalid-name | ||
"""A version of once for class methods which runs once across all instances.""" | ||
|
||
def _inspect_function(self, func): | ||
is_classmethod: bool | ||
is_staticmethod: bool | ||
|
||
def _inspect_function(self, func: collections.abc.Callable): | ||
if not _is_method(func): | ||
raise SyntaxError( | ||
"Attempting to use @once.once_per_class method-only decorator " | ||
"instead of @once.once" | ||
) | ||
if isinstance(func, classmethod): | ||
self.is_classmethod = True | ||
self.is_staticmethod = False | ||
return func.__func__ | ||
if isinstance(func, staticmethod): | ||
self.is_classmethod = False | ||
self.is_staticmethod = True | ||
return func.__func__ | ||
self.is_classmethod = False | ||
self.is_staticmethod = False | ||
return func | ||
|
||
# This is needed for a decorator on a class method to return a | ||
# bound version of the function to the object or class. | ||
def __get__(self, obj, cls): | ||
if isinstance(self.func, classmethod): | ||
func = functools.partial(self.func.__func__, cls) | ||
if self.is_classmethod: | ||
func = functools.partial(self.func, cls) | ||
return functools.partial(self._execute_call_once, func) | ||
if isinstance(self.func, staticmethod): | ||
# The additional __func__ is required for python <= 3.9 | ||
return functools.partial(self._execute_call_once, self.func.__func__) | ||
if self.is_staticmethod: | ||
return functools.partial(self._execute_call_once, self.func) | ||
return functools.partial(self._execute_call_once, self.func, obj) | ||
|
||
|
||
|
@@ -119,6 +162,7 @@ def _inspect_function(self, func: collections.abc.Callable): | |
"Attempting to use @once.once_per_instance method-only decorator " | ||
"instead of @once.once" | ||
) | ||
return func | ||
|
||
# This is needed for a decorator on a class method to return a | ||
# bound version of the function to the object. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a comment about why you chose iscoroutinefunction vs something like isawaitable would be good