Skip to content
Draft

WIP #178

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ jobs:

matrix:
os: [ubuntu-latest]
python: ["3.8", "3.9", "3.10", "3.11"]
include:
- os: macos-latest
python: "3.8"
- os: macos-latest
python: "3.11"
python: ["3.11"]
steps:
- uses: actions/checkout@v3

Expand Down
20 changes: 18 additions & 2 deletions src/isolate/connections/grpc/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import socket
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -23,6 +25,11 @@ class AgentError(Exception):
"""An internal problem caused by (most probably) the agent."""


PROCESS_SHUTDOWN_TIMEOUT_SECONDS = float(
os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "60")
)


@dataclass
class GRPCExecutionBase(EnvironmentConnection):
"""A customizable gRPC-based execution backend."""
Expand Down Expand Up @@ -128,9 +135,18 @@ def find_free_port() -> Tuple[str, int]:
with self.start_process(address) as process:
yield address, grpc.local_channel_credentials()
finally:
if process is not None:
# TODO: should we check the status code here?
self.terminate_process(process)

def terminate_process(self, process: Union[None, subprocess.Popen]) -> None:
if process is not None:
try:
print("Terminating the agent process...")
process.terminate()
process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS)
print("Agent process shutdown gracefully")
except Exception as exc:
print(f"Failed to shutdown the agent process gracefully: {exc}")
process.kill()

def get_python_cmd(
self,
Expand Down
85 changes: 71 additions & 14 deletions src/isolate/connections/grpc/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from __future__ import annotations

import functools
import os
import signal
import sys
import traceback
from argparse import ArgumentParser
Expand All @@ -23,7 +25,9 @@
)

import grpc
from grpc import ServicerContext, StatusCode
from grpc import StatusCode, local_server_credentials

from isolate.connections.grpc.definitions import PartialRunResult

try:
from isolate import __version__ as agent_version
Expand All @@ -48,12 +52,13 @@ def __init__(self, log_fd: int | None = None):

self._run_cache: dict[str, Any] = {}
self._log = sys.stdout if log_fd is None else os.fdopen(log_fd, "w")
self._current_callable: Any = None

def Run(
self,
request: definitions.FunctionCall,
context: ServicerContext,
) -> Iterator[definitions.PartialRunResult]:
context: grpc.ServicerContext,
) -> Iterator[PartialRunResult]:
self.log(f"A connection has been established: {context.peer()}!")
server_version = os.getenv("ISOLATE_SERVER_VERSION") or "unknown"
self.log(f"Isolate info: server {server_version}, agent {agent_version}")
Expand Down Expand Up @@ -87,7 +92,8 @@ def Run(
)
raise AbortException("The setup function has thrown an error.")
except AbortException as exc:
return self.abort_with_msg(context, exc.message)
self.abort_with_msg(context, exc.message)
return
else:
assert not was_it_raised
self._run_cache[cache_key] = result
Expand All @@ -107,7 +113,8 @@ def Run(
stringized_tb,
)
except AbortException as exc:
return self.abort_with_msg(context, exc.message)
self.abort_with_msg(context, exc.message)
return

def execute_function(
self,
Expand Down Expand Up @@ -143,12 +150,23 @@ def execute_function(
was_it_raised = False
stringized_tb = None
try:
self._current_callable = function
shutdown_registered = hasattr(function, "__shutdown__")
is_partial = isinstance(function, functools.partial)
func_type = type(function)
self.log(
f"Shutdown registered: {shutdown_registered}, "
f"fname: {getattr(function, '__name__', repr(function))},"
f"is_partial: {is_partial}, type: {func_type}"
)
result = function(*extra_args)
except BaseException as exc:
result = exc
was_it_raised = True
num_frames = len(traceback.extract_stack()[:-5])
stringized_tb = "".join(traceback.format_exc(limit=-num_frames))
finally:
self._current_callable = None

self.log(f"Completed the execution of the {function_kind} function.")
return result, was_it_raised, stringized_tb
Expand Down Expand Up @@ -195,7 +213,7 @@ def log(self, message: str) -> None:

def abort_with_msg(
self,
context: ServicerContext,
context: grpc.ServicerContext,
message: str,
*,
code: StatusCode = StatusCode.INVALID_ARGUMENT,
Expand All @@ -204,28 +222,65 @@ def abort_with_msg(
context.set_details(message)
return None


def create_server(address: str) -> grpc.Server:
def handle_shutdown(self) -> None:
if self._current_callable is None:
print("No current callable, skipping shutdown.")
return

# Check for teardown on the callable itself or on the wrapped function
# (in case it's a functools.partial)
shutdown_callable = None

if hasattr(self._current_callable, "__shutdown__"):
shutdown_callable = self._current_callable.__shutdown__
elif isinstance(self._current_callable, functools.partial) and hasattr(
self._current_callable.func, "__shutdown__"
):
shutdown_callable = self._current_callable.func.__shutdown__

if shutdown_callable is not None and callable(shutdown_callable):
self.log("Calling shutdown callback.")
try:
shutdown_callable()
except Exception as exc:
self.log(f"Error during shutdown: {exc}")
self.log(traceback.format_exc())
else:
self.log("No shutdown callback found, skipping.")


def create_server(address: str) -> tuple[grpc.Server, futures.ThreadPoolExecutor]:
"""Create a new (temporary) gRPC server listening on the given
address."""
address. Returns the server and its executor."""
executor = futures.ThreadPoolExecutor(max_workers=1)
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
executor,
maximum_concurrent_rpcs=1,
options=get_default_options(),
)

# Local server credentials allow us to ensure that the
# connection is established by a local process.
server_credentials = grpc.local_server_credentials()
server_credentials = local_server_credentials()
server.add_secure_port(address, server_credentials)
return server
return server, executor


def run_agent(address: str, log_fd: int | None = None) -> int:
"""Run the agent servicer on the given address."""
server = create_server(address)
server, executor = create_server(address)
servicer = AgentServicer(log_fd=log_fd)

# Set up SIGTERM handler
def sigterm_handler(signum, frame):
print("Received SIGTERM, shutting down the agent...")
servicer.handle_shutdown()
print("Shutdown complete, stopping the agent server.")
server.stop(grace=0.1)
executor.shutdown(wait=False, cancel_futures=True)

signal.signal(signal.SIGTERM, sigterm_handler)

# This function just calls some methods on the server
# and register a generic handler for the bridge. It does
# not have any global side effects.
Expand All @@ -242,7 +297,9 @@ def main() -> int:
parser.add_argument("--log-fd", type=int)

options = parser.parse_args()
return run_agent(options.address, log_fd=options.log_fd)
ret_code = run_agent(options.address, log_fd=options.log_fd)
print("Agent process exiting.")
sys.exit(ret_code)


if __name__ == "__main__":
Expand Down
45 changes: 40 additions & 5 deletions src/isolate/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import os
import signal
import threading
import time
import traceback
Expand Down Expand Up @@ -178,11 +179,18 @@ class RunTask:

def cancel(self):
while True:
self.future.cancel()
# Cancelling a running future is not possible, and it sometimes blocks,
# which means we never terminate the agent. So check if it's not running
if self.future and not self.future.running():
self.future.cancel()

if self.agent:
print("Terminating the agent...")
self.agent.terminate()

try:
self.future.exception(timeout=0.1)
if self.future:
self.future.exception(timeout=0.1)
return
except futures.TimeoutError:
pass
Expand All @@ -197,6 +205,7 @@ class IsolateServicer(definitions.IsolateServicer):
bridge_manager: BridgeManager
default_settings: IsolateSettings = field(default_factory=IsolateSettings)
background_tasks: dict[str, RunTask] = field(default_factory=dict)
_shutting_down: bool = field(default=False)

_thread_pool: futures.ThreadPoolExecutor = field(
default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS)
Expand Down Expand Up @@ -386,11 +395,12 @@ def Run(
self.background_tasks["RUN"] = task
yield from self._run_task(task)
except GRPCException as exc:
return self.abort_with_msg(
self.abort_with_msg(
exc.message,
context,
code=exc.code,
)
return
finally:
self.background_tasks.pop("RUN", None)

Expand Down Expand Up @@ -420,6 +430,17 @@ def Cancel(

return definitions.CancelResponse()

def shutdown(self) -> None:
if self._shutting_down:
print("Shutdown already in progress...")
return

self._shutting_down = True
task_count = len(self.background_tasks)
print(f"Shutting down, canceling {task_count} tasks...")
self.cancel_tasks()
print("All tasks canceled.")

def watch_queue_until_completed(
self, queue: Queue, is_completed: Callable[[], bool]
) -> Iterator[definitions.PartialRunResult]:
Expand Down Expand Up @@ -584,8 +605,10 @@ def _wrapper(request: Any, context: grpc.ServicerContext) -> Any:
def termination() -> None:
if is_run:
print("Stopping server since run is finished")
self.servicer.shutdown()
# Stop the server after the Run task is finished
self.server.stop(grace=0.1)
print("Server stopped")

elif is_submit:
# Wait until the task_id is assigned
Expand All @@ -610,7 +633,9 @@ def _stop(*args):
# Small sleep to make sure the cancellation is processed
time.sleep(0.1)
print("Stopping server since the task is finished")
self.servicer.shutdown()
self.server.stop(grace=0.1)
print("Server stopped")

# Add a callback which will stop the server
# after the task is finished
Expand Down Expand Up @@ -671,11 +696,21 @@ def main(argv: list[str] | None = None) -> None:
definitions.register_isolate(servicer, server)
health.register_health(HealthServicer(), server)

server.add_insecure_port("[::]:50001")
print("Started listening at localhost:50001")
def handle_termination(*args):
print("Termination signal received, shutting down...")
servicer.shutdown()
time.sleep(10)
server.stop(grace=0.1)

signal.signal(signal.SIGINT, handle_termination)
signal.signal(signal.SIGTERM, handle_termination)

server.add_insecure_port(f"[::]:{options.port}")
print(f"Started listening at {options.host}:{options.port}")

server.start()
server.wait_for_termination()
print("Server shut down")


if __name__ == "__main__":
Expand Down
Loading
Loading