diff --git a/python/src/utils.rs b/python/src/utils.rs index 5ec2fe0a65..b063b64d08 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -10,6 +10,23 @@ use tokio::runtime::Runtime; #[inline] pub fn rt() -> &'static Runtime { static TOKIO_RT: OnceLock = OnceLock::new(); + static PID: OnceLock = OnceLock::new(); + match PID.get() { + Some(pid) if pid == &std::process::id() => {} // Reuse the static runtime. + Some(pid) => { + panic!( + "Forked process detected - current PID is {} but the tokio runtime was created by {}. The tokio \ + runtime does not support forked processes https://github.com/tokio-rs/tokio/issues/4301. If you are \ + seeing this message while using Python multithreading make sure to use the `spawn` or `forkserver` \ + mode.", + pid, std::process::id() + ); + } + None => { + PID.set(std::process::id()) + .expect("Failed to record PID for tokio runtime."); + } + } TOKIO_RT.get_or_init(|| Runtime::new().expect("Failed to create a tokio runtime.")) } diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 8d03ff0863..cc36fc0274 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -3,7 +3,7 @@ from pathlib import Path from random import random from threading import Barrier, Thread -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Type from unittest.mock import Mock from deltalake._util import encode_partition_value @@ -18,6 +18,9 @@ else: _has_pandas = True +import multiprocessing +from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor + import pyarrow as pa import pyarrow.dataset as ds import pytest @@ -57,6 +60,52 @@ def test_read_simple_table_to_dict(): assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]} +class _SerializableException(BaseException): + pass + + +def _recursively_read_simple_table(executor_class: Type[Executor], depth): + try: + test_read_simple_table_to_dict() + except BaseException as e: # Ideally this would catch `pyo3_runtime.PanicException` but its seems that is not possible. + # Re-raise as something that can be serialized and therefore sent back to parent processes. + raise _SerializableException(f"Seraializatble exception: {e}") from e + + if depth == 0: + return + # We use concurrent.futures.Executors instead of `threading.Thread` or `multiprocessing.Process` to that errors + # are re-rasied in the parent process/thread when we call `future.result()`. + with executor_class(max_workers=1) as executor: + future = executor.submit( + _recursively_read_simple_table, executor_class, depth - 1 + ) + future.result() + + +@pytest.mark.parametrize( + "executor_class,multiprocessing_start_method,expect_panic", + [ + (ThreadPoolExecutor, None, False), + (ProcessPoolExecutor, "forkserver", False), + (ProcessPoolExecutor, "spawn", False), + (ProcessPoolExecutor, "fork", True), + ], +) +def test_read_simple_in_threads_and_processes( + executor_class, multiprocessing_start_method, expect_panic +): + if multiprocessing_start_method is not None: + multiprocessing.set_start_method(multiprocessing_start_method, force=True) + if expect_panic: + with pytest.raises( + _SerializableException, + match="The tokio runtime does not support forked processes", + ): + _recursively_read_simple_table(executor_class=executor_class, depth=5) + else: + _recursively_read_simple_table(executor_class=executor_class, depth=5) + + def test_read_simple_table_by_version_to_dict(): table_path = "../crates/test/tests/data/delta-0.2.0" dt = DeltaTable(table_path, version=2)