diff --git a/src/isolate/server/server.py b/src/isolate/server/server.py index a90ec58..67fb2a4 100644 --- a/src/isolate/server/server.py +++ b/src/isolate/server/server.py @@ -1,10 +1,12 @@ from __future__ import annotations +import functools import os import threading import time import traceback import uuid +from argparse import ArgumentParser from collections import defaultdict from concurrent import futures from concurrent.futures import ThreadPoolExecutor @@ -16,6 +18,7 @@ import grpc from grpc import ServicerContext, StatusCode +from grpc.experimental import wrap_server_method_handler from isolate.backends import ( EnvironmentCreationError, @@ -465,13 +468,120 @@ def _add_log_to_queue(self, log: Log) -> None: self.messages.put_nowait(grpc_result) -def main() -> None: +@dataclass +class ServerBoundInterceptor(grpc.ServerInterceptor): + _server: grpc.Server | None = None + _servicer: IsolateServicer | None = None + + def register_server(self, server: grpc.Server) -> None: + if self._server is not None: + raise RuntimeError("A server is already bound to this interceptor.") + + self._server = server + + @property + def server(self) -> grpc.Server: + if self._server is None: + raise RuntimeError("No server was bound to this interceptor.") + + return self._server + + def register_servicer(self, servicer: IsolateServicer) -> None: + if self._servicer is not None: + raise RuntimeError("A servicer is already bound to this interceptor.") + + self._servicer = servicer + + @property + def servicer(self) -> IsolateServicer: + if self._servicer is None: + raise RuntimeError("No servicer was bound to this interceptor.") + + return self._servicer + + +@dataclass +class SingleTaskInterceptor(ServerBoundInterceptor): + """Sets server to terminate after the first Submit/Run task.""" + + _done: bool = False + + def intercept_service(self, continuation, handler_call_details): + handler = continuation(handler_call_details) + + is_submit = handler_call_details.method == "/Isolate/Submit" + is_run = handler_call_details.method == "/Isolate/Run" + is_new_task = is_submit or is_run + + if is_new_task and self._done: + raise grpc.RpcError( + grpc.StatusCode.UNAVAILABLE, + "Server has already served one Run/Submit task.", + ) + elif is_new_task: + self._done = True + else: + # Let other requests like List/Cancel/etc pass through + return continuation(handler_call_details) + + def wrapper(method_impl): + @functools.wraps(method_impl) + def _wrapper(request, context): + def _stop(): + if is_submit: + # Wait for the task to finish + while self.server.servicer.background_tasks: + time.sleep(0.1) + self.server.stop(grace=0.1) + + context.add_callback(_stop) + return method_impl(request, context) + + return _wrapper + + return wrap_server_method_handler(wrapper, handler) + + +def main(argv: list[str] | None = None) -> None: + parser = ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=50001) + parser.add_argument( + "--single-use", + action="store_true", + help="Terminate the server after the first Run or Submit task is completed.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=MAX_THREADS, + help="Number of worker threads to use for the gRPC server.", + ) + + options = parser.parse_args(argv) + if options.num_workers is None: + options.num_workers = 1 if options.single_use else os.cpu_count() + + interceptors: list[ServerBoundInterceptor] = [] + if options.single_use: + interceptors.append(SingleTaskInterceptor()) + server = grpc.server( - futures.ThreadPoolExecutor(max_workers=MAX_THREADS), + futures.ThreadPoolExecutor(max_workers=options.num_workers), options=get_default_options(), + interceptors=interceptors, ) + + for interceptor in interceptors: + interceptor.register_server(server) + with BridgeManager() as bridge_manager: - definitions.register_isolate(IsolateServicer(bridge_manager), server) + servicer = IsolateServicer(bridge_manager) + + for interceptor in interceptors: + interceptor.register_servicer(servicer) + + definitions.register_isolate(servicer, server) health.register_health(HealthServicer(), server) server.add_insecure_port("[::]:50001")