diff --git a/custom_model_runner/datarobot_drum/drum/main.py b/custom_model_runner/datarobot_drum/drum/main.py index 94aa0e5b5..349a3a8b8 100644 --- a/custom_model_runner/datarobot_drum/drum/main.py +++ b/custom_model_runner/datarobot_drum/drum/main.py @@ -4,6 +4,22 @@ This is proprietary source code of DataRobot, Inc. and its affiliates. Released under the terms of DataRobot Tool and Utility Agreement. """ +from ..runtime_parameters import RuntimeParameters + +# Monkey patching for gevent compatibility if running with gunicorn-gevent +if RuntimeParameters.has("DRUM_SERVER_TYPE") and RuntimeParameters.has( + "DRUM_GUNICORN_WORKER_CLASS" +): + if ( + str(RuntimeParameters.get("DRUM_SERVER_TYPE")).lower() == "gunicorn" + and str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS")).lower() == "gevent" + ): + try: + from gevent import monkey + + monkey.patch_all() + except ImportError: + pass from datarobot_drum.drum.lazy_loading.lazy_loading_handler import LazyLoadingHandler @@ -39,6 +55,7 @@ # Run regression user model in fit mode. drum fit --code-dir --input --output --target-type regression --target --verbose """ + import os import signal import sys diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py b/custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py new file mode 100644 index 000000000..70aebfdaa --- /dev/null +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/gevent_stdout_flusher.py @@ -0,0 +1,98 @@ +import logging +import sys +import time +try: + import gevent + from gevent import Greenlet + HAS_GEVENT = True +except ImportError: + HAS_GEVENT = False + import threading + +HAS_GEVENT = True + +logger = logging.getLogger(__name__) + +class GeventCompatibleStdoutFlusher: + """An implementation to flush the stdout after a certain time of no activity. + Compatible with both gevent and threading environments.""" + + def __init__(self, max_time_until_flushing=1.0): + self._max_time_until_flushing = max_time_until_flushing + self._last_predict_time = None + self._flusher_greenlet = None + self._flusher_thread = None + self._stop_event = None + self._running = False + + def start(self): + """Start the stdout flusher.""" + if self._running: + return + + self._running = True + + if HAS_GEVENT: + self._flusher_greenlet = gevent.spawn(self._flush_greenlet_method) + else: + self._stop_event = threading.Event() + self._flusher_thread = threading.Thread(target=self._flush_thread_method) + """Check if the stdout flusher is alive.""" + if HAS_GEVENT and self._flusher_greenlet: + return not self._flusher_greenlet.dead + elif self._flusher_thread: + return self._flusher_thread.is_alive() + return False + + def stop(self): + """Stop the flusher in a synchronous fashion.""" + if not self._running: + logging.error("Flusher thread stopped 1.") + return + + self._running = False + logging.error("Flusher thread stopped 2.") + if HAS_GEVENT and self._flusher_greenlet: + self._flusher_greenlet.kill() + self._flusher_greenlet = None + elif self._flusher_thread and self._stop_event: + logging.error("Flusher thread stopped 3.") + self._stop_event.set() + self._flusher_thread.join(timeout=2.0) # Timeout to prevent hanging + self._flusher_thread = None + self._stop_event = None + logging.error("Flusher thread stopped 4.") + + def set_last_activity_time(self): + """Set the last activity time that will be used as the reference for time comparison.""" + self._last_predict_time = self._current_time() + + @staticmethod + def _current_time(): + return time.time() + + def _flush_greenlet_method(self): + """Gevent greenlet method for stdout flushing""" + try: + while self._running: + self._process_stdout_flushing() + #gevent.sleep(self._max_time_until_flushing) + gevent.sleep(0) + except gevent.GreenletExit: + pass # Normal termination + + def _flush_thread_method(self): + """Threading method for stdout flushing""" + while self._running and not self._stop_event.wait(self._max_time_until_flushing): + self._process_stdout_flushing() + + def _process_stdout_flushing(self): + if self._is_predict_time_set_and_max_waiting_time_expired(): + sys.stdout.flush() + sys.stderr.flush() + + def _is_predict_time_set_and_max_waiting_time_expired(self): + if self._last_predict_time is None: + return False + + return (self._current_time() - self._last_predict_time) >= self._max_time_until_flushing diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index e7fefc2b4..467d9a863 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -7,12 +7,14 @@ import logging import os import sys +import threading import time from pathlib import Path from threading import Thread import subprocess import signal +import psutil import requests from flask import Response, jsonify, request from werkzeug.exceptions import HTTPException @@ -312,17 +314,79 @@ def handle_exception(e): app = get_flask_app(model_api) self.load_flask_extensions(app) - self._run_flask_app(app) + self._run_flask_app(app, self._terminate) if self._stats_collector: self._stats_collector.print_reports() return [] - def _run_flask_app(self, app): + def get_gunicorn_config(self): + config = {} + if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CLASS"): + worker_class = str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS")) + if worker_class.lower() in {"sync", "gevent"}: + config["worker_class"] = worker_class + + if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CONNECTIONS"): + worker_connections = int(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CONNECTIONS")) + if 1 <= worker_connections <= 10000: + config["worker_connections"] = worker_connections + + if RuntimeParameters.has("DRUM_GUNICORN_BACKLOG"): + backlog = int(RuntimeParameters.get("DRUM_GUNICORN_BACKLOG")) + if 1 <= backlog <= 2048: + config["backlog"] = backlog + + if RuntimeParameters.has("DRUM_GUNICORN_TIMEOUT"): + timeout = int(RuntimeParameters.get("DRUM_GUNICORN_TIMEOUT")) + if 1 <= timeout <= 3600: + config["timeout"] = timeout + + if RuntimeParameters.has("DRUM_GUNICORN_GRACEFUL_TIMEOUT"): + graceful_timeout = int(RuntimeParameters.get("DRUM_GUNICORN_GRACEFUL_TIMEOUT")) + if 1 <= graceful_timeout <= 3600: + config["graceful_timeout"] = graceful_timeout + + if RuntimeParameters.has("DRUM_GUNICORN_KEEP_ALIVE"): + keepalive = int(RuntimeParameters.get("DRUM_GUNICORN_KEEP_ALIVE")) + if 1 <= keepalive <= 3600: + config["keepalive"] = keepalive + + if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS"): + max_requests = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS")) + if 1 <= max_requests <= 10000: + config["max_requests"] = max_requests + + if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS_JITTER"): + max_requests_jitter = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS_JITTER")) + if 1 <= max_requests_jitter <= 10000: + config["max_requests_jitter"] = max_requests_jitter + + if RuntimeParameters.has("DRUM_GUNICORN_LOG_LEVEL"): + loglevel = str(RuntimeParameters.get("DRUM_GUNICORN_LOG_LEVEL")) + if loglevel.lower() in {"debug", "info", "warning", "error", "critical"}: + config["loglevel"] = loglevel + + if RuntimeParameters.has("DRUM_GUNICORN_WORKERS"): + workers = int(RuntimeParameters.get("DRUM_GUNICORN_WORKERS")) + if 0 < workers < 200: + config["workers"] = workers + + return config + + def get_server_type(self): + server_type = "flask" + if RuntimeParameters.has("DRUM_SERVER_TYPE"): + server_type = str(RuntimeParameters.get("DRUM_SERVER_TYPE")) + if server_type.lower() in {"flask", "gunicorn"}: + server_type = server_type.lower() + return server_type + + def _run_flask_app(self, app, termination_hook): host = self._params.get("host", None) port = self._params.get("port", None) - + server_type = self.get_server_type() processes = 1 if self._params.get("processes"): processes = self._params.get("processes") @@ -340,20 +404,142 @@ def _run_flask_app(self, app): ) self._server_watchdog.start() - # Configure the server with timeout settings - app.run( - host=host, - port=port, - threaded=False, - processes=processes, - **( - {"request_handler": TimeoutWSGIRequestHandler} - if RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT") - else {} - ), - ) + if server_type == "gunicorn": + logger.info("Starting gunicorn server") + try: + from gunicorn.app.base import BaseApplication + except ImportError: + BaseApplication = None + raise DrumCommonException("gunicorn is not installed. Please install gunicorn.") + + class GunicornApp(BaseApplication): + def __init__(self, app, host, port, params, gunicorn_config, termination_hook): + self.application = app + self.host = host + self.port = port + self.params = params + self.gunicorn_config = gunicorn_config + self.termination_hook = termination_hook + super().__init__() + + def load_config(self): + self.cfg.set("bind", f"{self.host}:{self.port}") + workers = ( + self.params.get("max_workers") + or self.params.get("processes") + ) + if self.gunicorn_config.get("workers"): + workers = self.gunicorn_config.get("workers") + self.cfg.set("workers", workers) + self.cfg.set("reuse_port", True) + self.cfg.set("preload_app", True) + + self.cfg.set( + "worker_class", self.gunicorn_config.get("worker_class", "sync") + ) + self.cfg.set("backlog", self.gunicorn_config.get("backlog", 2048)) + self.cfg.set("timeout", self.gunicorn_config.get("timeout", 120)) + self.cfg.set( + "graceful_timeout", self.gunicorn_config.get("graceful_timeout", 60) + ) + self.cfg.set("keepalive", self.gunicorn_config.get("keepalive", 5)) + self.cfg.set("max_requests", self.gunicorn_config.get("max_requests", 1000)) + self.cfg.set( + "max_requests_jitter", + self.gunicorn_config.get("max_requests_jitter", 500), + ) + + if self.gunicorn_config.get("worker_connections"): + self.cfg.set( + "worker_connections", self.gunicorn_config.get("worker_connections") + ) + self.cfg.set("loglevel", self.gunicorn_config.get("loglevel", "info")) + + # Properly assign the worker_exit hook + if self.termination_hook: + self.cfg.set("worker_exit", self._worker_exit_hook) + + '''self.cfg.set("accesslog", "-") + self.cfg.set("errorlog", "-") # if you want error logs to stdout + self.cfg.set( + "access_log_format", + '%(t)s %(h)s %(l)s %(u)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"', + )''' + # Remove unsupported config keys: access_logfile, error_logfile, access_logformat + # These must be set via CLI, not config API + + def load(self): + return self.application + + def _worker_exit_hook(self, server, worker): + pid = worker.pid + server.log.info(f"[HOOK] Worker PID {pid} exiting — running termination hook.") + + def run_hook(): + try: + self.termination_hook() + server.log.info(f"[HOOK] Worker PID {pid} termination logic completed.") + except Exception as e: + server.log.error(f"[HOOK ERROR] Worker PID {pid}: {e}") + + try: + # 🔍 Найти процесс, который слушает тот же порт + port = self.port + occupying_proc = None + + for proc in psutil.process_iter(['pid', 'name']): + try: + for conn in proc.connections(kind='inet'): + if conn.status == psutil.CONN_LISTEN and conn.laddr.port == port: + occupying_proc = proc + break + if occupying_proc: + break + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + if occupying_proc: + server.log.info( + f"[PORT OCCUPIED] Порт {port} занят процессом: PID={occupying_proc.pid}, имя={occupying_proc.name()}" + ) + else: + server.log.info(f"[PORT FREE] Порт {port} свободен.") + + server.log.info(f"[HOOK] Worker PID {pid} termination logic completed.") + except Exception as e: + server.log.error(f"[HOOK ERROR] Worker PID {pid}: {e}") + + + + thread = threading.Thread(target=run_hook) + thread.start() + + + '''for thread in threads: + try: + server.log.info(f"Name: {thread.name}, ID: {thread.ident}, Daemon: {thread.daemon}") + ''' + #server.log.info(f"Active thread count:", threading.active_count()) + thread.join(timeout=20) + server.log.info(f"[HOOK] Worker PID {pid} cleanup done or timed out.") + + gunicorn_config = self.get_gunicorn_config() + GunicornApp(app, host, port, self._params, gunicorn_config, termination_hook).run() + else: + # Configure the server with timeout settings + app.run( + host=host, + port=port, + threaded=False, + processes=processes, + **( + {"request_handler": TimeoutWSGIRequestHandler} + if RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT") + else {} + ), + ) except OSError as e: - raise DrumCommonException("{}: host: {}; port: {}".format(e, host, port)) + raise DrumCommonException(f"{e}: host: {host}; port: {port}") def _kill_all_processes(self): """ diff --git a/custom_model_runner/requirements.txt b/custom_model_runner/requirements.txt index db2466b50..d9b6a68fd 100644 --- a/custom_model_runner/requirements.txt +++ b/custom_model_runner/requirements.txt @@ -3,6 +3,8 @@ argcomplete trafaret>=2.0.0 docker>=4.2.2 flask +gevent +gunicorn jinja2>=3.0.0 memory_profiler<1.0.0 numpy