Skip to content

Commit

Permalink
Run tests with -O and -OO. (#16)
Browse files Browse the repository at this point in the history
* Run tests with -O and -OO.

* Parallel threads schedule tasks predictably.

* PR comments.
  • Loading branch information
aebrahim authored Nov 8, 2023
1 parent 4e98518 commit 846adcc
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 65 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ jobs:
cache: pip
- run: pip install pytest
- run: pytest . --junitxml=junit/test_py${{ matrix.python-version }}_on_${{ matrix.os }}.xml
- run: python -O once_test.py
- run: python -OO once_test.py
- name: Upload pytest test results
uses: actions/upload-artifact@v3
if: success() || failure()
Expand Down
207 changes: 142 additions & 65 deletions once_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Unit tests for once decorators."""
# pylint: disable=missing-function-docstring
import asyncio
import collections.abc
import concurrent.futures
import functools
import gc
Expand All @@ -9,6 +10,7 @@
import sys
import threading
import unittest
import uuid
import weakref

import once
Expand All @@ -30,6 +32,74 @@ async def anext(iter, default=StopAsyncIteration):
_N_WORKERS = 32


class WrappedException:
def __init__(self, exception):
self.exception = exception


def parallel_map(
test: unittest.TestCase,
func: collections.abc.Callable,
# would be collections.abc.Iterable[tuple] | None on py >= 3.10
call_args=None,
n_threads: int = _N_WORKERS,
timeout: float = 10.0,
) -> list:
"""Run a function multiple times in parallel.
We ensure that N parallel tasks are all launched at the "same time", which
means all have parallel threads which are released to the GIL to execute at
the same time.
Why?
We can't rely on the thread pool excector to always spin up the full list of _N_WORKERS.
In pypy, we have observed that even with blocked tasks, the same thread executes multiple
function calls. This lets us handle the scheduling in a predictable way for testing.
"""
if call_args is None:
call_args = (tuple() for _ in range(n_threads))

batches = [[] for i in range(n_threads)] # type: list[list[tuple[int, tuple]]]
for i, call_args in enumerate(call_args):
if not isinstance(call_args, tuple):
raise TypeError("call arguments must be a tuple")
batches[i % n_threads].append((i, call_args))
n_calls = i + 1 # len(call_args), but it is now an exhuasted iterator.
unset = object()
results_lock = threading.Lock()
results = [unset for _ in range(n_calls)]

# This barrier is used to ensure that all calls release together, after this function has
# completed its setup of creating them.
start_barrier = threading.Barrier(min(n_threads, n_calls))

def wrapped_fn(batch):
start_barrier.wait()
for index, args in batch:
try:
result = func(*args)
except Exception as e:
result = WrappedException(e)
with results_lock:
results[index] = result

# We manually set thread names for easier debugging.
invocation_id = str(uuid.uuid4())
threads = [
threading.Thread(target=wrapped_fn, args=[batch], name=f"{test.id()}-{i}-{invocation_id}")
for i, batch in enumerate(batches)
]
for t in threads:
t.start()
for t in threads:
t.join(timeout=timeout)
for i, result in enumerate(results):
if result is unset:
test.fail(f"Call {i} did not complete succesfully")
elif isinstance(result, WrappedException):
raise result.exception
return results


class Counter:
"""Holding object for a counter.
Expand Down Expand Up @@ -429,16 +499,17 @@ def sample_failing_fn():
def test_iterator_parallel_execution(self):
counter = Counter()

# Must be called over an integer multiple of _N_WORKERS
@execute_with_barrier(n_workers=_N_WORKERS)
@once.once
def yielding_iterator():
nonlocal counter
for _ in range(3):
yield counter.get_incremented()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = executor.map(lambda _: list(yielding_iterator()), range(_N_WORKERS * 2))
results = parallel_map(
self,
lambda: list(yielding_iterator()),
(tuple() for _ in range(_N_WORKERS * 2)),
)
for result in results:
self.assertEqual(result, [1, 2, 3])

Expand Down Expand Up @@ -470,10 +541,7 @@ def yielding_iterator():

def test_threaded_single_function(self):
counting_fn, counter = generate_once_counter_fn()
barrier_counting_fn = execute_with_barrier(counting_fn, n_workers=_N_WORKERS)
with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results_generator = executor.map(barrier_counting_fn, range(_N_WORKERS))
results = list(results_generator)
results = parallel_map(self, counting_fn)
self.assertEqual(len(results), _N_WORKERS)
for r in results:
self.assertEqual(r, 1)
Expand All @@ -482,7 +550,6 @@ def test_threaded_single_function(self):
def test_once_per_thread(self):
counter = Counter()

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
@once.once(per_thread=True)
@execute_with_barrier(n_workers=_N_WORKERS)
def counting_fn(*args) -> int:
Expand All @@ -491,8 +558,7 @@ def counting_fn(*args) -> int:
del args
return counter.get_incremented()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(counting_fn, range(_N_WORKERS * 4)))
results = parallel_map(self, counting_fn, (tuple() for _ in range(_N_WORKERS * 4)))
self.assertEqual(min(results), 1)
self.assertEqual(max(results), _N_WORKERS)

Expand All @@ -503,17 +569,13 @@ def test_threaded_multiple_functions(self):
for _ in range(4):
cfn, counter = generate_once_counter_fn()
counters.append(counter)
fns.append(execute_with_barrier(cfn, n_workers=_N_WORKERS))

promises = []
with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
for cfn in fns:
for _ in range(_N_WORKERS):
promises.append(executor.submit(cfn))
del cfn
fns.clear()
for promise in promises:
self.assertEqual(promise.result(), 1)
fns.append(cfn)

def call_all_functions(i):
for j in range(i, i + 4):
self.assertEqual(fns[j % 4](), 1)

parallel_map(self, call_all_functions, ((i,) for i in range(_N_WORKERS)))
for counter in counters:
self.assertEqual(counter.value, 1)

Expand Down Expand Up @@ -575,16 +637,22 @@ def closure():
self.assertIsNone(ephemeral_ref())

def test_function_signature_preserved(self):
@once.once
def type_annotated_fn(arg: float) -> int:
"""Very descriptive docstring."""
del arg
return 1

sig = inspect.signature(type_annotated_fn)
self.assertIs(sig.parameters["arg"].annotation, float)
self.assertIs(sig.return_annotation, int)
self.assertEqual(type_annotated_fn.__doc__, "Very descriptive docstring.")
decorated_function = once.once(type_annotated_fn)
original_sig = inspect.signature(type_annotated_fn)
decorated_sig = inspect.signature(decorated_function)
self.assertIs(original_sig.parameters["arg"].annotation, float)
self.assertIs(decorated_sig.parameters["arg"].annotation, float)
self.assertIs(original_sig.return_annotation, int)
self.assertIs(decorated_sig.return_annotation, int)
self.assertEqual(inspect.getdoc(type_annotated_fn), inspect.getdoc(decorated_function))
if sys.flags.optimize >= 2:
self.skipTest("docstrings get stripped with -OO")
self.assertEqual(inspect.getdoc(type_annotated_fn), "Very descriptive docstring.")

def test_once_per_class(self):
class _CallOnceClass(Counter):
Expand All @@ -608,12 +676,10 @@ def once_fn(self):

once_obj = _CallOnceClass()

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
def execute(_):
def execute():
return once_obj.once_fn()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS * 4)))
results = parallel_map(self, execute, (tuple() for _ in range(_N_WORKERS * 4)))
self.assertEqual(min(results), 1)
self.assertEqual(max(results), 1)

Expand All @@ -626,12 +692,10 @@ def once_fn(self):

once_obj = _CallOnceClass()

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
def execute(_):
def execute():
return once_obj.once_fn()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS * 4)))
results = parallel_map(self, execute, (tuple() for _ in range(_N_WORKERS * 4)))
self.assertEqual(min(results), 1)
self.assertEqual(max(results), _N_WORKERS)

Expand Down Expand Up @@ -686,18 +750,30 @@ def value(self): # pylint: disable=inconsistent-return-statements
a = _CallOnceClass("a", self) # pylint: disable=invalid-name
b = _CallOnceClass("b", self) # pylint: disable=invalid-name

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
a_jobs = [executor.submit(a.value) for _ in range(_N_WORKERS // 2)]
b_jobs = [executor.submit(b.value) for _ in range(_N_WORKERS // 2)]
for a_job in a_jobs:
self.assertEqual(a_job.result(), "a")
for b_job in b_jobs:
self.assertEqual(b_job.result(), "b")

self.assertEqual(a.value(), "a")
self.assertEqual(a.value(), "a")
self.assertEqual(b.value(), "b")
self.assertEqual(b.value(), "b")
def call_and_check_both(i: int):
# Run in different order based on the call
if i % 4 == 0:
self.assertEqual(a.value(), "a")
self.assertEqual(a.value(), "a")
self.assertEqual(b.value(), "b")
self.assertEqual(b.value(), "b")
elif i % 4 == 1:
self.assertEqual(a.value(), "a")
self.assertEqual(b.value(), "b")
self.assertEqual(a.value(), "a")
self.assertEqual(b.value(), "b")
elif i % 4 == 2:
self.assertEqual(b.value(), "b")
self.assertEqual(a.value(), "a")
self.assertEqual(b.value(), "b")
self.assertEqual(a.value(), "a")
else:
self.assertEqual(b.value(), "b")
self.assertEqual(b.value(), "b")
self.assertEqual(a.value(), "a")
self.assertEqual(a.value(), "a")

parallel_map(self, call_and_check_both, ((i,) for i in range(_N_WORKERS)))

def test_once_per_instance_do_not_block_each_other(self):
class _BlockableClass:
Expand Down Expand Up @@ -736,12 +812,10 @@ def once_fn(self):

once_objs = [_CallOnceClass(), _CallOnceClass(), _CallOnceClass(), _CallOnceClass()]

@execute_with_barrier(n_workers=_N_WORKERS)
def execute(i):
return once_objs[i % 4].once_fn()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS * 4)))
results = parallel_map(self, execute, ((i,) for i in range(_N_WORKERS * 4)))
self.assertEqual(min(results), 1)
self.assertEqual(max(results), 1)

Expand All @@ -754,12 +828,10 @@ def once_fn(self):

once_objs = [_CallOnceClass(), _CallOnceClass(), _CallOnceClass(), _CallOnceClass()]

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
def execute(i):
return once_objs[i % 4].once_fn()

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS)))
results = parallel_map(self, execute, ((i,) for i in range(_N_WORKERS)))
self.assertEqual(min(results), 1)
self.assertEqual(max(results), math.ceil(_N_WORKERS / 4))

Expand Down Expand Up @@ -819,7 +891,7 @@ def receiving_iterator():

barrier = threading.Barrier(_N_WORKERS)

def call_iterator(_):
def call_iterator():
gen = receiving_iterator()
result = []
barrier.wait()
Expand All @@ -828,8 +900,7 @@ def call_iterator(_):
result.append(gen.send(i))
return result

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = executor.map(call_iterator, range(_N_WORKERS))
results = parallel_map(self, call_iterator)
for result in results:
self.assertEqual(result, list(range(_N_WORKERS * 4)))

Expand All @@ -854,8 +925,7 @@ def call_iterator(n):

# Unlike the previous test, each execution should yield lists of different lengths.
# This ensures that the iterator does not hang, even if not exhausted
with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = executor.map(call_iterator, range(1, _N_WORKERS + 1))
results = parallel_map(self, call_iterator, ((i,) for i in range(1, _N_WORKERS + 1)))
for i, result in enumerate(results):
self.assertEqual(result, list(range(i + 1)))

Expand Down Expand Up @@ -893,16 +963,23 @@ async def counting_fn(*args) -> int:
del args
return counter.get_incremented()

@execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race
results_lock = asyncio.Lock()
results = []

async def counting_fn_multiple_caller(*args):
"""Calls counting_fn() multiple times ensuring identical result."""
result = await counting_fn()
for i in range(5):
self.assertEqual(await counting_fn(), result)
async with results_lock:
results.append(result)
return result

def execute(*args):
coro = counting_fn(*args)
coro = counting_fn_multiple_caller(*args)
return asyncio.run(coro)

with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor:
results = list(executor.map(execute, range(_N_WORKERS)))
self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1)))
results = list(executor.map(execute, range(_N_WORKERS)))
self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1)))
parallel_map(self, execute)

async def test_failing_function(self):
counter = Counter()
Expand Down

0 comments on commit 846adcc

Please sign in to comment.