Skip to content

Commit

Permalink
Use threading.Barrier in tests.
Browse files Browse the repository at this point in the history
This lets us ensure parallel function test calls execute in parallel.
  • Loading branch information
aebrahim committed Oct 14, 2023
1 parent 889d5db commit 893935c
Showing 1 changed file with 39 additions and 27 deletions.
66 changes: 39 additions & 27 deletions once_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import gc
import inspect
import sys
import threading
import time
import unittest
import weakref
Expand Down Expand Up @@ -50,6 +51,31 @@ def get_incremented(self) -> int:
return self.value


def execute_with_barrier(*args, n_workers=None):
"""Decorator to ensure function calls do not begin until at least n_workers have started.
This ensures that our parallel tests actually test concurrency. Without this, it is possible
that function calls execute as they are being scheduled, and do not truly execute in parallel.
"""
# Trick to make the decorator accept an arugment. The first call only gets the n_workers
# parameter, and then returns a new function with it set that then accepts the function.
if n_workers is None:
raise ValueError("n_workers not set")
if len(args) == 0:
return functools.partial(execute_with_barrier, n_workers=n_workers)
if len(args) > 1:
raise ValueError("Up to one argument expected.")
func = args[0]
barrier = threading.Barrier(n_workers)

def wrapped(*args, **kwargs):
barrier.wait()
return func(*args, **kwargs)

functools.update_wrapper(wrapped, func)
return wrapped


def generate_once_counter_fn():
"""Generates a once.once decorated function which counts its calls."""

Expand Down Expand Up @@ -323,8 +349,8 @@ def sample_failing_fn():

def test_iterator_parallel_execution(self):
counter = Counter()
counter.paused = True

@execute_with_barrier(n_workers=_N_WORKERS)
@once.once
def yielding_iterator():
nonlocal counter
Expand Down Expand Up @@ -363,12 +389,11 @@ def yielding_iterator():

def test_threaded_single_function(self):
counting_fn, counter = generate_once_counter_fn()
counter.paused = True
barrier_counting_fn = execute_with_barrier(counting_fn, n_workers=_N_WORKERS)
with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results_generator = executor.map(counting_fn, range(_N_WORKERS * 2))
counter.paused = False # starter pistol, the race is off!
results_generator = executor.map(barrier_counting_fn, range(_N_WORKERS))
results = list(results_generator)
self.assertEqual(len(results), _N_WORKERS * 2)
self.assertEqual(len(results), _N_WORKERS)
for r in results:
self.assertEqual(r, 1)
self.assertEqual(counter.value, 1)
Expand All @@ -379,17 +404,14 @@ def test_threaded_multiple_functions(self):

for _ in range(4):
cfn, counter = generate_once_counter_fn()
counter.paused = True
counters.append(counter)
fns.append(cfn)
fns.append(execute_with_barrier(cfn, n_workers=_N_WORKERS))

promises = []
with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
for cfn in fns:
for _ in range(_N_WORKERS):
promises.append(executor.submit(cfn))
for counter in counters:
counter.paused = False
del cfn
fns.clear()
for promise in promises:
Expand Down Expand Up @@ -623,51 +645,42 @@ def receiving_iterator():
self.assertEqual(list(receiving_iterator()), [0, 1, 2, 5])

def test_receiving_iterator_parallel_execution(self):
# Pause so we actually are able to test parallel execution, by ensuring that each exec
# does not complete before the next one is scheduled.
paused = True

@once.once
def receiving_iterator():
nonlocal paused
next = yield 0
while next is not None:
while paused:
pass
next = yield next

barrier = threading.Barrier(_N_WORKERS)

def call_iterator(_):
gen = receiving_iterator()
result = []
barrier.wait()
result.append(gen.send(None))
for i in range(1, _N_WORKERS * 4):
result.append(gen.send(i))
return result

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = executor.map(call_iterator, range(_N_WORKERS * 2))
paused = False # starter pistol, the race is off!
results = executor.map(call_iterator, range(_N_WORKERS))
for result in results:
self.assertEqual(result, list(range(_N_WORKERS * 4)))

def test_receiving_iterator_parallel_execution_halting(self):
# Pause so we actually are able to test parallel execution, by ensuring that each exec
# does not complete before the next one is scheduled.
paused = True

@once.once
def receiving_iterator():
nonlocal paused
next = yield 0
while next is not None:
while paused:
pass
next = yield next

barrier = threading.Barrier(_N_WORKERS)

def call_iterator(n):
"""Call the iterator but end early"""
gen = receiving_iterator()
result = []
barrier.wait()
result.append(gen.send(None))
for i in range(1, n):
result.append(gen.send(i))
Expand All @@ -676,8 +689,7 @@ def call_iterator(n):
# Unlike the previous test, each execution should yield lists of different lengths.
# This ensures that the iterator does not hang, even if not exhausted
with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = executor.map(call_iterator, range(1, _N_WORKERS * 2))
paused = False # starter pistol, the race is off!
results = executor.map(call_iterator, range(1, _N_WORKERS + 1))
for i, result in enumerate(results):
self.assertEqual(result, list(range(i + 1)))

Expand Down

0 comments on commit 893935c

Please sign in to comment.