Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions custom_model_runner/datarobot_drum/drum/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
This is proprietary source code of DataRobot, Inc. and its affiliates.
Released under the terms of DataRobot Tool and Utility Agreement.
"""
from ..runtime_parameters import RuntimeParameters

# Monkey patching for gevent compatibility if running with gunicorn-gevent
if RuntimeParameters.has("DRUM_SERVER_TYPE") and RuntimeParameters.has(
"DRUM_GUNICORN_WORKER_CLASS"
):
if (
str(RuntimeParameters.get("DRUM_SERVER_TYPE")).lower() == "gunicorn"
and str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS")).lower() == "gevent"
):
try:
from gevent import monkey

monkey.patch_all()
except ImportError:
pass

from datarobot_drum.drum.lazy_loading.lazy_loading_handler import LazyLoadingHandler

Expand Down Expand Up @@ -39,6 +55,7 @@
# Run regression user model in fit mode.
drum fit --code-dir <custom code dir> --input <input.csv> --output <output_dir> --target-type regression --target <target feature> --verbose
"""

import os
import signal
import sys
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging
import sys
import time
try:
import gevent
from gevent import Greenlet
HAS_GEVENT = True
except ImportError:
HAS_GEVENT = False
import threading

HAS_GEVENT = True

logger = logging.getLogger(__name__)

class GeventCompatibleStdoutFlusher:
"""An implementation to flush the stdout after a certain time of no activity.
Compatible with both gevent and threading environments."""

def __init__(self, max_time_until_flushing=1.0):
self._max_time_until_flushing = max_time_until_flushing
self._last_predict_time = None
self._flusher_greenlet = None
self._flusher_thread = None
self._stop_event = None
self._running = False

def start(self):
"""Start the stdout flusher."""
if self._running:
return

self._running = True

if HAS_GEVENT:
self._flusher_greenlet = gevent.spawn(self._flush_greenlet_method)
else:
self._stop_event = threading.Event()
self._flusher_thread = threading.Thread(target=self._flush_thread_method)
"""Check if the stdout flusher is alive."""
if HAS_GEVENT and self._flusher_greenlet:
return not self._flusher_greenlet.dead
elif self._flusher_thread:
return self._flusher_thread.is_alive()
return False

def stop(self):
"""Stop the flusher in a synchronous fashion."""
if not self._running:
logging.error("Flusher thread stopped 1.")
return

self._running = False
logging.error("Flusher thread stopped 2.")
if HAS_GEVENT and self._flusher_greenlet:
self._flusher_greenlet.kill()
self._flusher_greenlet = None
elif self._flusher_thread and self._stop_event:
logging.error("Flusher thread stopped 3.")
self._stop_event.set()
self._flusher_thread.join(timeout=2.0) # Timeout to prevent hanging
self._flusher_thread = None
self._stop_event = None
logging.error("Flusher thread stopped 4.")

def set_last_activity_time(self):
"""Set the last activity time that will be used as the reference for time comparison."""
self._last_predict_time = self._current_time()

@staticmethod
def _current_time():
return time.time()

def _flush_greenlet_method(self):
"""Gevent greenlet method for stdout flushing"""
try:
while self._running:
self._process_stdout_flushing()
#gevent.sleep(self._max_time_until_flushing)
gevent.sleep(0)
except gevent.GreenletExit:
pass # Normal termination

def _flush_thread_method(self):
"""Threading method for stdout flushing"""
while self._running and not self._stop_event.wait(self._max_time_until_flushing):
self._process_stdout_flushing()

def _process_stdout_flushing(self):
if self._is_predict_time_set_and_max_waiting_time_expired():
sys.stdout.flush()
sys.stderr.flush()

def _is_predict_time_set_and_max_waiting_time_expired(self):
if self._last_predict_time is None:
return False

return (self._current_time() - self._last_predict_time) >= self._max_time_until_flushing
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import logging
import os
import sys
import threading
import time
from pathlib import Path
from threading import Thread
import subprocess
import signal

import psutil
import requests
from flask import Response, jsonify, request
from werkzeug.exceptions import HTTPException
Expand Down Expand Up @@ -312,17 +314,79 @@ def handle_exception(e):

app = get_flask_app(model_api)
self.load_flask_extensions(app)
self._run_flask_app(app)
self._run_flask_app(app, self._terminate)

if self._stats_collector:
self._stats_collector.print_reports()

return []

def _run_flask_app(self, app):
def get_gunicorn_config(self):
config = {}
if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CLASS"):
worker_class = str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS"))
if worker_class.lower() in {"sync", "gevent"}:
config["worker_class"] = worker_class

if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CONNECTIONS"):
worker_connections = int(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CONNECTIONS"))
if 1 <= worker_connections <= 10000:
config["worker_connections"] = worker_connections

if RuntimeParameters.has("DRUM_GUNICORN_BACKLOG"):
backlog = int(RuntimeParameters.get("DRUM_GUNICORN_BACKLOG"))
if 1 <= backlog <= 2048:
config["backlog"] = backlog

if RuntimeParameters.has("DRUM_GUNICORN_TIMEOUT"):
timeout = int(RuntimeParameters.get("DRUM_GUNICORN_TIMEOUT"))
if 1 <= timeout <= 3600:
config["timeout"] = timeout

if RuntimeParameters.has("DRUM_GUNICORN_GRACEFUL_TIMEOUT"):
graceful_timeout = int(RuntimeParameters.get("DRUM_GUNICORN_GRACEFUL_TIMEOUT"))
if 1 <= graceful_timeout <= 3600:
config["graceful_timeout"] = graceful_timeout

if RuntimeParameters.has("DRUM_GUNICORN_KEEP_ALIVE"):
keepalive = int(RuntimeParameters.get("DRUM_GUNICORN_KEEP_ALIVE"))
if 1 <= keepalive <= 3600:
config["keepalive"] = keepalive

if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS"):
max_requests = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS"))
if 1 <= max_requests <= 10000:
config["max_requests"] = max_requests

if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS_JITTER"):
max_requests_jitter = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS_JITTER"))
if 1 <= max_requests_jitter <= 10000:
config["max_requests_jitter"] = max_requests_jitter

if RuntimeParameters.has("DRUM_GUNICORN_LOG_LEVEL"):
loglevel = str(RuntimeParameters.get("DRUM_GUNICORN_LOG_LEVEL"))
if loglevel.lower() in {"debug", "info", "warning", "error", "critical"}:
config["loglevel"] = loglevel

if RuntimeParameters.has("DRUM_GUNICORN_WORKERS"):
workers = int(RuntimeParameters.get("DRUM_GUNICORN_WORKERS"))
if 0 < workers < 200:
config["workers"] = workers

return config

def get_server_type(self):
server_type = "flask"
if RuntimeParameters.has("DRUM_SERVER_TYPE"):
server_type = str(RuntimeParameters.get("DRUM_SERVER_TYPE"))
if server_type.lower() in {"flask", "gunicorn"}:
server_type = server_type.lower()
return server_type

def _run_flask_app(self, app, termination_hook):
host = self._params.get("host", None)
port = self._params.get("port", None)

server_type = self.get_server_type()
processes = 1
if self._params.get("processes"):
processes = self._params.get("processes")
Expand All @@ -340,20 +404,142 @@ def _run_flask_app(self, app):
)
self._server_watchdog.start()

# Configure the server with timeout settings
app.run(
host=host,
port=port,
threaded=False,
processes=processes,
**(
{"request_handler": TimeoutWSGIRequestHandler}
if RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT")
else {}
),
)
if server_type == "gunicorn":
logger.info("Starting gunicorn server")
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably this can be removed as we require gunicorn

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe import can be even moved to the beginning of the file

from gunicorn.app.base import BaseApplication
except ImportError:
BaseApplication = None
raise DrumCommonException("gunicorn is not installed. Please install gunicorn.")

class GunicornApp(BaseApplication):
def __init__(self, app, host, port, params, gunicorn_config, termination_hook):
self.application = app
self.host = host
self.port = port
self.params = params
self.gunicorn_config = gunicorn_config
self.termination_hook = termination_hook
super().__init__()

def load_config(self):
self.cfg.set("bind", f"{self.host}:{self.port}")
workers = (
self.params.get("max_workers")
or self.params.get("processes")
)
if self.gunicorn_config.get("workers"):
workers = self.gunicorn_config.get("workers")
self.cfg.set("workers", workers)
self.cfg.set("reuse_port", True)
self.cfg.set("preload_app", True)

self.cfg.set(
"worker_class", self.gunicorn_config.get("worker_class", "sync")
)
self.cfg.set("backlog", self.gunicorn_config.get("backlog", 2048))
self.cfg.set("timeout", self.gunicorn_config.get("timeout", 120))
self.cfg.set(
"graceful_timeout", self.gunicorn_config.get("graceful_timeout", 60)
)
self.cfg.set("keepalive", self.gunicorn_config.get("keepalive", 5))
self.cfg.set("max_requests", self.gunicorn_config.get("max_requests", 1000))
self.cfg.set(
"max_requests_jitter",
self.gunicorn_config.get("max_requests_jitter", 500),
)

if self.gunicorn_config.get("worker_connections"):
self.cfg.set(
"worker_connections", self.gunicorn_config.get("worker_connections")
)
self.cfg.set("loglevel", self.gunicorn_config.get("loglevel", "info"))

# Properly assign the worker_exit hook
if self.termination_hook:
self.cfg.set("worker_exit", self._worker_exit_hook)

'''self.cfg.set("accesslog", "-")
self.cfg.set("errorlog", "-") # if you want error logs to stdout
self.cfg.set(
"access_log_format",
'%(t)s %(h)s %(l)s %(u)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"',
)'''
# Remove unsupported config keys: access_logfile, error_logfile, access_logformat
# These must be set via CLI, not config API

def load(self):
return self.application

def _worker_exit_hook(self, server, worker):
pid = worker.pid
server.log.info(f"[HOOK] Worker PID {pid} exiting — running termination hook.")

def run_hook():
try:
self.termination_hook()
server.log.info(f"[HOOK] Worker PID {pid} termination logic completed.")
except Exception as e:
server.log.error(f"[HOOK ERROR] Worker PID {pid}: {e}")

try:
# 🔍 Найти процесс, который слушает тот же порт
port = self.port
occupying_proc = None

for proc in psutil.process_iter(['pid', 'name']):
try:
for conn in proc.connections(kind='inet'):
if conn.status == psutil.CONN_LISTEN and conn.laddr.port == port:
occupying_proc = proc
break
if occupying_proc:
break
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue

if occupying_proc:
server.log.info(
f"[PORT OCCUPIED] Порт {port} занят процессом: PID={occupying_proc.pid}, имя={occupying_proc.name()}"
)
else:
server.log.info(f"[PORT FREE] Порт {port} свободен.")

server.log.info(f"[HOOK] Worker PID {pid} termination logic completed.")
except Exception as e:
server.log.error(f"[HOOK ERROR] Worker PID {pid}: {e}")



thread = threading.Thread(target=run_hook)
thread.start()


'''for thread in threads:
try:
server.log.info(f"Name: {thread.name}, ID: {thread.ident}, Daemon: {thread.daemon}")
'''
#server.log.info(f"Active thread count:", threading.active_count())
thread.join(timeout=20)
server.log.info(f"[HOOK] Worker PID {pid} cleanup done or timed out.")

gunicorn_config = self.get_gunicorn_config()
GunicornApp(app, host, port, self._params, gunicorn_config, termination_hook).run()
else:
# Configure the server with timeout settings
app.run(
host=host,
port=port,
threaded=False,
processes=processes,
**(
{"request_handler": TimeoutWSGIRequestHandler}
if RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT")
else {}
),
)
except OSError as e:
raise DrumCommonException("{}: host: {}; port: {}".format(e, host, port))
raise DrumCommonException(f"{e}: host: {host}; port: {port}")

def _kill_all_processes(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions custom_model_runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ argcomplete
trafaret>=2.0.0
docker>=4.2.2
flask
gevent
gunicorn
jinja2>=3.0.0
memory_profiler<1.0.0
numpy
Expand Down