diff --git a/once_test.py b/once_test.py index 74e096f..b620a44 100644 --- a/once_test.py +++ b/once_test.py @@ -1,15 +1,18 @@ """Unit tests for once decorators.""" # pylint: disable=missing-function-docstring import asyncio +import collections.abc import concurrent.futures import functools import gc import inspect +import itertools import math import sys import threading import time import unittest +import uuid import weakref import once @@ -31,6 +34,74 @@ async def anext(iter, default=StopAsyncIteration): _N_WORKERS = 16 +class WrappedException: + def __init__(self, exception): + self.exception = exception + + +def parallel_map( + test: unittest.TestCase, + func: collections.abc.Callable, + # would be collections.abc.Iterable[tuple] | None on py >= 3.10 + call_args=None, + n_threads: int = _N_WORKERS, + timeout: float = 10.0, +) -> list: + """Run a function multiple times in parallel. + + We ensure that N parallel tasks are all launched at the "same time", which + means all have parallel threads which are released to the GIL to execute at + the same time. + Why? + We can't rely on the thread pool excector to always spin up the full list of _N_WORKERS. + In pypy, we have observed that even with blocked tasks, the same thread executes multiple + function calls. This lets us handle the scheduling in a predictable way for testing. + """ + if call_args is None: + call_args = (tuple() for _ in range(n_threads)) + + batches = [[] for i in range(n_threads)] # type: list[list[tuple[int, tuple]]] + for i, call_args in enumerate(call_args): + if not isinstance(call_args, tuple): + raise TypeError("call arguments must be a tuple") + batches[i % n_threads].append((i, call_args)) + n_calls = i + 1 + unset = object() + results_lock = threading.Lock() + results = [unset for _ in range(n_calls)] + + # This barrier is used to ensure that all calls release together, after this function has + # completed its setup of creating them. + start_barrier = threading.Barrier(min(len(batches), n_calls)) + + def wrapped_fn(batch): + start_barrier.wait() + for index, args in batch: + try: + result = func(*args) + except Exception as e: + result = WrappedException(e) + with results_lock: + results[index] = result + + # We manually set thread names for easier debugging. + invocation_id = str(uuid.uuid4()) + threads = [ + threading.Thread(target=wrapped_fn, args=[batch], name=f"{test.id()}-{i}-{invocation_id}") + for i, batch in enumerate(batches) + ] + for t in threads: + t.start() + for t in threads: + t.join(timeout=timeout) + for i, result in enumerate(results): + if result is unset: + test.fail(f"Call {i} did not complete succesfully") + elif isinstance(result, WrappedException): + raise result.exception + return results + + class Counter: """Holding object for a counter. @@ -388,8 +459,11 @@ def yielding_iterator(): for _ in range(3): yield counter.get_incremented() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = executor.map(lambda _: list(yielding_iterator()), range(_N_WORKERS * 2)) + results = parallel_map( + self, + lambda: list(yielding_iterator()), + (tuple() for _ in range(_N_WORKERS * 2)), + ) for result in results: self.assertEqual(result, [1, 2, 3]) @@ -422,9 +496,7 @@ def yielding_iterator(): def test_threaded_single_function(self): counting_fn, counter = generate_once_counter_fn() barrier_counting_fn = execute_with_barrier(counting_fn, n_workers=_N_WORKERS) - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results_generator = executor.map(barrier_counting_fn, range(_N_WORKERS)) - results = list(results_generator) + results = parallel_map(self, barrier_counting_fn) self.assertEqual(len(results), _N_WORKERS) for r in results: self.assertEqual(r, 1) @@ -433,7 +505,6 @@ def test_threaded_single_function(self): def test_once_per_thread(self): counter = Counter() - @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race @once.once(per_thread=True) @execute_with_barrier(n_workers=_N_WORKERS) def counting_fn(*args) -> int: @@ -442,8 +513,7 @@ def counting_fn(*args) -> int: del args return counter.get_incremented() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(counting_fn, range(_N_WORKERS * 4))) + results = parallel_map(self, counting_fn, (tuple() for _ in range(_N_WORKERS * 4))) self.assertEqual(min(results), 1) self.assertEqual(max(results), _N_WORKERS) @@ -565,12 +635,10 @@ def once_fn(self): once_obj = _CallOnceClass() - @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race - def execute(_): + def execute(): return once_obj.once_fn() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(execute, range(_N_WORKERS * 4))) + results = parallel_map(self, execute, (tuple() for _ in range(_N_WORKERS * 4))) self.assertEqual(min(results), 1) self.assertEqual(max(results), 1) @@ -583,12 +651,10 @@ def once_fn(self): once_obj = _CallOnceClass() - @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race - def execute(_): + def execute(): return once_obj.once_fn() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(execute, range(_N_WORKERS * 4))) + results = parallel_map(self, execute, (tuple() for _ in range(_N_WORKERS * 4))) self.assertEqual(min(results), 1) self.assertEqual(max(results), _N_WORKERS) @@ -697,8 +763,7 @@ def once_fn(self): def execute(i): return once_objs[i % 4].once_fn() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(execute, range(_N_WORKERS * 4))) + results = parallel_map(self, execute, ((i,) for i in range(_N_WORKERS * 4))) self.assertEqual(min(results), 1) self.assertEqual(max(results), 1) @@ -711,12 +776,10 @@ def once_fn(self): once_objs = [_CallOnceClass(), _CallOnceClass(), _CallOnceClass(), _CallOnceClass()] - @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race def execute(i): return once_objs[i % 4].once_fn() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(execute, range(_N_WORKERS))) + results = parallel_map(self, execute, ((i,) for i in range(_N_WORKERS))) self.assertEqual(min(results), 1) self.assertEqual(max(results), math.ceil(_N_WORKERS / 4)) @@ -776,7 +839,7 @@ def receiving_iterator(): barrier = threading.Barrier(_N_WORKERS) - def call_iterator(_): + def call_iterator(): gen = receiving_iterator() result = [] barrier.wait() @@ -785,8 +848,7 @@ def call_iterator(_): result.append(gen.send(i)) return result - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = executor.map(call_iterator, range(_N_WORKERS)) + results = parallel_map(self, call_iterator) for result in results: self.assertEqual(result, list(range(_N_WORKERS * 4))) @@ -811,8 +873,7 @@ def call_iterator(n): # Unlike the previous test, each execution should yield lists of different lengths. # This ensures that the iterator does not hang, even if not exhausted - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = executor.map(call_iterator, range(1, _N_WORKERS + 1)) + results = parallel_map(self, call_iterator, ((i,) for i in range(1, _N_WORKERS + 1))) for i, result in enumerate(results): self.assertEqual(result, list(range(i + 1))) @@ -850,16 +911,24 @@ async def counting_fn(*args) -> int: del args return counter.get_incremented() - @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race + results_lock = asyncio.Lock() + results = [] + + @execute_with_barrier(n_workers=_N_WORKERS, is_async=True) # increases chance of a race + async def counting_fn_multiple_caller(*args): + """Calls counting_fn() multiple times ensuring identical result.""" + result = await counting_fn() + for i in range(5): + self.assertEqual(await counting_fn(), result) + async with results_lock: + results.append(result) + return result + def execute(*args): - coro = counting_fn(*args) + coro = counting_fn_multiple_caller(*args) return asyncio.run(coro) - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(execute, range(_N_WORKERS))) - self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1))) - results = list(executor.map(execute, range(_N_WORKERS))) - self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1))) + parallel_map(self, execute) async def test_failing_function(self): counter = Counter()