From 6c55dc0889820f5dad3308a4024a55775a103f09 Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Tue, 24 Oct 2023 18:40:51 +0000 Subject: [PATCH] Improve async once_per_thread tests. --- once_test.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/once_test.py b/once_test.py index a328a77..1fa83d5 100644 --- a/once_test.py +++ b/once_test.py @@ -52,7 +52,7 @@ def get_incremented(self) -> int: return self.value -def execute_with_barrier(*args, n_workers=None): +def execute_with_barrier(*args, n_workers=None, is_async=False): """Decorator to ensure function calls do not begin until at least n_workers have started. This ensures that our parallel tests actually test concurrency. Without this, it is possible @@ -71,15 +71,23 @@ def execute_with_barrier(*args, n_workers=None): if n_workers is None: raise ValueError("n_workers not set") if len(args) == 0: - return functools.partial(execute_with_barrier, n_workers=n_workers) + return functools.partial(execute_with_barrier, n_workers=n_workers, is_async=is_async) if len(args) > 1: raise ValueError("Up to one argument expected.") func = args[0] barrier = threading.Barrier(n_workers) - def wrapped(*args, **kwargs): - barrier.wait() - return func(*args, **kwargs) + if is_async: + + async def wrapped(*args, **kwargs): + barrier.wait() # yes I know + return await func(*args, **kwargs) + + else: + + def wrapped(*args, **kwargs): + barrier.wait() + return func(*args, **kwargs) functools.update_wrapper(wrapped, func) return wrapped @@ -828,24 +836,24 @@ async def test_once_per_thread(self): counter = Counter() @once.once(per_thread=True) + @execute_with_barrier(n_workers=_N_WORKERS, is_async=True) async def counting_fn(*args) -> int: """Returns the call count, which should always be 1.""" nonlocal counter + del args return counter.get_incremented() @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race - def execute(): - coro = counting_fn() + def execute(*args): + coro = counting_fn(*args) return asyncio.run(coro) - threads = [threading.Thread(target=execute) for i in range(_N_WORKERS)] - for t in threads: - t.start() - for t in threads: - t.join() - self.assertEqual(await counting_fn(), _N_WORKERS + 1) - self.assertEqual(await counting_fn(), _N_WORKERS + 1) + with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: + results = list(executor.map(execute, range(_N_WORKERS))) + self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1))) + results = list(executor.map(execute, range(_N_WORKERS))) + self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1))) async def test_failing_function(self): counter = Counter()