From 2e542a5e6fc49048aa2303fd1a93beb2840f73a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 15 Aug 2025 12:06:15 +0000 Subject: [PATCH 01/11] Initial ThreadPoolExecutor implementation and refactoring --- src/qasync/__init__.py | 145 +++++++++++++++++++++++++++++++++++------ 1 file changed, 126 insertions(+), 19 deletions(-) diff --git a/src/qasync/__init__.py b/src/qasync/__init__.py index d7702c9..43e3c95 100644 --- a/src/qasync/__init__.py +++ b/src/qasync/__init__.py @@ -8,8 +8,9 @@ BSD License """ -__all__ = ["QEventLoop", "QThreadExecutor", "asyncSlot", "asyncClose", "asyncWrap"] +__all__ = ["QEventLoop", "QThreadExecutor", "QThreadPoolExecutor", "asyncSlot", "asyncClose", "asyncWrap"] +from ast import Not import asyncio import contextlib import functools @@ -20,8 +21,9 @@ import os import sys import time -from concurrent.futures import Future +from concurrent.futures import CancelledError, Future, TimeoutError from queue import Queue +from weakref import WeakSet logger = logging.getLogger(__name__) @@ -162,8 +164,66 @@ def wait(self): super().wait() +class QThreadExecutorBase: + def __init__(self): + self._been_shutdown = False + self.futures = WeakSet() + + def submit(self, callback, *args, **kwargs): + raise NotImplementedError() + + def map(self, func, *iterables, timeout=None, chunksize=1): + """Map the function to the iterables in a blocking way.""" + # iterables are consumed immediately + start = time.monotonic() + if chunksize <= 1: + futures = list(map(lambda *args: self.submit(func, *args), *iterables)) + try: + for future in futures: + if timeout is not None: + yield future.result(timeout=time.monotonic()-start) + else: + yield future.result() + except TimeoutError: + map(lambda f: f.cancel(), futures) + raise + else: + calls = list(map(lambda *args: args, *iterables)) + chunks = (calls[i:i + chunksize] for i in range(0, len(calls), chunksize)) + def helper(chunk): + """Helper to execute a chunk of calls""" + return [func(*args) for args in chunk] + + # submit all the chunks + chunkfutures = [self.submit(helper, chunk) for chunk in chunks] + + # await all the chunk futures + try: + for future in chunkfutures: + if timeout is not None: + results = future.result(timeout=time.monotonic()-start) + for result in results: + yield result + except TimeoutError: + map(lambda f: f.cancel(), chunkfutures) + raise + + def shutdown(self, wait=True, *, cancel_futures=False): + if self._been_shutdown: + raise RuntimeError(f"{self.__class__.__name__} has been shutdown") + self._been_shutdown = True + + def __enter__(self, *args): + if self._been_shutdown: + raise RuntimeError(f"{self.__class__.__name__} has been shutdown") + return self + + def __exit__(self, *args): + self.shutdown() + + @with_logger -class QThreadExecutor: +class QThreadExecutor(QThreadExecutorBase): """ ThreadExecutor that produces QThreads. @@ -191,13 +251,12 @@ def __init__(self, max_workers=10, stack_size=None): self.__workers = [ _QThreadWorker(self.__queue, i + 1, stack_size) for i in range(max_workers) ] - self.__been_shutdown = False for w in self.__workers: w.start() def submit(self, callback, *args, **kwargs): - if self.__been_shutdown: + if self._been_shutdown: raise RuntimeError("QThreadExecutor has been shutdown") future = Future() @@ -208,32 +267,80 @@ def submit(self, callback, *args, **kwargs): kwargs, ) self.__queue.put((future, callback, args, kwargs)) + self.futures.add(future) return future - def map(self, func, *iterables, timeout=None): - raise NotImplementedError("use as_completed on the event loop") - - def shutdown(self, wait=True): - if self.__been_shutdown: - raise RuntimeError("QThreadExecutor has been shutdown") - - self.__been_shutdown = True + def shutdown(self, wait=True, *, cancel_futures=False): + super().shutdown(wait=wait, cancel_futures=cancel_futures) self._logger.debug("Shutting down") for i in range(len(self.__workers)): # Signal workers to stop self.__queue.put(None) + if cancel_futures: + for future in self.futures: + future.cancel() if wait: for w in self.__workers: w.wait() - def __enter__(self, *args): - if self.__been_shutdown: - raise RuntimeError("QThreadExecutor has been shutdown") - return self - def __exit__(self, *args): - self.shutdown() +class _QThreadPoolExecutorRunnable(QtCore.QRunnable): + def __init__(self, callback, *args, **kwargs): + super().__init__() + self._callback = callback + self._args = args + self._kwargs = kwargs + self.future = Future() + + def run(self): + if self.future.set_running_or_notify_cancel(): + try: + result = self._callback(*self._args, **self._kwargs) + self.future.set_result(result) + except Exception as e: + self.future.set_exception(e) + + +@with_logger +class QThreadPoolExecutor(QThreadExecutorBase): + """ + ThreadPoolExecutor uses a QThreadPool as the underlying implementation. + + Same API as `concurrent.futures.Executor` + + >>> from qasync import QThreadPoolExecutor + >>> with QThreadPoolExecutor() as executor: + ... f = executor.submit(lambda x: 2 + x, 2) + ... r = f.result() + ... assert r == 4 + """ + + def __init__(self, pool=None): + super().__init__() + self.pool = pool or QtCore.QThreadPool.globalInstance() + + def submit(self, callback, *args, **kwargs): + if self._been_shutdown: + raise RuntimeError(f"{self.__class__.__name__} has been shutdown") + + runnable = _QThreadPoolExecutorRunnable(callback, *args, **kwargs) + self.pool.start(runnable) + self.futures.add(runnable.future) + return runnable.future + + def shutdown(self, wait=True, *, cancel_futures=False): + super().shutdown(wait=wait, cancel_futures=cancel_futures) + self._logger.debug("Shutting down") + if cancel_futures: + for future in self.futures: + future.cancel() + if wait: + for w in list(self.futures): + try: + w.wait() + except CancelledError: + pass def _format_handle(handle: asyncio.Handle): From de001bcd3e8c6fc85cf2253c21f74cfad7aace38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 15 Aug 2025 12:15:14 +0000 Subject: [PATCH 02/11] simplify map. chunksize can be ignored and cancellation is up to the caller. --- src/qasync/__init__.py | 37 ++++++------------------------------- 1 file changed, 6 insertions(+), 31 deletions(-) diff --git a/src/qasync/__init__.py b/src/qasync/__init__.py index 43e3c95..432efe4 100644 --- a/src/qasync/__init__.py +++ b/src/qasync/__init__.py @@ -176,37 +176,12 @@ def map(self, func, *iterables, timeout=None, chunksize=1): """Map the function to the iterables in a blocking way.""" # iterables are consumed immediately start = time.monotonic() - if chunksize <= 1: - futures = list(map(lambda *args: self.submit(func, *args), *iterables)) - try: - for future in futures: - if timeout is not None: - yield future.result(timeout=time.monotonic()-start) - else: - yield future.result() - except TimeoutError: - map(lambda f: f.cancel(), futures) - raise - else: - calls = list(map(lambda *args: args, *iterables)) - chunks = (calls[i:i + chunksize] for i in range(0, len(calls), chunksize)) - def helper(chunk): - """Helper to execute a chunk of calls""" - return [func(*args) for args in chunk] - - # submit all the chunks - chunkfutures = [self.submit(helper, chunk) for chunk in chunks] - - # await all the chunk futures - try: - for future in chunkfutures: - if timeout is not None: - results = future.result(timeout=time.monotonic()-start) - for result in results: - yield result - except TimeoutError: - map(lambda f: f.cancel(), chunkfutures) - raise + futures = list(map(lambda *args: self.submit(func, *args), *iterables)) + for future in futures: + if timeout is not None: + yield future.result(timeout=time.monotonic()-start) + else: + yield future.result() def shutdown(self, wait=True, *, cancel_futures=False): if self._been_shutdown: From 03a205fb0c75d75733144cf7a318dfe7185e841d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 15 Aug 2025 12:16:02 +0000 Subject: [PATCH 03/11] format --- src/qasync/__init__.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/qasync/__init__.py b/src/qasync/__init__.py index 432efe4..2437240 100644 --- a/src/qasync/__init__.py +++ b/src/qasync/__init__.py @@ -8,7 +8,14 @@ BSD License """ -__all__ = ["QEventLoop", "QThreadExecutor", "QThreadPoolExecutor", "asyncSlot", "asyncClose", "asyncWrap"] +__all__ = [ + "QEventLoop", + "QThreadExecutor", + "QThreadPoolExecutor", + "asyncSlot", + "asyncClose", + "asyncWrap", +] from ast import Not import asyncio @@ -179,7 +186,7 @@ def map(self, func, *iterables, timeout=None, chunksize=1): futures = list(map(lambda *args: self.submit(func, *args), *iterables)) for future in futures: if timeout is not None: - yield future.result(timeout=time.monotonic()-start) + yield future.result(timeout=time.monotonic() - start) else: yield future.result() From 6aff28765d7765d6403603b61e0c11fa8f125faf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 15 Aug 2025 12:45:05 +0000 Subject: [PATCH 04/11] fix shutdown of threadpoolexecutor --- src/qasync/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/qasync/__init__.py b/src/qasync/__init__.py index 2437240..6e3f665 100644 --- a/src/qasync/__init__.py +++ b/src/qasync/__init__.py @@ -320,8 +320,8 @@ def shutdown(self, wait=True, *, cancel_futures=False): if wait: for w in list(self.futures): try: - w.wait() - except CancelledError: + w.result() + except Exception: pass From 8e80c0e8053406194d5f57d441f6ee06e4553b97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 15 Aug 2025 13:25:34 +0000 Subject: [PATCH 05/11] add tests --- tests/test_qthreadexec.py | 71 ++++++++++++++++++++++++++++++++++----- 1 file changed, 63 insertions(+), 8 deletions(-) diff --git a/tests/test_qthreadexec.py b/tests/test_qthreadexec.py index 67c1833..3a9005f 100644 --- a/tests/test_qthreadexec.py +++ b/tests/test_qthreadexec.py @@ -2,6 +2,7 @@ # © 2014 Mark Harviston # © 2014 Arve Knudsen # BSD License +import time import logging import threading import weakref @@ -21,7 +22,11 @@ def disable_executor_logging(): To avoid issues with tests targeting stale references, we disable logging for QThreadExecutor and _QThreadWorker classes. """ - for cls in (qasync.QThreadExecutor, qasync._QThreadWorker): + for cls in ( + qasync.QThreadExecutor, + qasync._QThreadWorker, + qasync.QThreadPoolExecutor, + ): logger_name = cls.__qualname__ if cls.__module__ is not None: logger_name = f"{cls.__module__}.{logger_name}" @@ -30,16 +35,32 @@ def disable_executor_logging(): logger.propagate = False -@pytest.fixture +@pytest.fixture(params=[qasync.QThreadExecutor, qasync.QThreadPoolExecutor]) def executor(request): - exe = qasync.QThreadExecutor(5) - request.addfinalizer(exe.shutdown) + exe = get_executor(request) + request.addfinalizer(lambda: safe_shutdown(exe)) return exe -@pytest.fixture -def shutdown_executor(): - exe = qasync.QThreadExecutor(5) +def get_executor(request): + if request.param is qasync.QThreadPoolExecutor: + pool = qasync.QtCore.QThreadPool() + pool.setMaxThreadCount(5) + return request.param(pool) + else: + return request.param(5) + + +def safe_shutdown(executor): + try: + executor.shutdown() + except Exception: + pass + + +@pytest.fixture(params=[qasync.QThreadExecutor, qasync.QThreadPoolExecutor]) +def shutdown_executor(request): + exe = get_executor(request) exe.shutdown() return exe @@ -55,7 +76,7 @@ def test_ctx_after_shutdown(shutdown_executor): pass -def test_submit_after_shutdown(shutdown_executor): +def _test_submit_after_shutdown(shutdown_executor): with pytest.raises(RuntimeError): shutdown_executor.submit(None) @@ -64,6 +85,7 @@ def test_stack_recursion_limit(executor): # Test that worker threads have sufficient stack size for the default # sys.getrecursionlimit. If not this should fail with SIGSEGV or SIGBUS # (or event SIGILL?) + def rec(a, *args, **kwargs): rec(a, *args, **kwargs) @@ -104,3 +126,36 @@ def test_no_stale_reference_as_result(executor, disable_executor_logging): assert collected is True, ( "Stale reference to executor result not collected within timeout." ) + + +def test_map(executor): + results = list(executor.map(lambda x: x + 1, range(10))) + assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + +@pytest.mark.parametrize("cancel", [True, False]) +def test_map_timeout(executor, cancel): + results = [] + + def func(x): + nonlocal results + time.sleep(0.05) + results.append(x) + return x + + start = time.monotonic() + with pytest.raises(TimeoutError): + list(executor.map(func, range(10), timeout=0.01)) + duration = time.monotonic() - start + assert duration < 0.02 + + executor.shutdown(wait=True, cancel_futures=cancel) + if not cancel: + # they were not cancelled + assert set(results) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + else: + # only about half of the tasks should have completed + # because the max number of workers is 5 and the rest of + # the tasks were not started at the time of the cancel. + assert results + assert set(results) != {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} From 32a4f0528f8b67f66e939ea3ce690516ed6c5229 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 15 Aug 2025 13:51:40 +0000 Subject: [PATCH 06/11] Add a few more tests --- tests/test_qthreadexec.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/test_qthreadexec.py b/tests/test_qthreadexec.py index 3a9005f..11601b0 100644 --- a/tests/test_qthreadexec.py +++ b/tests/test_qthreadexec.py @@ -56,6 +56,9 @@ def safe_shutdown(executor): executor.shutdown() except Exception: pass + if isinstance(executor, qasync.QThreadPoolExecutor): + # empty the underlying QThreadPool object + executor.pool.waitForDone() @pytest.fixture(params=[qasync.QThreadExecutor, qasync.QThreadPoolExecutor]) @@ -129,12 +132,14 @@ def test_no_stale_reference_as_result(executor, disable_executor_logging): def test_map(executor): + """Basic test of executor map functionality""" results = list(executor.map(lambda x: x + 1, range(10))) assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] @pytest.mark.parametrize("cancel", [True, False]) def test_map_timeout(executor, cancel): + """Test that map with timeout raises TimeoutError and cancels futures""" results = [] def func(x): @@ -157,5 +162,21 @@ def func(x): # only about half of the tasks should have completed # because the max number of workers is 5 and the rest of # the tasks were not started at the time of the cancel. - assert results assert set(results) != {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + +def test_context(executor): + """Test that the context manager will shutdown executor""" + with executor: + f = executor.submit(lambda: 42) + assert f.result() == 42 + + with pytest.raises(RuntimeError): + executor.submit(lambda: 42) + + +def test_default_pool_executor(): + """Test that using the global instance of QThreadPool works""" + with qasync.QThreadPoolExecutor() as executor: + f = executor.submit(lambda: 42) + assert f.result() == 42 From e65622ca08a33caa79a510932f72a6892d54346e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 15 Aug 2025 14:20:01 +0000 Subject: [PATCH 07/11] Use correct TimeoutError alias from the concurrent.futures module --- src/qasync/__init__.py | 3 +-- tests/test_qthreadexec.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/qasync/__init__.py b/src/qasync/__init__.py index 6e3f665..13a7863 100644 --- a/src/qasync/__init__.py +++ b/src/qasync/__init__.py @@ -17,7 +17,6 @@ "asyncWrap", ] -from ast import Not import asyncio import contextlib import functools @@ -28,7 +27,7 @@ import os import sys import time -from concurrent.futures import CancelledError, Future, TimeoutError +from concurrent.futures import Future from queue import Queue from weakref import WeakSet diff --git a/tests/test_qthreadexec.py b/tests/test_qthreadexec.py index 11601b0..f3d7bba 100644 --- a/tests/test_qthreadexec.py +++ b/tests/test_qthreadexec.py @@ -2,10 +2,11 @@ # © 2014 Mark Harviston # © 2014 Arve Knudsen # BSD License -import time import logging import threading +import time import weakref +from concurrent.futures import TimeoutError import pytest From af83d9239dbf4439d0ee897feb57f4ac02ce0204 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 15 Aug 2025 14:26:30 +0000 Subject: [PATCH 08/11] make time test less flaky --- tests/test_qthreadexec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_qthreadexec.py b/tests/test_qthreadexec.py index f3d7bba..fe98f73 100644 --- a/tests/test_qthreadexec.py +++ b/tests/test_qthreadexec.py @@ -153,7 +153,7 @@ def func(x): with pytest.raises(TimeoutError): list(executor.map(func, range(10), timeout=0.01)) duration = time.monotonic() - start - assert duration < 0.02 + assert duration < 0.05 executor.shutdown(wait=True, cancel_futures=cancel) if not cancel: From 274f835311cb78ab4691ca08d21f3082894cf23e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 15 Aug 2025 14:35:10 +0000 Subject: [PATCH 09/11] Set the QThreadPool stacksize to match python expectations. --- src/qasync/__init__.py | 21 ++++++++++++++------- tests/test_qthreadexec.py | 3 +++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/qasync/__init__.py b/src/qasync/__init__.py index 13a7863..616c1b3 100644 --- a/src/qasync/__init__.py +++ b/src/qasync/__init__.py @@ -202,6 +202,19 @@ def __enter__(self, *args): def __exit__(self, *args): self.shutdown() + @staticmethod + def compute_stack_size(): + # Match cpython/Python/thread_pthread.h + if sys.platform.startswith("darwin"): + stack_size = 16 * 2**20 + elif sys.platform.startswith("freebsd"): + stack_size = 4 * 2**20 + elif sys.platform.startswith("aix"): + stack_size = 2 * 2**20 + else: + stack_size = None + return stack_size + @with_logger class QThreadExecutor(QThreadExecutorBase): @@ -222,13 +235,7 @@ def __init__(self, max_workers=10, stack_size=None): self.__max_workers = max_workers self.__queue = Queue() if stack_size is None: - # Match cpython/Python/thread_pthread.h - if sys.platform.startswith("darwin"): - stack_size = 16 * 2**20 - elif sys.platform.startswith("freebsd"): - stack_size = 4 * 2**20 - elif sys.platform.startswith("aix"): - stack_size = 2 * 2**20 + stack_size = self.compute_stack_size() self.__workers = [ _QThreadWorker(self.__queue, i + 1, stack_size) for i in range(max_workers) ] diff --git a/tests/test_qthreadexec.py b/tests/test_qthreadexec.py index fe98f73..8506dc9 100644 --- a/tests/test_qthreadexec.py +++ b/tests/test_qthreadexec.py @@ -46,6 +46,9 @@ def executor(request): def get_executor(request): if request.param is qasync.QThreadPoolExecutor: pool = qasync.QtCore.QThreadPool() + stack_size = qasync.QThreadExecutorBase.compute_stack_size() + if stack_size is not None: + pool.setStackSize(stack_size) pool.setMaxThreadCount(5) return request.param(pool) else: From 2af0894b6beed4e3fb4d12b5dd635d9c7fa3bc22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 15 Aug 2025 14:42:22 +0000 Subject: [PATCH 10/11] shut down the system global pool after test --- tests/test_qthreadexec.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_qthreadexec.py b/tests/test_qthreadexec.py index 8506dc9..ce07c52 100644 --- a/tests/test_qthreadexec.py +++ b/tests/test_qthreadexec.py @@ -184,3 +184,4 @@ def test_default_pool_executor(): with qasync.QThreadPoolExecutor() as executor: f = executor.submit(lambda: 42) assert f.result() == 42 + executor.pool.waitForDone() From 645c3f7e8587c82ec3c117dfdba73ec5c0ada291 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 17 Aug 2025 12:40:27 +0000 Subject: [PATCH 11/11] Use the robust map() implementation from standard library. --- src/qasync/__init__.py | 42 +++++++++++++++++++----- tests/test_qthreadexec.py | 67 +++++++++++++++++++++++++++++++++------ 2 files changed, 92 insertions(+), 17 deletions(-) diff --git a/src/qasync/__init__.py b/src/qasync/__init__.py index 616c1b3..e5f568d 100644 --- a/src/qasync/__init__.py +++ b/src/qasync/__init__.py @@ -170,6 +170,17 @@ def wait(self): super().wait() +def _result_or_cancel(fut, timeout=None): + try: + try: + return fut.result(timeout) + finally: + fut.cancel() + finally: + # Break a reference cycle with the exception in self._exception + del fut + + class QThreadExecutorBase: def __init__(self): self._been_shutdown = False @@ -180,14 +191,29 @@ def submit(self, callback, *args, **kwargs): def map(self, func, *iterables, timeout=None, chunksize=1): """Map the function to the iterables in a blocking way.""" - # iterables are consumed immediately - start = time.monotonic() - futures = list(map(lambda *args: self.submit(func, *args), *iterables)) - for future in futures: - if timeout is not None: - yield future.result(timeout=time.monotonic() - start) - else: - yield future.result() + # based on standard python implementation for BaseExecutor.map + end_time = time.monotonic() + timeout if timeout is not None else None + futures = [self.submit(func, *args) for args in zip(*iterables)] + + # the generator must be an inner function so that map() and the submit + # occurs immediately. + def generator(): + # reverse and pop to not keep future references around + # (for reference cycles in exceptions) + try: + futures.reverse() + while futures: + if end_time is not None: + yield _result_or_cancel( + futures.pop(), timeout=end_time - time.monotonic() + ) + else: + yield _result_or_cancel(futures.pop()) + finally: + for future in futures: + future.cancel() + + return generator() def shutdown(self, wait=True, *, cancel_futures=False): if self._been_shutdown: diff --git a/tests/test_qthreadexec.py b/tests/test_qthreadexec.py index ce07c52..7363dfa 100644 --- a/tests/test_qthreadexec.py +++ b/tests/test_qthreadexec.py @@ -7,6 +7,7 @@ import time import weakref from concurrent.futures import TimeoutError +from itertools import islice import pytest @@ -140,9 +141,11 @@ def test_map(executor): results = list(executor.map(lambda x: x + 1, range(10))) assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + results = list(executor.map(lambda x, y: x + y, range(10), range(9))) + assert results == [0, 2, 4, 6, 8, 10, 12, 14, 16] -@pytest.mark.parametrize("cancel", [True, False]) -def test_map_timeout(executor, cancel): + +def test_map_timeout(executor): """Test that map with timeout raises TimeoutError and cancels futures""" results = [] @@ -158,15 +161,61 @@ def func(x): duration = time.monotonic() - start assert duration < 0.05 + executor.shutdown(wait=True) + # only about half of the tasks should have completed + # because the max number of workers is 5 and the rest of + # the tasks were not started at the time of the cancel. + assert set(results) != {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + +def test_map_error(executor): + """Test that map with an exception will raise, and remaining tasks are cancelled""" + results = [] + + def func(x): + nonlocal results + time.sleep(0.05) + if len(results) == 5: + raise ValueError("Test error") + results.append(x) + return x + + with pytest.raises(ValueError): + list(executor.map(func, range(15))) + + executor.shutdown(wait=True, cancel_futures=False) + assert len(results) <= 10, "Final 5 at least should have been cancelled" + + +@pytest.mark.parametrize("cancel", [True, False]) +def test_map_shutdown(executor, cancel): + results = [] + + def func(x): + nonlocal results + time.sleep(0.05) + results.append(x) + return x + + # Get the first few results. + # Keep the iterator alive so that it isn't closed when its reference is dropped. + m = executor.map(func, range(15)) + values = list(islice(m, 5)) + assert values == [0, 1, 2, 3, 4] + executor.shutdown(wait=True, cancel_futures=cancel) - if not cancel: - # they were not cancelled - assert set(results) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + if cancel: + assert len(results) < 15, "Some tasks should have been cancelled" else: - # only about half of the tasks should have completed - # because the max number of workers is 5 and the rest of - # the tasks were not started at the time of the cancel. - assert set(results) != {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + assert len(results) == 15, "All tasks should have been completed" + + +def test_map_start(executor): + """Test that map starts tasks immediately, before iterating""" + e = threading.Event() + m = executor.map(lambda x: (e.set(), x), range(1)) + e.wait(timeout=0.1) + assert list(m) == [(None, 0)] def test_context(executor):