Skip to content

Commit

Permalink
mypy: add strict typing to prefork
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Nov 27, 2024
1 parent 4685e66 commit 17ecc73
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 45 deletions.
147 changes: 105 additions & 42 deletions pytools/prefork.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,65 @@
initialization that can do the forking for the fork-challenged
parent process.
Since none of this is MPI-specific, it got parked in pytools.
Since none of this is MPI-specific, it got parked in :mod:`pytools`.
.. autoexception:: ExecError
:show-inheritance:
.. autoclass:: Forker
.. autoclass:: DirectForker
.. autoclass:: IndirectForker
.. autofunction:: enable_prefork
.. autofunction:: call
.. autofunction:: call_async
.. autofunction:: call_capture_output
.. autofunction:: wait
.. autofunction:: waitall
"""

import socket
from abc import ABC, abstractmethod
from collections.abc import Sequence
from subprocess import Popen
from typing import Any


class ExecError(OSError):
pass


class DirectForker:
def __init__(self):
self.apids = {}
self.count = 0
class Forker(ABC):
@abstractmethod
def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
pass

@abstractmethod
def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
pass

@abstractmethod
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
pass

@abstractmethod
def wait(self, aid: int) -> int:
pass

@abstractmethod
def waitall(self) -> dict[int, int]:
pass


class DirectForker(Forker):
def __init__(self) -> None:
self.apids: dict[int, Popen[bytes]] = {}
self.count: int = 0

@staticmethod
def call(cmdline, cwd=None):
def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
from subprocess import call as spcall

try:
Expand All @@ -26,9 +70,7 @@ def call(cmdline, cwd=None):
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e

def call_async(self, cmdline, cwd=None):
from subprocess import Popen

def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
try:
self.count += 1

Expand All @@ -40,8 +82,10 @@ def call_async(self, cmdline, cwd=None):
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e

@staticmethod
def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
from subprocess import PIPE, Popen

try:
Expand All @@ -60,22 +104,22 @@ def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e

def wait(self, aid):
def wait(self, aid: int) -> int:
proc = self.apids.pop(aid)
retc = proc.wait()

return retc

def waitall(self):
def waitall(self) -> dict[int, int]:
rets = {}

for aid in list(self.apids):
for aid in self.apids:
rets[aid] = self.wait(aid)

return rets


def _send_packet(sock, data):
def _send_packet(sock: socket.socket, data: Any) -> None:
from pickle import dumps
from struct import pack

Expand All @@ -85,7 +129,9 @@ def _send_packet(sock, data):
sock.sendall(packet)


def _recv_packet(sock, who="Process", partner="other end"):
def _recv_packet(sock: socket.socket,
who: str = "Process",
partner: str = "other end") -> Any:
from struct import calcsize, unpack
size_bytes_size = calcsize("I")
size_bytes = sock.recv(size_bytes_size)
Expand All @@ -103,7 +149,7 @@ def _recv_packet(sock, who="Process", partner="other end"):
return loads(packet)


def _fork_server(sock):
def _fork_server(sock: socket.socket) -> None:
# Ignore keyboard interrupts, we'll get notified by the parent.
import signal
signal.signal(signal.SIGINT, signal.SIG_IGN)
Expand Down Expand Up @@ -131,7 +177,7 @@ def _fork_server(sock):
break
else:
try:
result = funcs[func_name](*args, **kwargs)
result = funcs[func_name](*args, **kwargs) # type: ignore[operator]
# FIXME: Is catching all exceptions the right course of action?
except Exception as e: # pylint:disable=broad-except
_send_packet(sock, ("exception", e))
Expand All @@ -144,15 +190,15 @@ def _fork_server(sock):
os._exit(0)


class IndirectForker:
def __init__(self, server_pid, sock):
class IndirectForker(Forker):
def __init__(self, server_pid: int, sock: socket.socket) -> None:
self.server_pid = server_pid
self.socket = sock

import atexit
atexit.register(self._quit)

def _remote_invoke(self, name, *args, **kwargs):
def _remote_invoke(self, name: str, *args: Any, **kwargs: Any) -> Any:
_send_packet(self.socket, (name, args, kwargs))
status, result = _recv_packet(
self.socket, who="Prefork client", partner="prefork server"
Expand All @@ -164,40 +210,55 @@ def _remote_invoke(self, name, *args, **kwargs):
assert status == "ok"
return result

def _quit(self):
def _quit(self) -> None:
self._remote_invoke("quit")

from os import waitpid
waitpid(self.server_pid, 0)

def call(self, cmdline, cwd=None):
return self._remote_invoke("call", cmdline, cwd)
def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
result = self._remote_invoke("call", cmdline, cwd)

assert isinstance(result, int)
return result

def call_async(self, cmdline, cwd=None):
return self._remote_invoke("call_async", cmdline, cwd)
def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
result = self._remote_invoke("call_async", cmdline, cwd)

assert isinstance(result, int)
return result

def call_capture_output(self, cmdline, cwd=None, error_on_nonzero=True):
return self._remote_invoke("call_capture_output", cmdline, cwd,
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True,
) -> tuple[int, bytes, bytes]:
return self._remote_invoke("call_capture_output", cmdline, cwd, # type: ignore[no-any-return,unused-ignore]
error_on_nonzero)

def wait(self, aid):
return self._remote_invoke("wait", aid)
def wait(self, aid: int) -> int:
result = self._remote_invoke("wait", aid)

def waitall(self):
return self._remote_invoke("waitall")
assert isinstance(result, int)
return result

def waitall(self) -> dict[int, int]:
result = self._remote_invoke("waitall")

assert isinstance(result, dict)
return result


forker = DirectForker()
forker: Forker = DirectForker()


def enable_prefork():
def enable_prefork() -> None:
global forker

if isinstance(forker, IndirectForker):
return

from socket import socketpair
s_parent, s_child = socketpair()
s_parent, s_child = socket.socketpair()

from os import fork
fork_res = fork()
Expand All @@ -212,21 +273,23 @@ def enable_prefork():
forker = IndirectForker(fork_res, s_parent)


def call(cmdline, cwd=None):
def call(cmdline: Sequence[str], cwd: str | None = None) -> int:
return forker.call(cmdline, cwd)


def call_async(cmdline, cwd=None):
def call_async(cmdline: Sequence[str], cwd: str | None = None) -> int:
return forker.call_async(cmdline, cwd)


def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
def call_capture_output(cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
return forker.call_capture_output(cmdline, cwd, error_on_nonzero)


def wait(aid):
def wait(aid: int) -> int:
return forker.wait(aid)


def waitall():
def waitall() -> dict[int, int]:
return forker.waitall()
7 changes: 4 additions & 3 deletions run-mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ set -ex
mypy --show-error-codes pytools

mypy --strict --follow-imports=silent \
pytools/tag.py \
pytools/graph.py \
pytools/datatable.py \
pytools/persistent_dict.py
pytools/graph.py \
pytools/persistent_dict.py \
pytools/prefork.py \
pytools/tag.py \

0 comments on commit 17ecc73

Please sign in to comment.