Skip to content

Commit

Permalink
Improve async once_per_thread tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
aebrahim committed Oct 24, 2023
1 parent 3fe3c78 commit 6c55dc0
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions once_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_incremented(self) -> int:
return self.value


def execute_with_barrier(*args, n_workers=None):
def execute_with_barrier(*args, n_workers=None, is_async=False):
"""Decorator to ensure function calls do not begin until at least n_workers have started.
This ensures that our parallel tests actually test concurrency. Without this, it is possible
Expand All @@ -71,15 +71,23 @@ def execute_with_barrier(*args, n_workers=None):
if n_workers is None:
raise ValueError("n_workers not set")
if len(args) == 0:
return functools.partial(execute_with_barrier, n_workers=n_workers)
return functools.partial(execute_with_barrier, n_workers=n_workers, is_async=is_async)
if len(args) > 1:
raise ValueError("Up to one argument expected.")
func = args[0]
barrier = threading.Barrier(n_workers)

def wrapped(*args, **kwargs):
barrier.wait()
return func(*args, **kwargs)
if is_async:

async def wrapped(*args, **kwargs):
barrier.wait() # yes I know
return await func(*args, **kwargs)

else:

def wrapped(*args, **kwargs):
barrier.wait()
return func(*args, **kwargs)

functools.update_wrapper(wrapped, func)
return wrapped
Expand Down Expand Up @@ -828,24 +836,24 @@ async def test_once_per_thread(self):
counter = Counter()

@once.once(per_thread=True)
@execute_with_barrier(n_workers=_N_WORKERS, is_async=True)
async def counting_fn(*args) -> int:
"""Returns the call count, which should always be 1."""
nonlocal counter

del args
return counter.get_incremented()

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
def execute():
coro = counting_fn()
def execute(*args):
coro = counting_fn(*args)
return asyncio.run(coro)

threads = [threading.Thread(target=execute) for i in range(_N_WORKERS)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(await counting_fn(), _N_WORKERS + 1)
self.assertEqual(await counting_fn(), _N_WORKERS + 1)
with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS)))
self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1)))
results = list(executor.map(execute, range(_N_WORKERS)))
self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1)))

async def test_failing_function(self):
counter = Counter()
Expand Down

0 comments on commit 6c55dc0

Please sign in to comment.