diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..9d9eb46 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,31 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/python +{ + "name": "Python 3", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye", + "customizations": { + "vscode": { + "extensions": [ + "ms-python.python", + "github.vscode-github-actions" + ] + } + } + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Use 'postCreateCommand' to run commands after the container is created. + // "postCreateCommand": "pip3 install --user -r requirements.txt", + + // Configure tool-specific properties. + // "customizations": {}, + + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} + diff --git a/.github/workflows/annotate_pr.yml b/.github/workflows/annotate_pr.yml new file mode 100644 index 0000000..c5d9554 --- /dev/null +++ b/.github/workflows/annotate_pr.yml @@ -0,0 +1,30 @@ +name: Annotate Pull Request + +# In a separate workflow because of +# https://securitylab.github.com/research/github-actions-preventing-pwn-requests/ +on: + workflow_run: + workflows: + - Python test + types: + - completed + +jobs: + python_test_reporter: + runs-on: ubuntu-latest + permissions: + checks: write + if: github.event.workflow_run.name == 'Python test' + strategy: + fail-fast: false + matrix: + os: [ubuntu-20.04, windows-2019, macos-11] + python-version: ["3.10", "3.11", "3.12-dev"] + steps: + - name: Test Report + uses: dorny/test-reporter@v1 + with: + name: test-py${{ matrix.python-version}}-on-${{ matrix.os }} + reporter: java-junit + artifact: test-py${{ matrix.python-version}}-on-${{ matrix.os }} + path: test_py${{ matrix.python-version }}_on_${{ matrix.os }}.xml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..f319956 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,55 @@ +name: Release Wheels + +on: + push: + release: + types: + - published + +jobs: + build_wheel: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: 3.11 + cache: pip + - run: pip wheel . + - uses: actions/upload-artifact@v3 + with: + name: wheel + path: ./*.whl + + upload_wheel_test: + needs: [build_wheel] + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/once-py + permissions: + id-token: write + steps: + - uses: actions/download-artifact@v3 + with: + name: wheel + path: dist/ + - uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ + + upload_wheel: + needs: [build_wheel] + runs-on: ubuntu-latest + if: github.event_name == 'release' && github.event.action == 'published' + environment: + name: pypi + url: https://pypi.org/p/once-py + permissions: + id-token: write + steps: + - uses: actions/download-artifact@v3 + with: + name: wheel + path: dist/ + - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..cf7e0a1 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,46 @@ +name: Python test + +on: [push, pull_request] + +jobs: + test: + name: test-py${{ matrix.python-version}}-on-${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-20.04, windows-2019, macos-11] + python-version: ["3.10", "3.11", "3.12-dev", "pypy3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - run: pip install pytest + - run: pytest . --junitxml=junit/test_py${{ matrix.python-version }}_on_${{ matrix.os }}.xml + - name: Upload pytest test results + uses: actions/upload-artifact@v3 + if: success() || failure() + with: + name: test-py${{ matrix.python-version}}-on-${{ matrix.os }} + path: junit/test_py${{ matrix.python-version }}_on_${{ matrix.os }}.xml + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: psf/black@stable + + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: 3.11 + cache: pip + - run: | + pip install mypy + mypy . diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f5899d3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +build/ +dist/ +*.egg-info +__pycache__ +.DS_Store +*.whl +_version.py +.pytest_cache/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3c541dc --- /dev/null +++ b/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2023 Delfina Care Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..4ca7c78 --- /dev/null +++ b/README.md @@ -0,0 +1,38 @@ +# Once + +This library provides functionality to ensure a function is called exactly +once in Python, heavily inspired by `std::call_once`. + +During initialization, we often want to ensure code is run **exactly once**. +But thinking about all the different ways this constraint can be violated can +be time-consuming and complex. We don't want to have to reason about what other +callers are doing and from which thread. + +Introducing a simple solution - the `once.once` decorator! Simply decorate a +function with this decorator, and this library will handle all the edge cases +to ensure it is called exactly once! The first call will invoke the function, +and all subsequent calls will return the same result. Enough talking, let's +cut to an example: + +```python +import once + +@once.once +def my_expensive_object(): + load_expensive_resource() + load_more_expensive_resources() + return ObjectSingletonUsingLotsOfMemory() + +def caller_one(): + my_expensive_object().use_it() + +def caller_two_from_a_separate_thread(): + my_expensive_object().use_it() + +def optional_init_function_to_prewarm(): + my_expensive_object() + +``` + +This module is extremely simple, with no external dependencies, and heavily +tested for races. diff --git a/once.py b/once.py new file mode 100644 index 0000000..9ed01c2 --- /dev/null +++ b/once.py @@ -0,0 +1,157 @@ +"""Utility for initialization ensuring functions are called only once.""" +import abc +import inspect +import functools +import threading +import weakref + + +def _new_lock() -> threading.Lock: + return threading.Lock() + + +def _is_method(func): + """Determine if a function is a method on a class.""" + if isinstance(func, (classmethod, staticmethod)): + return True + sig = inspect.signature(func) + return "self" in sig.parameters + + +class _OnceBase(abc.ABC): + """Abstract Base Class for once function decorators.""" + + def __init__(self, func): + self._inspect_function(func) + functools.update_wrapper(self, func) + self.lock = _new_lock() + self.called = False + self.return_value = None + self.func = func + + @abc.abstractmethod + def _inspect_function(self, func): + """Inspect the passed-in function to ensure it can be wrapped. + + This function should raise a SyntaxError if the passed-in function is + not suitable.""" + + def _execute_call_once(self, func, *args, **kwargs): + if self.called: + return self.return_value + with self.lock: + if self.called: + return self.return_value + self.return_value = func(*args, **kwargs) + self.called = True + return self.return_value + + +class once(_OnceBase): # pylint: disable=invalid-name + """Decorator to ensure a function is only called once. + + The restriction of only one call also holds across threads. However, this + restriction does not apply to unsuccessful function calls. If the function + raises an exception, the next call will invoke a new call to the function. + If the function is called with multiple arguments, it will still only be + called only once. + + This decorator will fail for methods defined on a class. Use + once_per_class or once_per_instance for methods on a class instead. + + Please note that because the value returned by the decorated function is + stored to return for subsequent calls, it will not be eligible for garbage + collection until after the decorated function itself has been deleted. For + module and class level functions (i.e. non-closures), this means the return + value will never be deleted. + """ + + def _inspect_function(self, func): + if _is_method(func): + raise SyntaxError( + "Attempting to use @once.once decorator on method " + "instead of @once.once_per_class or @once.once_per_instance" + ) + + def __call__(self, *args, **kwargs): + return self._execute_call_once(self.func, *args, **kwargs) + + +class once_per_class(_OnceBase): # pylint: disable=invalid-name + """A version of once for class methods which runs once across all instances.""" + + def _inspect_function(self, func): + if not _is_method(func): + raise SyntaxError( + "Attempting to use @once.once_per_class method-only decorator " + "instead of @once.once" + ) + + # This is needed for a decorator on a class method to return a + # bound version of the function to the object or class. + def __get__(self, obj, cls): + if isinstance(self.func, classmethod): + func = functools.partial(self.func.__func__, cls) + return functools.partial(self._execute_call_once, func) + if isinstance(self.func, staticmethod): + return functools.partial(self._execute_call_once, self.func) + return functools.partial(self._execute_call_once, self.func, obj) + + +class once_per_instance(_OnceBase): # pylint: disable=invalid-name + """A version of once for class methods which runs once per instance.""" + + def __init__(self, func): + super().__init__(func) + self.return_value = weakref.WeakKeyDictionary() + self.inflight_lock = {} + + def _inspect_function(self, func): + if isinstance(func, (classmethod, staticmethod)): + raise SyntaxError("Must use @once.once_per_class on classmethod and staticmethod") + if not _is_method(func): + raise SyntaxError( + "Attempting to use @once.once_per_instance method-only decorator " + "instead of @once.once" + ) + + # This is needed for a decorator on a class method to return a + # bound version of the function to the object. + def __get__(self, obj, cls): + del cls + return functools.partial(self._execute_call_once_per_instance, obj) + + def _execute_call_once_per_instance(self, obj, *args, **kwargs): + # We only append to the call history, and do not overwrite or remove keys. + # Therefore, we can check the call history without a lock for an early + # exit. + # Another concern might be the weakref dictionary for return_value + # getting garbage collected without a lock. However, because + # user_function references whichever key it matches, it cannot be + # garbage collected during this call. + if obj in self.return_value: + return self.return_value[obj] + with self.lock: + if obj in self.return_value: + return self.return_value[obj] + if obj in self.inflight_lock: + inflight_lock = self.inflight_lock[obj] + else: + inflight_lock = _new_lock() + self.inflight_lock[obj] = inflight_lock + # Now we have a per-object lock. This means that we will not block + # other instances. In addition to better performance, this reduces the + # potential for deadlocks. + with inflight_lock: + if obj in self.return_value: + return self.return_value[obj] + result = self.func(obj, *args, **kwargs) + self.return_value[obj] = result + # At this point, any new call will find a cache hit before + # even grabbing a lock. It is now safe to clean up the inflight + # lock entry from the dictionary, as all subsequent will not need + # it. Any other previously called inflight requests already have + # their reference to the lock object, and do not need it present + # in this dict either. + self.inflight_lock.pop(obj) + return result diff --git a/once_test.py b/once_test.py new file mode 100644 index 0000000..55783a3 --- /dev/null +++ b/once_test.py @@ -0,0 +1,342 @@ +"""Unit tests for once decorators.""" +# pylint: disable=missing-function-docstring +import concurrent.futures +import inspect +import time +import unittest +from unittest import mock +import threading +import weakref +import gc + +import once + + +class Counter: + """Holding object for a counter. + + If we return an integer directly, it will simply return a copy and + will not update as the number of calls increases. + """ + + def __init__(self) -> None: + self.value = 0 + + def get_incremented(self) -> int: + self.value += 1 + return self.value + + +def generate_once_counter_fn(): + """Generates a once.once decorated function which counts its calls.""" + + counter = Counter() + + @once.once + def counting_fn(*args) -> int: + """Returns the call count, which should always be 1.""" + nonlocal counter + del args + return counter.get_incremented() + + return counting_fn, counter + + +class TestOnce(unittest.TestCase): + """Unit tests for once decorators.""" + + def test_counter_works(self): + """Ensure the counter text fixture works.""" + counter = Counter() + self.assertEqual(counter.value, 0) + self.assertEqual(counter.get_incremented(), 1) + self.assertEqual(counter.value, 1) + self.assertEqual(counter.get_incremented(), 2) + self.assertEqual(counter.value, 2) + + def test_different_args_same_result(self): + counting_fn, counter = generate_once_counter_fn() + self.assertEqual(counting_fn(1), 1) + self.assertEqual(counter.value, 1) + # Should return the same result as the first call. + self.assertEqual(counting_fn(2), 1) + self.assertEqual(counter.value, 1) + + def test_threaded_single_function(self): + counting_fn, counter = generate_once_counter_fn() + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + results = list(executor.map(counting_fn, range(32))) + self.assertEqual(len(results), 32) + for r in results: + self.assertEqual(r, 1) + self.assertEqual(counter.value, 1) + + def test_threaded_multiple_functions(self): + counters = [] + fns = [] + + for _ in range(4): + cfn, counter = generate_once_counter_fn() + counters.append(counter) + fns.append(cfn) + + promises = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + for cfn in fns: + for _ in range(16): + promises.append(executor.submit(cfn)) + del cfn + fns.clear() + for promise in promises: + self.assertEqual(promise.result(), 1) + for counter in counters: + self.assertEqual(counter.value, 1) + + def test_different_fn_do_not_deadlock(self): + """Ensure different functions use different locks to avoid deadlock.""" + + fn2_called = False + + # If fn1 is called first, these functions will deadlock unless they can + # run in parallel. + @once.once + def fn1(): + nonlocal fn2_called + start = time.time() + while not fn2_called: + if time.time() - start > 5: + self.fail("Function fn1 deadlocked for 5 seconds.") + time.sleep(0.01) + + @once.once + def fn2(): + nonlocal fn2_called + fn2_called = True + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + fn1_promise = executor.submit(fn1) + executor.submit(fn2) + fn1_promise.result() + + def test_closure_gc(self): + """Tests that closure function is not cached indefinitely""" + + class EphemeralObject: + """Object which should get GC'ed""" + + def create_closure(): + ephemeral = EphemeralObject() + ephemeral_ref = weakref.ref(ephemeral) + + @once.once + def closure(): + return ephemeral + + return closure, ephemeral_ref + + closure, ephemeral_ref = create_closure() + + # Cannot yet be garbage collected because kept alive in the closure. + self.assertIsNotNone(ephemeral_ref()) + self.assertIsInstance(closure(), EphemeralObject) + self.assertIsNotNone(ephemeral_ref()) + self.assertIsInstance(closure(), EphemeralObject) + del closure + # Can now be garbage collected. + # In CPython this call technically should not be needed, because + # garbage collection should have happened automatically. However, that + # is an implementation detail which does not hold on all platforms, + # such as for example pypy. Therefore, we manually trigger a garbage + # collection cycle. + gc.collect() + self.assertIsNone(ephemeral_ref()) + + @mock.patch.object(once, "_new_lock") + def test_lock_bypass(self, lock_mocker) -> None: + """Test both with and without lock bypass cache lookup.""" + + # We mock the lock to return our specific lock, so we can specifically + # test behavior with it held. + lock = threading.Lock() + lock_mocker.return_value = lock + + counter = Counter() + + @once.once + def sample_fn() -> int: + nonlocal counter + return counter.get_incremented() + + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + with lock: + potential_first_call_promises = [executor.submit(sample_fn) for i in range(32)] + # Give the promises enough time to finish, if they were not blocked. + # The test will still pass without this, but we wouldn't be + # testing anything. + time.sleep(0.01) + # At this point, all of the promises will be waiting for the lock, + # and none of them will have completed. + for promise in potential_first_call_promises: + self.assertFalse(promise.done()) + # Now that we have released the lock, all of these should complete. + for promise in potential_first_call_promises: + self.assertEqual(promise.result(), 1) + self.assertEqual(counter.value, 1) + # Now that we know the function has already been called, we should + # be able to get a result without waiting for a lock. + with lock: + bypass_lock_promises = [executor.submit(sample_fn) for i in range(32)] + for promise in bypass_lock_promises: + self.assertEqual(promise.result(), 1) + self.assertEqual(counter.value, 1) + + def test_function_signature_preserved(self): + @once.once + def type_annotated_fn(arg: float) -> int: + """Very descriptive docstring.""" + del arg + return 1 + + sig = inspect.signature(type_annotated_fn) + self.assertIs(sig.parameters["arg"].annotation, float) + self.assertIs(sig.return_annotation, int) + self.assertEqual(type_annotated_fn.__doc__, "Very descriptive docstring.") + + def test_once_per_class(self): + class _CallOnceClass(Counter): + @once.once_per_class + def once_fn(self): + return self.get_incremented() + + a = _CallOnceClass() # pylint: disable=invalid-name + b = _CallOnceClass() # pylint: disable=invalid-name + + self.assertEqual(a.once_fn(), 1) + self.assertEqual(a.once_fn(), 1) + self.assertEqual(b.once_fn(), 1) + self.assertEqual(b.once_fn(), 1) + + def test_once_not_allowed_on_method(self): + with self.assertRaises(SyntaxError): + + class _InvalidClass: # pylint: disable=unused-variable + @once.once + def once_method(self): + pass + + def test_once_per_instance_not_allowed_on_function(self): + with self.assertRaises(SyntaxError): + + @once.once_per_instance + def once_fn(): + pass + + def test_once_per_class_not_allowed_on_classmethod(self): + with self.assertRaises(SyntaxError): + + class _InvalidClass: # pylint: disable=unused-variable + @once.once_per_instance + @classmethod + def once_method(cls): + pass + + def test_once_per_class_not_allowed_on_staticmethod(self): + with self.assertRaises(SyntaxError): + + class _InvalidClass: # pylint: disable=unused-variable + @once.once_per_instance + @staticmethod + def once_method(): + pass + + def test_once_per_instance(self): + class _CallOnceClass: + def __init__(self, value: str, test: unittest.TestCase): + self._value = value + self.called = False + self.test = test + + @once.once_per_instance + def value(self): # pylint: disable=inconsistent-return-statements + if not self.called: + self.called = True + return self._value + if self.called: + self.test.fail(f"Method on {self.value} called a second time.") + + a = _CallOnceClass("a", self) # pylint: disable=invalid-name + b = _CallOnceClass("b", self) # pylint: disable=invalid-name + + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + a_jobs = [executor.submit(a.value) for _ in range(16)] + b_jobs = [executor.submit(b.value) for _ in range(16)] + for a_job in a_jobs: + self.assertEqual(a_job.result(), "a") + for b_job in b_jobs: + self.assertEqual(b_job.result(), "b") + + self.assertEqual(a.value(), "a") + self.assertEqual(a.value(), "a") + self.assertEqual(b.value(), "b") + self.assertEqual(b.value(), "b") + + def test_once_per_instance_do_not_block_each_other(self): + class _BlockableClass: + def __init__(self, test: unittest.TestCase): + self.lock = threading.Lock() + self.test = test + self.started = False + self.counter = Counter() + + @once.once_per_instance + def run(self) -> int: + self.started = True + with self.lock: + pass + return self.counter.get_incremented() + + a = _BlockableClass(self) + b = _BlockableClass(self) + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + with a.lock: + a_job = executor.submit(a.run) + while not a.started: + pass + # At this point, the A job has started. However, it cannot + # complete while we hold its lock. Despite this, we want to ensure + # that B can still run. + b_job = executor.submit(b.run) + # The b_job will deadlock and this will fail if different + # object executions block each other. + self.assertEqual(b_job.result(timeout=5), 1) + self.assertEqual(a_job.result(timeout=5), 1) + + def test_once_per_class_classmethod(self): + counter = Counter() + + class _CallOnceClass: + @once.once_per_class + @classmethod + def value(cls): + nonlocal counter + return counter.get_incremented() + + self.assertEqual(_CallOnceClass.value(), 1) + self.assertEqual(_CallOnceClass.value(), 1) + + def test_once_per_class_staticmethod(self): + counter = Counter() + + class _CallOnceClass: + @once.once_per_class + @staticmethod + def value(): + nonlocal counter + return counter.get_incremented() + + self.assertEqual(_CallOnceClass.value(), 1) + self.assertEqual(_CallOnceClass.value(), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1495215 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[tool.hatch.version] +source = "vcs" + +[tool.hatch.version.raw-options] +local_scheme = "no-local-version" + +[tool.hatch.metadata.hooks.vcs.urls] +source_archive = "https://github.com/DelfinaCare/once/archive/{commit_hash}.zip" + +[tool.black] +line-length = 100 + +[project] +name = "once-py" +dynamic = ["version"] +authors = [ + { name="Ali.Ebrahim", email="ali@delfina.com" }, +] +description = "Utility for initialization ensuring functions are called only once" +readme = "README.md" +license = "MIT" +requires-python = ">=3.10" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +[project.urls] +"Homepage" = "https://github.com/DelfinaCare/once" +"Bug Tracker" = "https://github.com/DelfinaCare/once/issues"