diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 10f0ee54228b..984504b412de 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -414,3 +414,31 @@ def __getattr__(name: str) -> Any: msg = f"module {__name__!r} has no attribute {name!r}" raise AttributeError(msg) + + +# fork() breaks Polars thread pool, so warn users who might be doing this. +def __install_postfork_hook() -> None: + message = """\ +Using fork() can cause Polars to deadlock in the child process. +In addition, using fork() with Python in general is a recipe for mysterious +deadlocks and crashes. + +The most likely reason you are seeing this error is because you are using the +multiprocessing module on Linux, which uses fork() by default. This will be +fixed in Python 3.14. Until then, you want to use the "spawn" context instead. + +See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details. +""" + + def before_hook() -> None: + import warnings + + warnings.warn(message, RuntimeWarning, stacklevel=2) + + import os + + if hasattr(os, "register_at_fork"): + os.register_at_fork(before=before_hook) + + +__install_postfork_hook() diff --git a/py-polars/tests/unit/test_polars_import.py b/py-polars/tests/unit/test_polars_import.py index fa1779de3478..2686c094999b 100644 --- a/py-polars/tests/unit/test_polars_import.py +++ b/py-polars/tests/unit/test_polars_import.py @@ -1,6 +1,8 @@ from __future__ import annotations import compileall +import multiprocessing +import os import subprocess import sys from pathlib import Path @@ -97,3 +99,32 @@ def test_polars_import() -> None: import_time_ms = polars_import_time // 1_000 msg = f"Possible import speed regression; took {import_time_ms}ms\n{df_import}" raise AssertionError(msg) + + +def run_in_child() -> int: + return 123 + + +@pytest.mark.skipif(not hasattr(os, "fork"), reason="Requires fork()") +def test_fork_safety(recwarn: pytest.WarningsRecorder) -> None: + def get_num_fork_warnings() -> int: + fork_warnings = 0 + for warning in recwarn: + if issubclass(warning.category, RuntimeWarning) and str( + warning.message + ).startswith("Using fork() can cause Polars"): + fork_warnings += 1 + return fork_warnings + + assert get_num_fork_warnings() == 0 + + # Using forkserver and spawn context should not do any of our warning: + for context in ["spawn", "forkserver"]: + with multiprocessing.get_context(context).Pool(1) as pool: + assert pool.apply(run_in_child) == 123 + assert get_num_fork_warnings() == 0 + + # Using fork()-based multiprocessing should raise a warning: + with multiprocessing.get_context("fork").Pool(1) as pool: + assert pool.apply(run_in_child) == 123 + assert get_num_fork_warnings() == 1