Skip to content

Commit

Permalink
Automatically get group id by pid, and change some default value
Browse files Browse the repository at this point in the history
When using killpg, we need to call with group id instead of pid

1. Get group id by pid for `killpg`
2. Modify `kill` and `terminate` 's default group to `False` for better
   compatibility.

Co-authored-by: Peter Rowlands (변기호) <[email protected]>
  • Loading branch information
karajan1001 and pmrowla committed Dec 9, 2022
1 parent b5cc7d4 commit 80e11d3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
18 changes: 12 additions & 6 deletions src/dvc_task/proc/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def run_signature(
immutable=immutable,
)

def send_signal(self, name: str, sig: int, group: bool = True):
def send_signal(self, name: str, sig: int, group: bool = False):
"""Send `signal` to the specified named process."""
try:
process_info = self[name]
Expand All @@ -135,9 +135,10 @@ def handle_closed_process():
if process_info.returncode is None:
try:
if sys.platform != "win32" and group:
os.killpg( # pylint: disable=no-member
process_info.pid, sig
pgid = os.getpgid( # pylint: disable=no-member
process_info.pid
)
os.killpg(pgid, sig) # pylint: disable=no-member
else:
os.kill(process_info.pid, sig)
except ProcessLookupError:
Expand All @@ -154,13 +155,18 @@ def handle_closed_process():

def interrupt(self, name: str, group: bool = True):
"""Send interrupt signal to specified named process"""
self.send_signal(name, signal.SIGINT, group)
if sys.platform == "win32":
self.send_signal(
name, signal.CTRL_C_EVENT, group # pylint: disable=no-member
)
else:
self.send_signal(name, signal.SIGINT, group)

def terminate(self, name: str, group: bool = True):
def terminate(self, name: str, group: bool = False):
"""Terminate the specified named process."""
self.send_signal(name, signal.SIGTERM, group)

def kill(self, name: str, group: bool = True):
def kill(self, name: str, group: bool = False):
"""Kill the specified named process."""
if sys.platform == "win32":
self.send_signal(name, signal.SIGTERM, group)
Expand Down
33 changes: 24 additions & 9 deletions tests/proc/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,25 @@ def test_send_signal(
mock_kill.assert_called_once_with(PID_RUNNING, signal.SIGTERM)

if sys.platform != "win32":
gid = 100
mocker.patch("os.getpgid", return_value=gid)
mock_killpg = mocker.patch("os.killpg")
process_manager.send_signal(running_process, signal.SIGINT, True)
mock_killpg.assert_called_once_with(PID_RUNNING, signal.SIGINT)
mock_killpg.assert_called_once_with(gid, signal.SIGINT)
else:
mock_kill.reset_mock()
process_manager.send_signal(
running_process,
signal.CTRL_C_EVENT, # pylint: disable=no-member
True,
)
mock_kill.assert_called_once_with(
PID_RUNNING, signal.CTRL_C_EVENT # pylint: disable=no-member
)

mock_kill.reset_mock()
with pytest.raises(ProcessLookupError):
process_manager.send_signal(finished_process, signal.SIGTERM)
process_manager.send_signal(finished_process, signal.SIGTERM, False)
mock_kill.assert_not_called()

if sys.platform == "win32":
Expand All @@ -59,39 +71,42 @@ def side_effect(*args):

mocker.patch("os.kill", side_effect=side_effect)
with pytest.raises(ProcessLookupError):
process_manager.send_signal(running_process, signal.SIGTERM)
process_manager.send_signal(running_process, signal.SIGTERM, False)
assert process_manager[running_process].returncode == -1

with pytest.raises(ProcessLookupError):
process_manager.send_signal("nonexists", signal.SIGTERM)
process_manager.send_signal("nonexists", signal.SIGTERM, False)


if sys.platform == "win32":
SIGKILL = signal.SIGTERM
SIGINT = signal.CTRL_C_EVENT # pylint: disable=no-member
else:
SIGKILL = signal.SIGKILL # pylint: disable=no-member
SIGINT = signal.SIGINT


@pytest.mark.parametrize(
"method, sig",
"method, sig, group",
[
("kill", SIGKILL),
("terminate", signal.SIGTERM),
("interrupt", signal.SIGINT),
("kill", SIGKILL, False),
("terminate", signal.SIGTERM, False),
("interrupt", SIGINT, True),
],
)
def test_kill_commands(
mocker: MockerFixture,
process_manager: ProcessManager,
method: str,
sig: signal.Signals,
group: bool,
):
"""Test shortcut for different signals."""
name = "process"
mock_kill = mocker.patch.object(process_manager, "send_signal")
func = getattr(process_manager, method)
func(name)
mock_kill.assert_called_once_with(name, sig, True)
mock_kill.assert_called_once_with(name, sig, group)


def test_remove(
Expand Down

0 comments on commit 80e11d3

Please sign in to comment.