Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mypy: add strict typing to prefork #270

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 112 additions & 43 deletions pytools/prefork.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,65 @@
initialization that can do the forking for the fork-challenged
parent process.

Since none of this is MPI-specific, it got parked in pytools.
Since none of this is MPI-specific, it got parked in :mod:`pytools`.

.. autoexception:: ExecError
:show-inheritance:

.. autoclass:: Forker
.. autoclass:: DirectForker
.. autoclass:: IndirectForker

.. autofunction:: enable_prefork
.. autofunction:: call
.. autofunction:: call_async
.. autofunction:: call_capture_output
.. autofunction:: wait
.. autofunction:: waitall
"""

import socket
from abc import ABC, abstractmethod
from collections.abc import Sequence
from subprocess import Popen
from typing import Any


class ExecError(OSError):
pass


class DirectForker:
def __init__(self):
self.apids = {}
self.count = 0
class Forker(ABC):
@abstractmethod
def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
pass

@abstractmethod
def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
pass

@abstractmethod
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
pass

@staticmethod
def call(cmdline, cwd=None):
@abstractmethod
def wait(self, aid: int) -> int:
pass

@abstractmethod
def waitall(self) -> dict[int, int]:
pass


class DirectForker(Forker):
def __init__(self) -> None:
self.apids: dict[int, Popen[bytes]] = {}
self.count: int = 0

def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
from subprocess import call as spcall

try:
Expand All @@ -26,9 +70,7 @@ def call(cmdline, cwd=None):
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e

def call_async(self, cmdline, cwd=None):
from subprocess import Popen

def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
try:
self.count += 1

Expand All @@ -40,8 +82,10 @@ def call_async(self, cmdline, cwd=None):
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e

@staticmethod
def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
from subprocess import PIPE, Popen

try:
Expand All @@ -60,22 +104,22 @@ def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e

def wait(self, aid):
def wait(self, aid: int) -> int:
proc = self.apids.pop(aid)
retc = proc.wait()

return retc

def waitall(self):
def waitall(self) -> dict[int, int]:
rets = {}

for aid in list(self.apids):
for aid in self.apids:
rets[aid] = self.wait(aid)

return rets


def _send_packet(sock, data):
def _send_packet(sock: socket.socket, data: object) -> None:
from pickle import dumps
from struct import pack

Expand All @@ -85,7 +129,9 @@ def _send_packet(sock, data):
sock.sendall(packet)


def _recv_packet(sock, who="Process", partner="other end"):
def _recv_packet(sock: socket.socket,
who: str = "Process",
partner: str = "other end") -> tuple[object, ...]:
from struct import calcsize, unpack
size_bytes_size = calcsize("I")
size_bytes = sock.recv(size_bytes_size)
Expand All @@ -100,10 +146,14 @@ def _recv_packet(sock, who="Process", partner="other end"):
packet += sock.recv(size)

from pickle import loads
return loads(packet)

result = loads(packet)
assert isinstance(result, tuple)

def _fork_server(sock):
return result


def _fork_server(sock: socket.socket) -> None:
# Ignore keyboard interrupts, we'll get notified by the parent.
import signal
signal.signal(signal.SIGINT, signal.SIG_IGN)
Expand All @@ -124,13 +174,14 @@ def _fork_server(sock):
func_name, args, kwargs = _recv_packet(
sock, who="Prefork server", partner="parent"
)
assert isinstance(func_name, str)

if func_name == "quit":
df.waitall()
_send_packet(sock, ("ok", None))
break
try:
result = funcs[func_name](*args, **kwargs)
result = funcs[func_name](*args, **kwargs) # type: ignore[operator]
# FIXME: Is catching all exceptions the right course of action?
except Exception as e: # pylint:disable=broad-except
_send_packet(sock, ("exception", e))
Expand All @@ -143,60 +194,76 @@ def _fork_server(sock):
os._exit(0)


class IndirectForker:
def __init__(self, server_pid, sock):
class IndirectForker(Forker):
def __init__(self, server_pid: int, sock: socket.socket) -> None:
self.server_pid = server_pid
self.socket = sock

import atexit
atexit.register(self._quit)

def _remote_invoke(self, name, *args, **kwargs):
def _remote_invoke(self, name: str, *args: Any, **kwargs: Any) -> object:
_send_packet(self.socket, (name, args, kwargs))
status, result = _recv_packet(
self.socket, who="Prefork client", partner="prefork server"
)

if status == "exception":
assert isinstance(result, Exception)
raise result

assert status == "ok"
return result

def _quit(self):
def _quit(self) -> None:
self._remote_invoke("quit")

from os import waitpid
waitpid(self.server_pid, 0)

def call(self, cmdline, cwd=None):
return self._remote_invoke("call", cmdline, cwd)
def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
result = self._remote_invoke("call", cmdline, cwd)

assert isinstance(result, int)
return result

def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
result = self._remote_invoke("call_async", cmdline, cwd)

def call_async(self, cmdline, cwd=None):
return self._remote_invoke("call_async", cmdline, cwd)
assert isinstance(result, int)
return result

def call_capture_output(self, cmdline, cwd=None, error_on_nonzero=True):
return self._remote_invoke("call_capture_output", cmdline, cwd,
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True,
) -> tuple[int, bytes, bytes]:
return self._remote_invoke("call_capture_output", cmdline, cwd, # type: ignore[return-value]
error_on_nonzero)

def wait(self, aid):
return self._remote_invoke("wait", aid)
def wait(self, aid: int) -> int:
result = self._remote_invoke("wait", aid)

assert isinstance(result, int)
return result

def waitall(self):
return self._remote_invoke("waitall")
def waitall(self) -> dict[int, int]:
result = self._remote_invoke("waitall")

assert isinstance(result, dict)
return result


forker = DirectForker()
forker: Forker = DirectForker()


def enable_prefork():
def enable_prefork() -> None:
global forker

if isinstance(forker, IndirectForker):
return

from socket import socketpair
s_parent, s_child = socketpair()
s_parent, s_child = socket.socketpair()

from os import fork
fork_res = fork()
Expand All @@ -211,21 +278,23 @@ def enable_prefork():
forker = IndirectForker(fork_res, s_parent)


def call(cmdline, cwd=None):
def call(cmdline: Sequence[str], cwd: str | None = None) -> int:
return forker.call(cmdline, cwd)


def call_async(cmdline, cwd=None):
def call_async(cmdline: Sequence[str], cwd: str | None = None) -> int:
return forker.call_async(cmdline, cwd)


def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
def call_capture_output(cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
return forker.call_capture_output(cmdline, cwd, error_on_nonzero)


def wait(aid):
def wait(aid: int) -> int:
return forker.wait(aid)


def waitall():
def waitall() -> dict[int, int]:
return forker.waitall()
11 changes: 6 additions & 5 deletions run-mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

set -ex

mypy pytools
python -m mypy pytools

mypy --strict --follow-imports=silent \
pytools/tag.py \
pytools/graph.py \
python -m mypy --strict --follow-imports=silent \
pytools/datatable.py \
pytools/persistent_dict.py
pytools/graph.py \
pytools/persistent_dict.py \
pytools/prefork.py \
pytools/tag.py \
Loading