From 99ebc98afd4d055e73f37e4a02953132a9db26e0 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 12 Aug 2024 17:37:03 +0100 Subject: [PATCH 1/3] Add explicit error --- python/src/utils.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/src/utils.rs b/python/src/utils.rs index 5ec2fe0a65..3b0a0b0041 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -10,6 +10,22 @@ use tokio::runtime::Runtime; #[inline] pub fn rt() -> &'static Runtime { static TOKIO_RT: OnceLock = OnceLock::new(); + static PID: OnceLock = OnceLock::new(); + match PID.get() { + Some(pid) if pid == &std::process::id() => {} // Reuse the static runtime. + Some(pid) => { + panic!( + "Forked process detected - current PID is {} but the tokio runtime was by {}. The tokio runtime + does not support forked processes https://github.com/tokio-rs/tokio/issues/4301. If you are seeing this + message while using Python multithreading make sure to use the `spawn` or `forkserver` mode.", + pid, std::process::id() + ); + } + None => { + PID.set(std::process::id()) + .expect("Failed to record PID for tokio runtime."); + } + } TOKIO_RT.get_or_init(|| Runtime::new().expect("Failed to create a tokio runtime.")) } From 8f3f96a632f355ac7acc203c7258d2b8e98b4f43 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 12 Aug 2024 18:26:44 +0100 Subject: [PATCH 2/3] Start a test --- python/tests/test_table_read.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 8d03ff0863..6ab5030c8a 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -1,3 +1,4 @@ +from itertools import product import os from datetime import date, datetime, timezone from pathlib import Path @@ -24,6 +25,8 @@ from pyarrow.dataset import ParquetReadOptions from pyarrow.fs import LocalFileSystem, SubTreeFileSystem +import multiprocessing +import threading from deltalake import DeltaTable @@ -56,6 +59,23 @@ def test_read_simple_table_to_dict(): dt = DeltaTable(table_path) assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]} +def recursively_read_simple_table(thread_or_process_class, depth): + print(thread_or_process_class, depth) + test_read_simple_table_to_dict() + if depth == 0: + return + + process_or_thread = thread_or_process_class(target=recursively_read_simple_table, args=(thread_or_process_class, depth - 1)) + process_or_thread.start() + process_or_thread.join() + + +@pytest.mark.parametrize("thread_or_process_class, multiprocessing_start_method", [(threading.Thread, None), (multiprocessing.Process, "forkserver"), (multiprocessing.Process, "spawn"), (multiprocessing.Process, "fork")]) +def test_read_simple_in_threads_and_processes(thread_or_process_class, multiprocessing_start_method): + if multiprocessing_start_method is not None: + multiprocessing.set_start_method(multiprocessing_start_method, force=True) + recursively_read_simple_table(thread_or_process_class=thread_or_process_class, depth=10) + def test_read_simple_table_by_version_to_dict(): table_path = "../crates/test/tests/data/delta-0.2.0" From 7f79c4351c912858e8671b7c67abbe85395b9e2a Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 12 Aug 2024 20:32:35 +0100 Subject: [PATCH 3/3] Tidy --- python/src/utils.rs | 7 ++-- python/tests/test_table_read.py | 57 +++++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/python/src/utils.rs b/python/src/utils.rs index 3b0a0b0041..b063b64d08 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -15,9 +15,10 @@ pub fn rt() -> &'static Runtime { Some(pid) if pid == &std::process::id() => {} // Reuse the static runtime. Some(pid) => { panic!( - "Forked process detected - current PID is {} but the tokio runtime was by {}. The tokio runtime - does not support forked processes https://github.com/tokio-rs/tokio/issues/4301. If you are seeing this - message while using Python multithreading make sure to use the `spawn` or `forkserver` mode.", + "Forked process detected - current PID is {} but the tokio runtime was created by {}. The tokio \ + runtime does not support forked processes https://github.com/tokio-rs/tokio/issues/4301. If you are \ + seeing this message while using Python multithreading make sure to use the `spawn` or `forkserver` \ + mode.", pid, std::process::id() ); } diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 6ab5030c8a..cc36fc0274 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -1,10 +1,9 @@ -from itertools import product import os from datetime import date, datetime, timezone from pathlib import Path from random import random from threading import Barrier, Thread -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Type from unittest.mock import Mock from deltalake._util import encode_partition_value @@ -19,14 +18,15 @@ else: _has_pandas = True +import multiprocessing +from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor + import pyarrow as pa import pyarrow.dataset as ds import pytest from pyarrow.dataset import ParquetReadOptions from pyarrow.fs import LocalFileSystem, SubTreeFileSystem -import multiprocessing -import threading from deltalake import DeltaTable @@ -59,22 +59,51 @@ def test_read_simple_table_to_dict(): dt = DeltaTable(table_path) assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]} -def recursively_read_simple_table(thread_or_process_class, depth): - print(thread_or_process_class, depth) - test_read_simple_table_to_dict() + +class _SerializableException(BaseException): + pass + + +def _recursively_read_simple_table(executor_class: Type[Executor], depth): + try: + test_read_simple_table_to_dict() + except BaseException as e: # Ideally this would catch `pyo3_runtime.PanicException` but its seems that is not possible. + # Re-raise as something that can be serialized and therefore sent back to parent processes. + raise _SerializableException(f"Seraializatble exception: {e}") from e + if depth == 0: return - - process_or_thread = thread_or_process_class(target=recursively_read_simple_table, args=(thread_or_process_class, depth - 1)) - process_or_thread.start() - process_or_thread.join() + # We use concurrent.futures.Executors instead of `threading.Thread` or `multiprocessing.Process` to that errors + # are re-rasied in the parent process/thread when we call `future.result()`. + with executor_class(max_workers=1) as executor: + future = executor.submit( + _recursively_read_simple_table, executor_class, depth - 1 + ) + future.result() -@pytest.mark.parametrize("thread_or_process_class, multiprocessing_start_method", [(threading.Thread, None), (multiprocessing.Process, "forkserver"), (multiprocessing.Process, "spawn"), (multiprocessing.Process, "fork")]) -def test_read_simple_in_threads_and_processes(thread_or_process_class, multiprocessing_start_method): +@pytest.mark.parametrize( + "executor_class,multiprocessing_start_method,expect_panic", + [ + (ThreadPoolExecutor, None, False), + (ProcessPoolExecutor, "forkserver", False), + (ProcessPoolExecutor, "spawn", False), + (ProcessPoolExecutor, "fork", True), + ], +) +def test_read_simple_in_threads_and_processes( + executor_class, multiprocessing_start_method, expect_panic +): if multiprocessing_start_method is not None: multiprocessing.set_start_method(multiprocessing_start_method, force=True) - recursively_read_simple_table(thread_or_process_class=thread_or_process_class, depth=10) + if expect_panic: + with pytest.raises( + _SerializableException, + match="The tokio runtime does not support forked processes", + ): + _recursively_read_simple_table(executor_class=executor_class, depth=5) + else: + _recursively_read_simple_table(executor_class=executor_class, depth=5) def test_read_simple_table_by_version_to_dict():