Skip to content

Commit

Permalink
Make per_thread an option for once.
Browse files Browse the repository at this point in the history
More elegant than separate functions, especially as we consider other
behavior modification options in the future.
  • Loading branch information
aebrahim committed Oct 18, 2023
1 parent ade1135 commit 059c90c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 40 deletions.
64 changes: 28 additions & 36 deletions once/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def _wrap(
) -> collections.abc.Callable:
"""Generate a wrapped function appropriate to the function type.
The once_factory lets us reuse logic for both once and once_per_thread.
For once, the factory always returns the same _OnceBase object, but for
once_per_thread, it would return a unique one for each thread.
The once_factory lets us reuse logic for both per-thread and singleton.
For a singleton, the factory always returns the same _OnceBase object, but
for per thread, it would return a unique one for each thread.
"""
# Theoretically, we could compute fn_type now. However, this code may be executed at runtime
# OR at definition time (due to once_per_instance), and we want to only be doing reflection at
Expand Down Expand Up @@ -195,7 +195,7 @@ def _once_factory(is_async: bool, per_thread: bool) -> _ONCE_FACTORY_TYPE:
return lambda: singleton_once


def once(func: collections.abc.Callable):
def once(*args, per_thread=False) -> collections.abc.Callable:
"""Decorator to ensure a function is only called once.
The restriction of only one call also holds across threads. However, this
Expand All @@ -217,20 +217,22 @@ def once(func: collections.abc.Callable):
module and class level functions (i.e. non-closures), this means the return
value will never be deleted.
"""
if len(args) == 1:
func: collections.abc.Callable = args[0]
elif len(args) > 1:
raise ValueError("Up to 1 argument expected.")
else:
# This trick lets this function be a decorator directly, or be called
# to create a decorator.
# Both @once and @once() will function correctly.
return functools.partial(once, per_thread=per_thread)
if _is_method(func):
raise SyntaxError(
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
fn_type = _wrapped_function_type(func)
once_factory = _once_factory(is_async=fn_type in _ASYNC_FN_TYPES, per_thread=False)
return _wrap(func, once_factory, fn_type)


def once_per_thread(func: collections.abc.Callable):
"""A version of once which executes only once per thread."""
fn_type = _wrapped_function_type(func)
once_factory = _once_factory(is_async=fn_type in _ASYNC_FN_TYPES, per_thread=True)
once_factory = _once_factory(is_async=fn_type in _ASYNC_FN_TYPES, per_thread=per_thread)
return _wrap(func, once_factory, fn_type)


Expand All @@ -240,11 +242,15 @@ class once_per_class: # pylint: disable=invalid-name
is_classmethod: bool
is_staticmethod: bool

def __init__(self, func: collections.abc.Callable) -> None:
@classmethod
def with_options(cls, per_thread: bool = False):
return lambda func: cls(func, per_thread=per_thread)

def __init__(self, func: collections.abc.Callable, per_thread: bool = False) -> None:
self.func = self._inspect_function(func)
self.fn_type = _wrapped_function_type(self.func)
self.once_factory = _once_factory(
is_async=self.fn_type in _ASYNC_FN_TYPES, per_thread=False
is_async=self.fn_type in _ASYNC_FN_TYPES, per_thread=per_thread
)

def _inspect_function(self, func: collections.abc.Callable):
Expand Down Expand Up @@ -277,29 +283,29 @@ def __get__(self, obj, cls) -> collections.abc.Callable:
return _wrap(func, self.once_factory, self.fn_type)


class once_per_class_per_thread(once_per_class): # pylint: disable=invalid-name
def __init__(self, func: collections.abc.Callable) -> None:
super().__init__(func)
self.once_factory = _once_factory(self.fn_type in _ASYNC_FN_TYPES, per_thread=True)
class once_per_instance: # pylint: disable=invalid-name
"""A version of once for class methods which runs once per instance."""

@classmethod
def with_options(cls, per_thread: bool = False):
return lambda func: cls(func, per_thread=per_thread)

class _OncePerInstanceBase(abc.ABC):
def __init__(self, func: collections.abc.Callable) -> None:
def __init__(self, func: collections.abc.Callable, per_thread: bool = False) -> None:
self.func = self._inspect_function(func)
self.fn_type = _wrapped_function_type(self.func)
self.is_async_fn = self.fn_type in _ASYNC_FN_TYPES
self.callables_lock = threading.Lock()
self.callables: weakref.WeakKeyDictionary[
typing.Any, collections.abc.Callable
] = weakref.WeakKeyDictionary()
self.per_thread = per_thread

@abc.abstractmethod
def once_factory(self) -> _ONCE_FACTORY_TYPE:
"""Generate a new once factory.
A once factory factory if you will.
"""
raise NotImplementedError()
return _once_factory(self.is_async_fn, per_thread=self.per_thread)

def _inspect_function(self, func: collections.abc.Callable):
if isinstance(func, (classmethod, staticmethod)):
Expand All @@ -321,17 +327,3 @@ def __get__(self, obj, cls) -> collections.abc.Callable:
callable = _wrap(bound_func, self.once_factory(), self.fn_type)
self.callables[obj] = callable
return callable


class once_per_instance(_OncePerInstanceBase): # pylint: disable=invalid-name
"""A version of once for class methods which runs once per instance."""

def once_factory(self):
return _once_factory(self.is_async_fn, per_thread=False)


class once_per_instance_per_thread(_OncePerInstanceBase): # pylint: disable=invalid-name
"""A version of once for class methods which runs once per instance per thread."""

def once_factory(self):
return _once_factory(self.is_async_fn, per_thread=True)
8 changes: 4 additions & 4 deletions once_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def test_once_per_thread(self):
counter = Counter()

@execute_with_barrier(n_workers=_N_WORKERS)
@once.once_per_thread
@once.once(per_thread=True)
def counting_fn(*args) -> int:
"""Returns the call count, which should always be 1."""
nonlocal counter
Expand Down Expand Up @@ -555,7 +555,7 @@ def execute(_):

def test_once_per_class_per_thread(self):
class _CallOnceClass(Counter):
@once.once_per_class_per_thread
@once.once_per_class.with_options(per_thread=True)
def once_fn(self):
return self.get_incremented()

Expand Down Expand Up @@ -681,7 +681,7 @@ def execute(i):

def test_once_per_instance_per_thread(self):
class _CallOnceClass(Counter):
@once.once_per_instance_per_thread
@once.once_per_instance.with_options(per_thread=True)
def once_fn(self):
return self.get_incremented()

Expand Down Expand Up @@ -817,7 +817,7 @@ async def counting_fn2():
async def test_once_per_thread(self):
counter = Counter()

@once.once_per_thread
@once.once(per_thread=True)
async def counting_fn(*args) -> int:
"""Returns the call count, which should always be 1."""
nonlocal counter
Expand Down

0 comments on commit 059c90c

Please sign in to comment.