diff --git a/pytools/prefork.py b/pytools/prefork.py index 6be5cf9b..3ca04c8f 100644 --- a/pytools/prefork.py +++ b/pytools/prefork.py @@ -3,21 +3,65 @@ initialization that can do the forking for the fork-challenged parent process. -Since none of this is MPI-specific, it got parked in pytools. +Since none of this is MPI-specific, it got parked in :mod:`pytools`. + +.. autoexception:: ExecError + :show-inheritance: + +.. autoclass:: Forker +.. autoclass:: DirectForker +.. autoclass:: IndirectForker + +.. autofunction:: enable_prefork +.. autofunction:: call +.. autofunction:: call_async +.. autofunction:: call_capture_output +.. autofunction:: wait +.. autofunction:: waitall """ +import socket +from abc import ABC, abstractmethod +from collections.abc import Sequence +from subprocess import Popen +from typing import Any + class ExecError(OSError): pass -class DirectForker: - def __init__(self): - self.apids = {} - self.count = 0 +class Forker(ABC): + @abstractmethod + def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int: + pass + + @abstractmethod + def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int: + pass + + @abstractmethod + def call_capture_output(self, + cmdline: Sequence[str], + cwd: str | None = None, + error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]: + pass - @staticmethod - def call(cmdline, cwd=None): + @abstractmethod + def wait(self, aid: int) -> int: + pass + + @abstractmethod + def waitall(self) -> dict[int, int]: + pass + + +class DirectForker(Forker): + def __init__(self) -> None: + self.apids: dict[int, Popen[bytes]] = {} + self.count: int = 0 + + def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int: from subprocess import call as spcall try: @@ -26,9 +70,7 @@ def call(cmdline, cwd=None): raise ExecError( "error invoking '{}': {}".format(" ".join(cmdline), e)) from e - def call_async(self, cmdline, cwd=None): - from subprocess import Popen - + def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int: try: self.count += 1 @@ -40,8 +82,10 @@ def call_async(self, cmdline, cwd=None): raise ExecError( "error invoking '{}': {}".format(" ".join(cmdline), e)) from e - @staticmethod - def call_capture_output(cmdline, cwd=None, error_on_nonzero=True): + def call_capture_output(self, + cmdline: Sequence[str], + cwd: str | None = None, + error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]: from subprocess import PIPE, Popen try: @@ -60,22 +104,22 @@ def call_capture_output(cmdline, cwd=None, error_on_nonzero=True): raise ExecError( "error invoking '{}': {}".format(" ".join(cmdline), e)) from e - def wait(self, aid): + def wait(self, aid: int) -> int: proc = self.apids.pop(aid) retc = proc.wait() return retc - def waitall(self): + def waitall(self) -> dict[int, int]: rets = {} - for aid in list(self.apids): + for aid in self.apids: rets[aid] = self.wait(aid) return rets -def _send_packet(sock, data): +def _send_packet(sock: socket.socket, data: object) -> None: from pickle import dumps from struct import pack @@ -85,7 +129,9 @@ def _send_packet(sock, data): sock.sendall(packet) -def _recv_packet(sock, who="Process", partner="other end"): +def _recv_packet(sock: socket.socket, + who: str = "Process", + partner: str = "other end") -> tuple[object, ...]: from struct import calcsize, unpack size_bytes_size = calcsize("I") size_bytes = sock.recv(size_bytes_size) @@ -100,10 +146,14 @@ def _recv_packet(sock, who="Process", partner="other end"): packet += sock.recv(size) from pickle import loads - return loads(packet) + result = loads(packet) + assert isinstance(result, tuple) -def _fork_server(sock): + return result + + +def _fork_server(sock: socket.socket) -> None: # Ignore keyboard interrupts, we'll get notified by the parent. import signal signal.signal(signal.SIGINT, signal.SIG_IGN) @@ -124,13 +174,14 @@ def _fork_server(sock): func_name, args, kwargs = _recv_packet( sock, who="Prefork server", partner="parent" ) + assert isinstance(func_name, str) if func_name == "quit": df.waitall() _send_packet(sock, ("ok", None)) break try: - result = funcs[func_name](*args, **kwargs) + result = funcs[func_name](*args, **kwargs) # type: ignore[operator] # FIXME: Is catching all exceptions the right course of action? except Exception as e: # pylint:disable=broad-except _send_packet(sock, ("exception", e)) @@ -143,60 +194,76 @@ def _fork_server(sock): os._exit(0) -class IndirectForker: - def __init__(self, server_pid, sock): +class IndirectForker(Forker): + def __init__(self, server_pid: int, sock: socket.socket) -> None: self.server_pid = server_pid self.socket = sock import atexit atexit.register(self._quit) - def _remote_invoke(self, name, *args, **kwargs): + def _remote_invoke(self, name: str, *args: Any, **kwargs: Any) -> object: _send_packet(self.socket, (name, args, kwargs)) status, result = _recv_packet( self.socket, who="Prefork client", partner="prefork server" ) if status == "exception": + assert isinstance(result, Exception) raise result assert status == "ok" return result - def _quit(self): + def _quit(self) -> None: self._remote_invoke("quit") from os import waitpid waitpid(self.server_pid, 0) - def call(self, cmdline, cwd=None): - return self._remote_invoke("call", cmdline, cwd) + def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int: + result = self._remote_invoke("call", cmdline, cwd) + + assert isinstance(result, int) + return result + + def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int: + result = self._remote_invoke("call_async", cmdline, cwd) - def call_async(self, cmdline, cwd=None): - return self._remote_invoke("call_async", cmdline, cwd) + assert isinstance(result, int) + return result - def call_capture_output(self, cmdline, cwd=None, error_on_nonzero=True): - return self._remote_invoke("call_capture_output", cmdline, cwd, + def call_capture_output(self, + cmdline: Sequence[str], + cwd: str | None = None, + error_on_nonzero: bool = True, + ) -> tuple[int, bytes, bytes]: + return self._remote_invoke("call_capture_output", cmdline, cwd, # type: ignore[return-value] error_on_nonzero) - def wait(self, aid): - return self._remote_invoke("wait", aid) + def wait(self, aid: int) -> int: + result = self._remote_invoke("wait", aid) + + assert isinstance(result, int) + return result - def waitall(self): - return self._remote_invoke("waitall") + def waitall(self) -> dict[int, int]: + result = self._remote_invoke("waitall") + + assert isinstance(result, dict) + return result -forker = DirectForker() +forker: Forker = DirectForker() -def enable_prefork(): +def enable_prefork() -> None: global forker if isinstance(forker, IndirectForker): return - from socket import socketpair - s_parent, s_child = socketpair() + s_parent, s_child = socket.socketpair() from os import fork fork_res = fork() @@ -211,21 +278,23 @@ def enable_prefork(): forker = IndirectForker(fork_res, s_parent) -def call(cmdline, cwd=None): +def call(cmdline: Sequence[str], cwd: str | None = None) -> int: return forker.call(cmdline, cwd) -def call_async(cmdline, cwd=None): +def call_async(cmdline: Sequence[str], cwd: str | None = None) -> int: return forker.call_async(cmdline, cwd) -def call_capture_output(cmdline, cwd=None, error_on_nonzero=True): +def call_capture_output(cmdline: Sequence[str], + cwd: str | None = None, + error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]: return forker.call_capture_output(cmdline, cwd, error_on_nonzero) -def wait(aid): +def wait(aid: int) -> int: return forker.wait(aid) -def waitall(): +def waitall() -> dict[int, int]: return forker.waitall() diff --git a/run-mypy.sh b/run-mypy.sh index 73955e7b..ec4ff3cb 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -2,10 +2,11 @@ set -ex -mypy pytools +python -m mypy pytools -mypy --strict --follow-imports=silent \ - pytools/tag.py \ - pytools/graph.py \ +python -m mypy --strict --follow-imports=silent \ pytools/datatable.py \ - pytools/persistent_dict.py + pytools/graph.py \ + pytools/persistent_dict.py \ + pytools/prefork.py \ + pytools/tag.py \