Skip to content

Commit

Permalink
Black formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
mattalbr committed Feb 12, 2024
1 parent a9af9e6 commit 083d7e4
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 22 deletions.
68 changes: 51 additions & 17 deletions once/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility for initialization ensuring functions are called only once."""

import abc
import asyncio
import collections.abc
Expand Down Expand Up @@ -126,7 +127,11 @@ async def wrapped(*args, _once_force_rerun: bool = False, **kwargs) -> typing.An
async with once_base.async_lock:
if not once_base.called:
once_base.return_value = _iterator_wrappers.AsyncGeneratorWrapper(
retry_exceptions, func, *args, allow_force_rerun=once_base.allow_force_rerun, **kwargs
retry_exceptions,
func,
*args,
allow_force_rerun=once_base.allow_force_rerun,
**kwargs,
)
once_base.called = True
return_value = once_base.return_value
Expand Down Expand Up @@ -181,7 +186,11 @@ def wrapped(*args, _once_force_rerun: bool = False, **kwargs) -> typing.Any:
with once_base.lock:
if not once_base.called:
once_base.return_value = _iterator_wrappers.GeneratorWrapper(
retry_exceptions, func, *args, allow_force_rerun=once_base.allow_force_rerun, **kwargs
retry_exceptions,
func,
*args,
allow_force_rerun=once_base.allow_force_rerun,
**kwargs,
)
once_base.called = True
iterator = once_base.return_value
Expand All @@ -195,6 +204,7 @@ def wrapped(*args, _once_force_rerun: bool = False, **kwargs) -> typing.Any:
if once_base.allow_force_rerun:
wrapped.force_rerun = functools.partial(wrapped, _once_force_rerun=True)
else:

def force_rerun(*args, **kwargs):
# This force_rerun won't necessarily have the right type, but it doesn't
# need to since it's an error case anyway. Just here for a more helpful
Expand All @@ -203,9 +213,9 @@ def force_rerun(*args, **kwargs):
f"force_rerun() is not allowed to be called on onced function {func}.\n"
"Did you mean to add `allow_force_rerun=True` to your once.once() annotation?"
)
wrapped.force_rerun = force_rerun

wrapped.force_rerun = force_rerun

functools.update_wrapper(wrapped, func)
return wrapped

Expand All @@ -229,7 +239,9 @@ def _get_once_per_thread():
return _get_once_per_thread


def once(*args, per_thread=False, retry_exceptions=False, allow_force_rerun=False) -> collections.abc.Callable:
def once(
*args, per_thread=False, retry_exceptions=False, allow_force_rerun=False
) -> collections.abc.Callable:
"""Decorator to ensure a function is only called once.
The restriction of only one call also holds across threads. However, this
Expand Down Expand Up @@ -273,15 +285,22 @@ def once(*args, per_thread=False, retry_exceptions=False, allow_force_rerun=Fals
# to create a decorator.
# Both @once and @once() will function correctly.
return functools.partial(
once, per_thread=per_thread, retry_exceptions=retry_exceptions, allow_force_rerun=allow_force_rerun
once,
per_thread=per_thread,
retry_exceptions=retry_exceptions,
allow_force_rerun=allow_force_rerun,
)
if _is_method(func):
raise SyntaxError(
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
fn_type = _wrapped_function_type(func)
once_factory = _once_factory(is_async=fn_type in _ASYNC_FN_TYPES, per_thread=per_thread, allow_force_rerun=allow_force_rerun)
once_factory = _once_factory(
is_async=fn_type in _ASYNC_FN_TYPES,
per_thread=per_thread,
allow_force_rerun=allow_force_rerun,
)
return _wrap(func, once_factory, fn_type, retry_exceptions)


Expand All @@ -292,8 +311,15 @@ class once_per_class: # pylint: disable=invalid-name
is_staticmethod: bool

@classmethod
def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_force_rerun=False):
return lambda func: cls(func, per_thread=per_thread, retry_exceptions=retry_exceptions, allow_force_rerun=allow_force_rerun)
def with_options(
cls, per_thread: bool = False, retry_exceptions=False, allow_force_rerun=False
):
return lambda func: cls(
func,
per_thread=per_thread,
retry_exceptions=retry_exceptions,
allow_force_rerun=allow_force_rerun,
)

def __init__(
self,
Expand All @@ -305,7 +331,9 @@ def __init__(
self.func = self._inspect_function(func)
self.fn_type = _wrapped_function_type(self.func)
self.once_factory = _once_factory(
is_async=self.fn_type in _ASYNC_FN_TYPES, per_thread=per_thread, allow_force_rerun=allow_force_rerun
is_async=self.fn_type in _ASYNC_FN_TYPES,
per_thread=per_thread,
allow_force_rerun=allow_force_rerun,
)
self.retry_exceptions = retry_exceptions

Expand Down Expand Up @@ -343,8 +371,12 @@ class once_per_instance: # pylint: disable=invalid-name
"""A version of once for class methods which runs once per instance."""

@classmethod
def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_force_rerun=False):
return lambda func: cls(func, per_thread=per_thread, retry_exceptions=retry_exceptions, allow_force_rerun=False)
def with_options(
cls, per_thread: bool = False, retry_exceptions=False, allow_force_rerun=False
):
return lambda func: cls(
func, per_thread=per_thread, retry_exceptions=retry_exceptions, allow_force_rerun=False
)

def __init__(
self,
Expand All @@ -357,9 +389,9 @@ def __init__(
self.fn_type = _wrapped_function_type(self.func)
self.is_async_fn = self.fn_type in _ASYNC_FN_TYPES
self.callables_lock = threading.Lock()
self.callables: weakref.WeakKeyDictionary[
typing.Any, collections.abc.Callable
] = weakref.WeakKeyDictionary()
self.callables: weakref.WeakKeyDictionary[typing.Any, collections.abc.Callable] = (
weakref.WeakKeyDictionary()
)
self.per_thread = per_thread
self.retry_exceptions = retry_exceptions
self.allow_force_rerun = allow_force_rerun
Expand All @@ -369,7 +401,9 @@ def once_factory(self) -> _ONCE_FACTORY_TYPE:
A once factory factory if you will.
"""
return _once_factory(self.is_async_fn, per_thread=self.per_thread, allow_force_rerun=self.allow_force_rerun)
return _once_factory(
self.is_async_fn, per_thread=self.per_thread, allow_force_rerun=self.allow_force_rerun
)

def _inspect_function(self, func: collections.abc.Callable):
if isinstance(func, (classmethod, staticmethod)):
Expand Down
7 changes: 6 additions & 1 deletion once/_iterator_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ class _GeneratorWrapperBase(abc.ABC):
"""

def __init__(
self, reset_on_exception: bool, func: collections.abc.Callable, allow_force_rerun: bool = False, *args, **kwargs
self,
reset_on_exception: bool,
func: collections.abc.Callable,
allow_force_rerun: bool = False,
*args,
**kwargs,
) -> None:
self.callable: collections.abc.Callable | None = functools.partial(func, *args, **kwargs)
self.generator = self.callable()
Expand Down
12 changes: 8 additions & 4 deletions once_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Unit tests for once decorators."""

# pylint: disable=missing-function-docstring
import asyncio
import collections.abc
Expand Down Expand Up @@ -384,9 +385,10 @@ def test_partial(self):
func = once.once(functools.partial(lambda _: counter.get_incremented(), None))
self.assertEqual(func(), 1)
self.assertEqual(func(), 1)

def test_force_rerun(self):
counter = Counter()

@once.once(allow_force_rerun=True)
def counting_fn():
return counter.get_incremented()
Expand All @@ -395,7 +397,7 @@ def counting_fn():
self.assertEqual(counting_fn.force_rerun(), 2)
self.assertEqual(counting_fn(), 2)
self.assertEqual(counting_fn.force_rerun(), 3)

def test_force_rerun_not_allowed(self):
counting_fn, counter = generate_once_counter_fn()
self.assertEqual(counting_fn(None), 1)
Expand Down Expand Up @@ -996,6 +998,7 @@ async def counting_fn2():

async def test_force_rerun(self):
counter = Counter()

@once.once(allow_force_rerun=True)
async def counting_fn():
return counter.get_incremented()
Expand All @@ -1006,6 +1009,7 @@ async def counting_fn():

async def test_force_rerun_not_allowed(self):
counter = Counter()

@once.once
async def counting_fn():
return counter.get_incremented()
Expand Down Expand Up @@ -1117,7 +1121,7 @@ async def async_yielding_iterator():

self.assertEqual([i async for i in async_yielding_iterator()], [1, 2, 3])
self.assertEqual([i async for i in async_yielding_iterator()], [1, 2, 3])

async def test_iterator_force_rerun(self):
counter = Counter()

Expand All @@ -1129,7 +1133,7 @@ async def async_yielding_iterator():
self.assertEqual([i async for i in async_yielding_iterator()], [1, 2, 3])
self.assertEqual([i async for i in async_yielding_iterator.force_rerun()], [4, 5, 6])
self.assertEqual([i async for i in async_yielding_iterator()], [4, 5, 6])

async def test_iterator_force_rerun_not_allowed(self):
counter = Counter()

Expand Down

0 comments on commit 083d7e4

Please sign in to comment.