diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index 5c61d3f8f03..84f5cea9526 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -299,6 +299,8 @@ def task_pre_step( self._save_logs_sidecar.start() # Start spot termination monitor sidecar. + # TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process. + os.environ["MF_MAIN_PID"] = str(os.getpid()) current._update_env( {"spot_termination_notice": "/tmp/spot_termination_notice"} ) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 069c63ef211..f67afd578d8 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -559,6 +559,8 @@ def task_pre_step( self._save_logs_sidecar.start() # Start spot termination monitor sidecar. + # TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process. + os.environ["MF_MAIN_PID"] = str(os.getpid()) current._update_env( {"spot_termination_notice": "/tmp/spot_termination_notice"} ) diff --git a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py index 59f821f885e..115b815a6e4 100644 --- a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py +++ b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py @@ -21,6 +21,9 @@ def __init__(self): self._token = None self._token_expiry = 0 + # Due to nesting, os.getppid is not reliable for fetching the main task pid + self.main_pid = int(os.getenv("MF_MAIN_PID", os.getppid())) + if self._is_aws_spot_instance(): self._process = Process(target=self._monitor_loop) self._process.start() @@ -71,7 +74,7 @@ def _monitor_loop(self): if response.status_code == 200: termination_time = response.text self._emit_termination_metadata(termination_time) - os.kill(os.getppid(), signal.SIGTERM) + os.kill(self.main_pid, signal.SIGUSR1) break except (requests.exceptions.RequestException, requests.exceptions.Timeout): pass