Skip to content

Commit

Permalink
Tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom-Newton authored and ion-elgreco committed Aug 12, 2024
1 parent 28982f8 commit 9bdb8ac
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 17 deletions.
7 changes: 4 additions & 3 deletions python/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ pub fn rt() -> &'static Runtime {
Some(pid) if pid == &std::process::id() => {} // Reuse the static runtime.
Some(pid) => {
panic!(
"Forked process detected - current PID is {} but the tokio runtime was by {}. The tokio runtime
does not support forked processes https://github.com/tokio-rs/tokio/issues/4301. If you are seeing this
message while using Python multithreading make sure to use the `spawn` or `forkserver` mode.",
"Forked process detected - current PID is {} but the tokio runtime was created by {}. The tokio \
runtime does not support forked processes https://github.com/tokio-rs/tokio/issues/4301. If you are \
seeing this message while using Python multithreading make sure to use the `spawn` or `forkserver` \
mode.",
pid, std::process::id()
);
}
Expand Down
57 changes: 43 additions & 14 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from itertools import product
import os
from datetime import date, datetime, timezone
from pathlib import Path
from random import random
from threading import Barrier, Thread
from typing import Any, List, Tuple
from typing import Any, List, Tuple, Type
from unittest.mock import Mock

from deltalake._util import encode_partition_value
Expand All @@ -19,14 +18,15 @@
else:
_has_pandas = True

import multiprocessing
from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor

import pyarrow as pa
import pyarrow.dataset as ds
import pytest
from pyarrow.dataset import ParquetReadOptions
from pyarrow.fs import LocalFileSystem, SubTreeFileSystem

import multiprocessing
import threading
from deltalake import DeltaTable


Expand Down Expand Up @@ -59,22 +59,51 @@ def test_read_simple_table_to_dict():
dt = DeltaTable(table_path)
assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]}

def recursively_read_simple_table(thread_or_process_class, depth):
print(thread_or_process_class, depth)
test_read_simple_table_to_dict()

class _SerializableException(BaseException):
pass


def _recursively_read_simple_table(executor_class: Type[Executor], depth):
try:
test_read_simple_table_to_dict()
except BaseException as e: # Ideally this would catch `pyo3_runtime.PanicException` but its seems that is not possible.
# Re-raise as something that can be serialized and therefore sent back to parent processes.
raise _SerializableException(f"Seraializatble exception: {e}") from e

if depth == 0:
return

process_or_thread = thread_or_process_class(target=recursively_read_simple_table, args=(thread_or_process_class, depth - 1))
process_or_thread.start()
process_or_thread.join()
# We use concurrent.futures.Executors instead of `threading.Thread` or `multiprocessing.Process` to that errors
# are re-rasied in the parent process/thread when we call `future.result()`.
with executor_class(max_workers=1) as executor:
future = executor.submit(
_recursively_read_simple_table, executor_class, depth - 1
)
future.result()


@pytest.mark.parametrize("thread_or_process_class, multiprocessing_start_method", [(threading.Thread, None), (multiprocessing.Process, "forkserver"), (multiprocessing.Process, "spawn"), (multiprocessing.Process, "fork")])
def test_read_simple_in_threads_and_processes(thread_or_process_class, multiprocessing_start_method):
@pytest.mark.parametrize(
"executor_class,multiprocessing_start_method,expect_panic",
[
(ThreadPoolExecutor, None, False),
(ProcessPoolExecutor, "forkserver", False),
(ProcessPoolExecutor, "spawn", False),
(ProcessPoolExecutor, "fork", True),
],
)
def test_read_simple_in_threads_and_processes(
executor_class, multiprocessing_start_method, expect_panic
):
if multiprocessing_start_method is not None:
multiprocessing.set_start_method(multiprocessing_start_method, force=True)
recursively_read_simple_table(thread_or_process_class=thread_or_process_class, depth=10)
if expect_panic:
with pytest.raises(
_SerializableException,
match="The tokio runtime does not support forked processes",
):
_recursively_read_simple_table(executor_class=executor_class, depth=5)
else:
_recursively_read_simple_table(executor_class=executor_class, depth=5)


def test_read_simple_table_by_version_to_dict():
Expand Down

0 comments on commit 9bdb8ac

Please sign in to comment.