Skip to content

Commit

Permalink
mypy: add strict typing to prefork
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Nov 27, 2024
1 parent ec572ad commit 29a5017
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 48 deletions.
155 changes: 112 additions & 43 deletions pytools/prefork.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,65 @@
initialization that can do the forking for the fork-challenged
parent process.
Since none of this is MPI-specific, it got parked in pytools.
Since none of this is MPI-specific, it got parked in :mod:`pytools`.
.. autoexception:: ExecError
:show-inheritance:
.. autoclass:: Forker
.. autoclass:: DirectForker
.. autoclass:: IndirectForker
.. autofunction:: enable_prefork
.. autofunction:: call
.. autofunction:: call_async
.. autofunction:: call_capture_output
.. autofunction:: wait
.. autofunction:: waitall
"""

import socket
from abc import ABC, abstractmethod
from collections.abc import Sequence
from subprocess import Popen
from typing import Any


class ExecError(OSError):
pass


class DirectForker:
def __init__(self):
self.apids = {}
self.count = 0
class Forker(ABC):
@abstractmethod
def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
pass

@abstractmethod
def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
pass

@abstractmethod
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
pass

@staticmethod
def call(cmdline, cwd=None):
@abstractmethod
def wait(self, aid: int) -> int:
pass

@abstractmethod
def waitall(self) -> dict[int, int]:
pass


class DirectForker(Forker):
def __init__(self) -> None:
self.apids: dict[int, Popen[bytes]] = {}
self.count: int = 0

def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
from subprocess import call as spcall

try:
Expand All @@ -26,9 +70,7 @@ def call(cmdline, cwd=None):
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e

def call_async(self, cmdline, cwd=None):
from subprocess import Popen

def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
try:
self.count += 1

Expand All @@ -40,8 +82,10 @@ def call_async(self, cmdline, cwd=None):
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e

@staticmethod
def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
from subprocess import PIPE, Popen

try:
Expand All @@ -60,22 +104,22 @@ def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e

def wait(self, aid):
def wait(self, aid: int) -> int:
proc = self.apids.pop(aid)
retc = proc.wait()

return retc

def waitall(self):
def waitall(self) -> dict[int, int]:
rets = {}

for aid in list(self.apids):
for aid in self.apids:
rets[aid] = self.wait(aid)

return rets


def _send_packet(sock, data):
def _send_packet(sock: socket.socket, data: object) -> None:
from pickle import dumps
from struct import pack

Expand All @@ -85,7 +129,9 @@ def _send_packet(sock, data):
sock.sendall(packet)


def _recv_packet(sock, who="Process", partner="other end"):
def _recv_packet(sock: socket.socket,
who: str = "Process",
partner: str = "other end") -> tuple[object, ...]:
from struct import calcsize, unpack
size_bytes_size = calcsize("I")
size_bytes = sock.recv(size_bytes_size)
Expand All @@ -100,10 +146,14 @@ def _recv_packet(sock, who="Process", partner="other end"):
packet += sock.recv(size)

from pickle import loads
return loads(packet)

result = loads(packet)
assert isinstance(result, tuple)

def _fork_server(sock):
return result


def _fork_server(sock: socket.socket) -> None:
# Ignore keyboard interrupts, we'll get notified by the parent.
import signal
signal.signal(signal.SIGINT, signal.SIG_IGN)
Expand All @@ -124,13 +174,14 @@ def _fork_server(sock):
func_name, args, kwargs = _recv_packet(
sock, who="Prefork server", partner="parent"
)
assert isinstance(func_name, str)

if func_name == "quit":
df.waitall()
_send_packet(sock, ("ok", None))
break
try:
result = funcs[func_name](*args, **kwargs)
result = funcs[func_name](*args, **kwargs) # type: ignore[operator]
# FIXME: Is catching all exceptions the right course of action?
except Exception as e: # pylint:disable=broad-except
_send_packet(sock, ("exception", e))
Expand All @@ -143,60 +194,76 @@ def _fork_server(sock):
os._exit(0)


class IndirectForker:
def __init__(self, server_pid, sock):
class IndirectForker(Forker):
def __init__(self, server_pid: int, sock: socket.socket) -> None:
self.server_pid = server_pid
self.socket = sock

import atexit
atexit.register(self._quit)

def _remote_invoke(self, name, *args, **kwargs):
def _remote_invoke(self, name: str, *args: Any, **kwargs: Any) -> object:
_send_packet(self.socket, (name, args, kwargs))
status, result = _recv_packet(
self.socket, who="Prefork client", partner="prefork server"
)

if status == "exception":
assert isinstance(result, Exception)
raise result

assert status == "ok"
return result

def _quit(self):
def _quit(self) -> None:
self._remote_invoke("quit")

from os import waitpid
waitpid(self.server_pid, 0)

def call(self, cmdline, cwd=None):
return self._remote_invoke("call", cmdline, cwd)
def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
result = self._remote_invoke("call", cmdline, cwd)

assert isinstance(result, int)
return result

def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
result = self._remote_invoke("call_async", cmdline, cwd)

def call_async(self, cmdline, cwd=None):
return self._remote_invoke("call_async", cmdline, cwd)
assert isinstance(result, int)
return result

def call_capture_output(self, cmdline, cwd=None, error_on_nonzero=True):
return self._remote_invoke("call_capture_output", cmdline, cwd,
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True,
) -> tuple[int, bytes, bytes]:
return self._remote_invoke("call_capture_output", cmdline, cwd, # type: ignore[return-value]
error_on_nonzero)

def wait(self, aid):
return self._remote_invoke("wait", aid)
def wait(self, aid: int) -> int:
result = self._remote_invoke("wait", aid)

assert isinstance(result, int)
return result

def waitall(self):
return self._remote_invoke("waitall")
def waitall(self) -> dict[int, int]:
result = self._remote_invoke("waitall")

assert isinstance(result, dict)
return result


forker = DirectForker()
forker: Forker = DirectForker()


def enable_prefork():
def enable_prefork() -> None:
global forker

if isinstance(forker, IndirectForker):
return

from socket import socketpair
s_parent, s_child = socketpair()
s_parent, s_child = socket.socketpair()

from os import fork
fork_res = fork()
Expand All @@ -211,21 +278,23 @@ def enable_prefork():
forker = IndirectForker(fork_res, s_parent)


def call(cmdline, cwd=None):
def call(cmdline: Sequence[str], cwd: str | None = None) -> int:
return forker.call(cmdline, cwd)


def call_async(cmdline, cwd=None):
def call_async(cmdline: Sequence[str], cwd: str | None = None) -> int:
return forker.call_async(cmdline, cwd)


def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
def call_capture_output(cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
return forker.call_capture_output(cmdline, cwd, error_on_nonzero)


def wait(aid):
def wait(aid: int) -> int:
return forker.wait(aid)


def waitall():
def waitall() -> dict[int, int]:
return forker.waitall()
11 changes: 6 additions & 5 deletions run-mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

set -ex

mypy pytools
python -m mypy pytools

mypy --strict --follow-imports=silent \
pytools/tag.py \
pytools/graph.py \
python -m mypy --strict --follow-imports=silent \
pytools/datatable.py \
pytools/persistent_dict.py
pytools/graph.py \
pytools/persistent_dict.py \
pytools/prefork.py \
pytools/tag.py \

0 comments on commit 29a5017

Please sign in to comment.