diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5a55c56..2be49b0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,12 +21,7 @@ jobs: matrix: os: [ubuntu-latest] - python: ["3.8", "3.9", "3.10", "3.11"] - include: - - os: macos-latest - python: "3.8" - - os: macos-latest - python: "3.11" + python: ["3.11"] steps: - uses: actions/checkout@v3 diff --git a/src/isolate/connections/grpc/_base.py b/src/isolate/connections/grpc/_base.py index a4d3bfc..5ebcd80 100644 --- a/src/isolate/connections/grpc/_base.py +++ b/src/isolate/connections/grpc/_base.py @@ -1,4 +1,6 @@ +import os import socket +import subprocess from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path @@ -23,6 +25,11 @@ class AgentError(Exception): """An internal problem caused by (most probably) the agent.""" +PROCESS_SHUTDOWN_TIMEOUT_SECONDS = float( + os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "60") +) + + @dataclass class GRPCExecutionBase(EnvironmentConnection): """A customizable gRPC-based execution backend.""" @@ -128,9 +135,18 @@ def find_free_port() -> Tuple[str, int]: with self.start_process(address) as process: yield address, grpc.local_channel_credentials() finally: - if process is not None: - # TODO: should we check the status code here? + self.terminate_process(process) + + def terminate_process(self, process: Union[None, subprocess.Popen]) -> None: + if process is not None: + try: + print("Terminating the agent process...") process.terminate() + process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS) + print("Agent process shutdown gracefully") + except Exception as exc: + print(f"Failed to shutdown the agent process gracefully: {exc}") + process.kill() def get_python_cmd( self, diff --git a/src/isolate/connections/grpc/agent.py b/src/isolate/connections/grpc/agent.py index 08bca36..56e648c 100644 --- a/src/isolate/connections/grpc/agent.py +++ b/src/isolate/connections/grpc/agent.py @@ -10,7 +10,9 @@ from __future__ import annotations +import functools import os +import signal import sys import traceback from argparse import ArgumentParser @@ -23,7 +25,9 @@ ) import grpc -from grpc import ServicerContext, StatusCode +from grpc import StatusCode, local_server_credentials + +from isolate.connections.grpc.definitions import PartialRunResult try: from isolate import __version__ as agent_version @@ -48,12 +52,13 @@ def __init__(self, log_fd: int | None = None): self._run_cache: dict[str, Any] = {} self._log = sys.stdout if log_fd is None else os.fdopen(log_fd, "w") + self._current_callable: Any = None def Run( self, request: definitions.FunctionCall, - context: ServicerContext, - ) -> Iterator[definitions.PartialRunResult]: + context: grpc.ServicerContext, + ) -> Iterator[PartialRunResult]: self.log(f"A connection has been established: {context.peer()}!") server_version = os.getenv("ISOLATE_SERVER_VERSION") or "unknown" self.log(f"Isolate info: server {server_version}, agent {agent_version}") @@ -87,7 +92,8 @@ def Run( ) raise AbortException("The setup function has thrown an error.") except AbortException as exc: - return self.abort_with_msg(context, exc.message) + self.abort_with_msg(context, exc.message) + return else: assert not was_it_raised self._run_cache[cache_key] = result @@ -107,7 +113,8 @@ def Run( stringized_tb, ) except AbortException as exc: - return self.abort_with_msg(context, exc.message) + self.abort_with_msg(context, exc.message) + return def execute_function( self, @@ -143,12 +150,23 @@ def execute_function( was_it_raised = False stringized_tb = None try: + self._current_callable = function + shutdown_registered = hasattr(function, "__shutdown__") + is_partial = isinstance(function, functools.partial) + func_type = type(function) + self.log( + f"Shutdown registered: {shutdown_registered}, " + f"fname: {getattr(function, '__name__', repr(function))}," + f"is_partial: {is_partial}, type: {func_type}" + ) result = function(*extra_args) except BaseException as exc: result = exc was_it_raised = True num_frames = len(traceback.extract_stack()[:-5]) stringized_tb = "".join(traceback.format_exc(limit=-num_frames)) + finally: + self._current_callable = None self.log(f"Completed the execution of the {function_kind} function.") return result, was_it_raised, stringized_tb @@ -195,7 +213,7 @@ def log(self, message: str) -> None: def abort_with_msg( self, - context: ServicerContext, + context: grpc.ServicerContext, message: str, *, code: StatusCode = StatusCode.INVALID_ARGUMENT, @@ -204,28 +222,65 @@ def abort_with_msg( context.set_details(message) return None - -def create_server(address: str) -> grpc.Server: + def handle_shutdown(self) -> None: + if self._current_callable is None: + print("No current callable, skipping shutdown.") + return + + # Check for teardown on the callable itself or on the wrapped function + # (in case it's a functools.partial) + shutdown_callable = None + + if hasattr(self._current_callable, "__shutdown__"): + shutdown_callable = self._current_callable.__shutdown__ + elif isinstance(self._current_callable, functools.partial) and hasattr( + self._current_callable.func, "__shutdown__" + ): + shutdown_callable = self._current_callable.func.__shutdown__ + + if shutdown_callable is not None and callable(shutdown_callable): + self.log("Calling shutdown callback.") + try: + shutdown_callable() + except Exception as exc: + self.log(f"Error during shutdown: {exc}") + self.log(traceback.format_exc()) + else: + self.log("No shutdown callback found, skipping.") + + +def create_server(address: str) -> tuple[grpc.Server, futures.ThreadPoolExecutor]: """Create a new (temporary) gRPC server listening on the given - address.""" + address. Returns the server and its executor.""" + executor = futures.ThreadPoolExecutor(max_workers=1) server = grpc.server( - futures.ThreadPoolExecutor(max_workers=1), + executor, maximum_concurrent_rpcs=1, options=get_default_options(), ) # Local server credentials allow us to ensure that the # connection is established by a local process. - server_credentials = grpc.local_server_credentials() + server_credentials = local_server_credentials() server.add_secure_port(address, server_credentials) - return server + return server, executor def run_agent(address: str, log_fd: int | None = None) -> int: """Run the agent servicer on the given address.""" - server = create_server(address) + server, executor = create_server(address) servicer = AgentServicer(log_fd=log_fd) + # Set up SIGTERM handler + def sigterm_handler(signum, frame): + print("Received SIGTERM, shutting down the agent...") + servicer.handle_shutdown() + print("Shutdown complete, stopping the agent server.") + server.stop(grace=0.1) + executor.shutdown(wait=False, cancel_futures=True) + + signal.signal(signal.SIGTERM, sigterm_handler) + # This function just calls some methods on the server # and register a generic handler for the bridge. It does # not have any global side effects. @@ -242,7 +297,9 @@ def main() -> int: parser.add_argument("--log-fd", type=int) options = parser.parse_args() - return run_agent(options.address, log_fd=options.log_fd) + ret_code = run_agent(options.address, log_fd=options.log_fd) + print("Agent process exiting.") + sys.exit(ret_code) if __name__ == "__main__": diff --git a/src/isolate/server/server.py b/src/isolate/server/server.py index 12a9931..a48f509 100644 --- a/src/isolate/server/server.py +++ b/src/isolate/server/server.py @@ -2,6 +2,7 @@ import functools import os +import signal import threading import time import traceback @@ -178,11 +179,18 @@ class RunTask: def cancel(self): while True: - self.future.cancel() + # Cancelling a running future is not possible, and it sometimes blocks, + # which means we never terminate the agent. So check if it's not running + if self.future and not self.future.running(): + self.future.cancel() + if self.agent: + print("Terminating the agent...") self.agent.terminate() + try: - self.future.exception(timeout=0.1) + if self.future: + self.future.exception(timeout=0.1) return except futures.TimeoutError: pass @@ -197,6 +205,7 @@ class IsolateServicer(definitions.IsolateServicer): bridge_manager: BridgeManager default_settings: IsolateSettings = field(default_factory=IsolateSettings) background_tasks: dict[str, RunTask] = field(default_factory=dict) + _shutting_down: bool = field(default=False) _thread_pool: futures.ThreadPoolExecutor = field( default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS) @@ -386,11 +395,12 @@ def Run( self.background_tasks["RUN"] = task yield from self._run_task(task) except GRPCException as exc: - return self.abort_with_msg( + self.abort_with_msg( exc.message, context, code=exc.code, ) + return finally: self.background_tasks.pop("RUN", None) @@ -420,6 +430,17 @@ def Cancel( return definitions.CancelResponse() + def shutdown(self) -> None: + if self._shutting_down: + print("Shutdown already in progress...") + return + + self._shutting_down = True + task_count = len(self.background_tasks) + print(f"Shutting down, canceling {task_count} tasks...") + self.cancel_tasks() + print("All tasks canceled.") + def watch_queue_until_completed( self, queue: Queue, is_completed: Callable[[], bool] ) -> Iterator[definitions.PartialRunResult]: @@ -584,8 +605,10 @@ def _wrapper(request: Any, context: grpc.ServicerContext) -> Any: def termination() -> None: if is_run: print("Stopping server since run is finished") + self.servicer.shutdown() # Stop the server after the Run task is finished self.server.stop(grace=0.1) + print("Server stopped") elif is_submit: # Wait until the task_id is assigned @@ -610,7 +633,9 @@ def _stop(*args): # Small sleep to make sure the cancellation is processed time.sleep(0.1) print("Stopping server since the task is finished") + self.servicer.shutdown() self.server.stop(grace=0.1) + print("Server stopped") # Add a callback which will stop the server # after the task is finished @@ -671,11 +696,21 @@ def main(argv: list[str] | None = None) -> None: definitions.register_isolate(servicer, server) health.register_health(HealthServicer(), server) - server.add_insecure_port("[::]:50001") - print("Started listening at localhost:50001") + def handle_termination(*args): + print("Termination signal received, shutting down...") + servicer.shutdown() + time.sleep(10) + server.stop(grace=0.1) + + signal.signal(signal.SIGINT, handle_termination) + signal.signal(signal.SIGTERM, handle_termination) + + server.add_insecure_port(f"[::]:{options.port}") + print(f"Started listening at {options.host}:{options.port}") server.start() server.wait_for_termination() + print("Server shut down") if __name__ == "__main__": diff --git a/tests/test_shutdown.py b/tests/test_shutdown.py new file mode 100644 index 0000000..37a122d --- /dev/null +++ b/tests/test_shutdown.py @@ -0,0 +1,176 @@ +"""End-to-end tests for graceful shutdown behavior of IsolateServicer.""" + +import functools +import os +import signal +import subprocess +import sys +import threading +import time +from unittest.mock import Mock + +import grpc +import pytest +from isolate.server.definitions.server_pb2 import BoundFunction, EnvironmentDefinition +from isolate.server.definitions.server_pb2_grpc import IsolateStub +from isolate.server.interface import to_serialized_object +from isolate.server.server import BridgeManager, IsolateServicer, RunnerAgent, RunTask + + +def create_run_request(func, *args, **kwargs): + """Convert a Python function into a BoundFunction request for stub.Run().""" + bound_function = functools.partial(func, *args, **kwargs) + serialized_function = to_serialized_object(bound_function, method="cloudpickle") + + env_def = EnvironmentDefinition() + env_def.kind = "local" + + request = BoundFunction() + request.function.CopyFrom(serialized_function) + request.environments.append(env_def) + request.stream_logs = True + + return request + + +@pytest.fixture +def servicer(): + """Create a real IsolateServicer instance for testing.""" + with BridgeManager() as bridge_manager: + servicer = IsolateServicer(bridge_manager) + yield servicer + + +@pytest.fixture +def isolate_server_subprocess(monkeypatch): + """Set up a gRPC server with the IsolateServicer for testing.""" + # Find a free port + import socket + + monkeypatch.setenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "2") + + with socket.socket() as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + process = subprocess.Popen( + [ + sys.executable, + "-m", + "isolate.server.server", + "--single-use", + "--port", + str(port), + ] + ) + + time.sleep(5) # Wait for server to start + yield process, port + + # Cleanup + if process.poll() is None: + process.terminate() + process.wait(timeout=10) + + +def consume_responses(responses): + def _consume(): + try: + for response in responses: + pass + except grpc.RpcError: + # Expected when connection is closed + pass + + response_thread = threading.Thread(target=_consume, daemon=True) + response_thread.start() + + +def test_shutdown_with_terminate(servicer): + task = RunTask(request=Mock(), future=Mock()) + servicer.background_tasks["TEST_BLOCKING"] = task + task.agent = RunnerAgent(Mock(), Mock(), Mock(), Mock()) + task.agent.terminate = Mock(wraps=task.agent.terminate) + servicer.shutdown() + task.agent.terminate.assert_called_once() # agent should be terminated + + +def test_exit_on_client_close(isolate_server_subprocess): + """Connect with grpc client, run a task and then close the client.""" + process, port = isolate_server_subprocess + channel = grpc.insecure_channel(f"localhost:{port}") + stub = IsolateStub(channel) + + def fn(): + import time + + time.sleep(30) # Simulate long-running task + + responses = stub.Run(create_run_request(fn)) + consume_responses(responses) + + # Give task time to start + time.sleep(2) + + # there is a running grpc client connected to an isolate servicer which is + # emitting responses from an agent running a infinite loop + assert process.poll() is None, "Server should be running while client is connected" + + # Close the channel to simulate client disconnect + channel.close() + + # Give time for the channel close to propagate and trigger termination + time.sleep(1.0) + + try: + # Wait for server process to exit + process.wait(timeout=5) + except subprocess.TimeoutExpired: + raise AssertionError("Server did not shut down after client disconnect") + + assert ( + process.poll() is not None + ), "Server should have shut down after client disconnect" + + +def test_running_function_receives_sigterm(isolate_server_subprocess, tmp_path): + """Test that the user provided code receives the SIGTERM""" + process, port = isolate_server_subprocess + channel = grpc.insecure_channel(f"localhost:{port}") + stub = IsolateStub(channel) + + # Send SIGTERM to the current process + assert process.poll() is None, "Server should be running initially" + + sigterm_file_path = tmp_path.joinpath("sigterm_test") + + def teardown(path): + import pathlib + + pathlib.Path(path).touch() + + def func_with_teardown(): + import time + + time.sleep(30) # Simulate long-running task + + func_with_teardown.__shutdown__ = functools.partial( + teardown, str(sigterm_file_path) + ) + + assert not sigterm_file_path.exists() + + responses = stub.Run(create_run_request(func_with_teardown)) + consume_responses(responses) + time.sleep(2) # Give task time to start + + os.kill(process.pid, signal.SIGTERM) + process.wait(timeout=5) + assert process.poll() is not None, "Server should have shut down after SIGTERM" + assert ( + sigterm_file_path.exists() + ), "Function should have received SIGTERM and created the file" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])