Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fail fast on forked process #2765

Merged
merged 3 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@ use tokio::runtime::Runtime;
#[inline]
pub fn rt() -> &'static Runtime {
static TOKIO_RT: OnceLock<Runtime> = OnceLock::new();
static PID: OnceLock<u32> = OnceLock::new();
match PID.get() {
Some(pid) if pid == &std::process::id() => {} // Reuse the static runtime.
Some(pid) => {
panic!(
"Forked process detected - current PID is {} but the tokio runtime was created by {}. The tokio \
runtime does not support forked processes https://github.com/tokio-rs/tokio/issues/4301. If you are \
seeing this message while using Python multithreading make sure to use the `spawn` or `forkserver` \
mode.",
pid, std::process::id()
);
}
None => {
PID.set(std::process::id())
.expect("Failed to record PID for tokio runtime.");
}
}
TOKIO_RT.get_or_init(|| Runtime::new().expect("Failed to create a tokio runtime."))
}

Expand Down
51 changes: 50 additions & 1 deletion python/tests/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from random import random
from threading import Barrier, Thread
from typing import Any, List, Tuple
from typing import Any, List, Tuple, Type
from unittest.mock import Mock

from deltalake._util import encode_partition_value
Expand All @@ -18,6 +18,9 @@
else:
_has_pandas = True

import multiprocessing
from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor

import pyarrow as pa
import pyarrow.dataset as ds
import pytest
Expand Down Expand Up @@ -57,6 +60,52 @@ def test_read_simple_table_to_dict():
assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]}


class _SerializableException(BaseException):
pass


def _recursively_read_simple_table(executor_class: Type[Executor], depth):
try:
test_read_simple_table_to_dict()
except BaseException as e: # Ideally this would catch `pyo3_runtime.PanicException` but its seems that is not possible.
# Re-raise as something that can be serialized and therefore sent back to parent processes.
raise _SerializableException(f"Seraializatble exception: {e}") from e

if depth == 0:
return
# We use concurrent.futures.Executors instead of `threading.Thread` or `multiprocessing.Process` to that errors
# are re-rasied in the parent process/thread when we call `future.result()`.
with executor_class(max_workers=1) as executor:
future = executor.submit(
_recursively_read_simple_table, executor_class, depth - 1
)
future.result()


@pytest.mark.parametrize(
"executor_class,multiprocessing_start_method,expect_panic",
[
(ThreadPoolExecutor, None, False),
(ProcessPoolExecutor, "forkserver", False),
(ProcessPoolExecutor, "spawn", False),
(ProcessPoolExecutor, "fork", True),
],
)
def test_read_simple_in_threads_and_processes(
executor_class, multiprocessing_start_method, expect_panic
):
if multiprocessing_start_method is not None:
multiprocessing.set_start_method(multiprocessing_start_method, force=True)
if expect_panic:
with pytest.raises(
_SerializableException,
match="The tokio runtime does not support forked processes",
):
_recursively_read_simple_table(executor_class=executor_class, depth=5)
else:
_recursively_read_simple_table(executor_class=executor_class, depth=5)


def test_read_simple_table_by_version_to_dict():
table_path = "../crates/test/tests/data/delta-0.2.0"
dt = DeltaTable(table_path, version=2)
Expand Down
Loading