Skip to content

Commit

Permalink
Parallel threads schedule tasks predictably.
Browse files Browse the repository at this point in the history
  • Loading branch information
aebrahim committed Oct 30, 2023
1 parent 7aa50b1 commit c2d85f9
Showing 1 changed file with 102 additions and 33 deletions.
135 changes: 102 additions & 33 deletions once_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Unit tests for once decorators."""
# pylint: disable=missing-function-docstring
import asyncio
import collections.abc
import concurrent.futures
import functools
import gc
import inspect
import itertools
import math
import sys
import threading
import time
import unittest
import uuid
import weakref

import once
Expand All @@ -31,6 +34,74 @@ async def anext(iter, default=StopAsyncIteration):
_N_WORKERS = 16


class WrappedException:
def __init__(self, exception):
self.exception = exception


def parallel_map(
test: unittest.TestCase,
func: collections.abc.Callable,
# would be collections.abc.Iterable[tuple] | None on py >= 3.10
call_args=None,
n_threads: int = _N_WORKERS,
timeout: float = 10.0,
) -> list:
"""Run a function multiple times in parallel.
We ensure that N parallel tasks are all launched at the "same time", which
means all have parallel threads which are released to the GIL to execute at
the same time.
Why?
We can't rely on the thread pool excector to always spin up the full list of _N_WORKERS.
In pypy, we have observed that even with blocked tasks, the same thread executes multiple
function calls. This lets us handle the scheduling in a predictable way for testing.
"""
if call_args is None:
call_args = (tuple() for _ in range(n_threads))

batches = [[] for i in range(n_threads)] # type: list[list[tuple[int, tuple]]]
for i, call_args in enumerate(call_args):
if not isinstance(call_args, tuple):
raise TypeError("call arguments must be a tuple")
batches[i % n_threads].append((i, call_args))
n_calls = i + 1
unset = object()
results_lock = threading.Lock()
results = [unset for _ in range(n_calls)]

# This barrier is used to ensure that all calls release together, after this function has
# completed its setup of creating them.
start_barrier = threading.Barrier(min(len(batches), n_calls))

def wrapped_fn(batch):
start_barrier.wait()
for index, args in batch:
try:
result = func(*args)
except Exception as e:
result = WrappedException(e)
with results_lock:
results[index] = result

# We manually set thread names for easier debugging.
invocation_id = str(uuid.uuid4())
threads = [
threading.Thread(target=wrapped_fn, args=[batch], name=f"{test.id()}-{i}-{invocation_id}")
for i, batch in enumerate(batches)
]
for t in threads:
t.start()
for t in threads:
t.join(timeout=timeout)
for i, result in enumerate(results):
if result is unset:
test.fail(f"Call {i} did not complete succesfully")
elif isinstance(result, WrappedException):
raise result.exception
return results


class Counter:
"""Holding object for a counter.
Expand Down Expand Up @@ -388,8 +459,11 @@ def yielding_iterator():
for _ in range(3):
yield counter.get_incremented()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = executor.map(lambda _: list(yielding_iterator()), range(_N_WORKERS * 2))
results = parallel_map(
self,
lambda: list(yielding_iterator()),
(tuple() for _ in range(_N_WORKERS * 2)),
)
for result in results:
self.assertEqual(result, [1, 2, 3])

Expand Down Expand Up @@ -422,9 +496,7 @@ def yielding_iterator():
def test_threaded_single_function(self):
counting_fn, counter = generate_once_counter_fn()
barrier_counting_fn = execute_with_barrier(counting_fn, n_workers=_N_WORKERS)
with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results_generator = executor.map(barrier_counting_fn, range(_N_WORKERS))
results = list(results_generator)
results = parallel_map(self, barrier_counting_fn)
self.assertEqual(len(results), _N_WORKERS)
for r in results:
self.assertEqual(r, 1)
Expand All @@ -433,7 +505,6 @@ def test_threaded_single_function(self):
def test_once_per_thread(self):
counter = Counter()

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
@once.once(per_thread=True)
@execute_with_barrier(n_workers=_N_WORKERS)
def counting_fn(*args) -> int:
Expand All @@ -442,8 +513,7 @@ def counting_fn(*args) -> int:
del args
return counter.get_incremented()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(counting_fn, range(_N_WORKERS * 4)))
results = parallel_map(self, counting_fn, (tuple() for _ in range(_N_WORKERS * 4)))
self.assertEqual(min(results), 1)
self.assertEqual(max(results), _N_WORKERS)

Expand Down Expand Up @@ -565,12 +635,10 @@ def once_fn(self):

once_obj = _CallOnceClass()

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
def execute(_):
def execute():
return once_obj.once_fn()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS * 4)))
results = parallel_map(self, execute, (tuple() for _ in range(_N_WORKERS * 4)))
self.assertEqual(min(results), 1)
self.assertEqual(max(results), 1)

Expand All @@ -583,12 +651,10 @@ def once_fn(self):

once_obj = _CallOnceClass()

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
def execute(_):
def execute():
return once_obj.once_fn()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS * 4)))
results = parallel_map(self, execute, (tuple() for _ in range(_N_WORKERS * 4)))
self.assertEqual(min(results), 1)
self.assertEqual(max(results), _N_WORKERS)

Expand Down Expand Up @@ -697,8 +763,7 @@ def once_fn(self):
def execute(i):
return once_objs[i % 4].once_fn()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS * 4)))
results = parallel_map(self, execute, ((i,) for i in range(_N_WORKERS * 4)))
self.assertEqual(min(results), 1)
self.assertEqual(max(results), 1)

Expand All @@ -711,12 +776,10 @@ def once_fn(self):

once_objs = [_CallOnceClass(), _CallOnceClass(), _CallOnceClass(), _CallOnceClass()]

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
def execute(i):
return once_objs[i % 4].once_fn()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS)))
results = parallel_map(self, execute, ((i,) for i in range(_N_WORKERS)))
self.assertEqual(min(results), 1)
self.assertEqual(max(results), math.ceil(_N_WORKERS / 4))

Expand Down Expand Up @@ -776,7 +839,7 @@ def receiving_iterator():

barrier = threading.Barrier(_N_WORKERS)

def call_iterator(_):
def call_iterator():
gen = receiving_iterator()
result = []
barrier.wait()
Expand All @@ -785,8 +848,7 @@ def call_iterator(_):
result.append(gen.send(i))
return result

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = executor.map(call_iterator, range(_N_WORKERS))
results = parallel_map(self, call_iterator)
for result in results:
self.assertEqual(result, list(range(_N_WORKERS * 4)))

Expand All @@ -811,8 +873,7 @@ def call_iterator(n):

# Unlike the previous test, each execution should yield lists of different lengths.
# This ensures that the iterator does not hang, even if not exhausted
with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = executor.map(call_iterator, range(1, _N_WORKERS + 1))
results = parallel_map(self, call_iterator, ((i,) for i in range(1, _N_WORKERS + 1)))
for i, result in enumerate(results):
self.assertEqual(result, list(range(i + 1)))

Expand Down Expand Up @@ -850,16 +911,24 @@ async def counting_fn(*args) -> int:
del args
return counter.get_incremented()

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
results_lock = asyncio.Lock()
results = []

@execute_with_barrier(n_workers=_N_WORKERS, is_async=True) # increases chance of a race
async def counting_fn_multiple_caller(*args):
"""Calls counting_fn() multiple times ensuring identical result."""
result = await counting_fn()
for i in range(5):
self.assertEqual(await counting_fn(), result)
async with results_lock:
results.append(result)
return result

def execute(*args):
coro = counting_fn(*args)
coro = counting_fn_multiple_caller(*args)
return asyncio.run(coro)

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS)))
self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1)))
results = list(executor.map(execute, range(_N_WORKERS)))
self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1)))
parallel_map(self, execute)

async def test_failing_function(self):
counter = Counter()
Expand Down

0 comments on commit c2d85f9

Please sign in to comment.