Skip to content

Commit 01ab15c

Browse files
committed
Test for SIGTERM
feature: Test for SIGTERM
1 parent 5d0ebcb commit 01ab15c

File tree

4 files changed

+60
-4
lines changed

4 files changed

+60
-4
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ GitHub provides additional document on [forking a repository](https://help.githu
3737
### Running the unit tests
3838

3939
1. Install tox using `pip install tox`
40-
1. Install coverage using `pip install .[test]`
40+
1. Install coverage using `pip install ".[test]"`
4141
1. cd into the sagemaker-training-toolkit folder: `cd sagemaker-training-toolkit`
4242
1. Run the following tox command and verify that all code checks and unit tests pass: `tox test/unit`
4343

src/sagemaker_training/process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs):
137137
cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs
138138
)
139139

140-
with capture_signal(signal.SIGTERM, lambda signalnum, frame: proc.send_signal(signalnum)):
140+
with capture_signal(signal.SIGTERM, lambda signalnum, *_: proc.send_signal(signalnum)):
141141
output = await asyncio.gather(
142142
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
143143
)
@@ -218,7 +218,7 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr
218218
process = subprocess.Popen(
219219
cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs
220220
)
221-
with capture_signal(signal.SIGTERM, lambda signalnum, frame: process.send_signal(signalnum)):
221+
with capture_signal(signal.SIGTERM, lambda signalnum, *_: process.send_signal(signalnum)):
222222
return_code = process.wait()
223223
if return_code:
224224
extra_info = None

test/unit/_test_process_helper.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
Helper script for testing signal handling
3+
4+
- If it receives SIGTERM, immediately exit "21"
5+
- If it doesn't receive a signal, sleep for 3 seconds then exit "-1"
6+
"""
7+
8+
import signal
9+
import time
10+
11+
12+
def signal_handler(signalnum, *_):
13+
assert signalnum == signal.SIGTERM
14+
exit(21)
15+
16+
17+
def main():
18+
signal.signal(signal.SIGTERM, signal_handler)
19+
time.sleep(3)
20+
exit(-1)
21+
22+
23+
if __name__ == "__main__":
24+
main()

test/unit/test_process.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
import asyncio
1616
import os
1717
import sys
18-
18+
import time
19+
import multiprocessing
1920
from mock import ANY, MagicMock, patch
2021
import pytest
2122

2223
from sagemaker_training import environment, errors, process
24+
from sagemaker_training.process import capture_signal
25+
import signal
2326

2427

2528
class AsyncMock(MagicMock):
@@ -175,3 +178,32 @@ def test_run_python(log, async_shell, async_gather, entry_point_type_script, eve
175178
stdout=asyncio.subprocess.PIPE,
176179
)
177180
log.assert_called_with(cmd, {})
181+
182+
183+
def _sleep_subprocess(capture_error):
184+
with pytest.raises(errors.ExecuteUserScriptError) as error:
185+
process.check_error(
186+
[
187+
sys.executable,
188+
os.path.abspath(os.path.join(__file__, "../_test_process_helper.py"))
189+
],
190+
errors.ExecuteUserScriptError,
191+
1,
192+
capture_error=capture_error
193+
)
194+
assert int(error.value.return_code) == 21
195+
exit(42)
196+
197+
198+
@pytest.mark.skipif(sys.version_info != (3, 7), reason="requires python3.7")
199+
@pytest.mark.parametrize("capture_error", [True, False])
200+
def test_check_error_signal(capture_error):
201+
proc = multiprocessing.Process(
202+
target=_sleep_subprocess,
203+
args=(capture_error,)
204+
)
205+
proc.start()
206+
time.sleep(1)
207+
os.kill(proc.pid, signal.SIGTERM)
208+
proc.join(1)
209+
assert int(proc.exitcode) == 42

0 commit comments

Comments
 (0)